Skip to content

Commit

Permalink
refactor: ScalarVisitor refactor replace_predicate_column (#15439)
Browse files Browse the repository at this point in the history
refactor: ScalarVisitor refactor replace_predicate_x
  • Loading branch information
TCeason authored May 9, 2024
1 parent 3b1b239 commit b232ff1
Showing 1 changed file with 27 additions and 233 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,12 @@ use crate::optimizer::rule::Rule;
use crate::optimizer::rule::TransformResult;
use crate::optimizer::RuleID;
use crate::optimizer::SExpr;
use crate::plans::AggregateFunction;
use crate::plans::BoundColumnRef;
use crate::plans::CastExpr;
use crate::plans::Filter;
use crate::plans::FunctionCall;
use crate::plans::LagLeadFunction;
use crate::plans::LambdaFunc;
use crate::plans::NthValueFunction;
use crate::plans::RelOp;
use crate::plans::Scan;
use crate::plans::UDFCall;
use crate::plans::WindowFunc;
use crate::plans::WindowFuncType;
use crate::plans::WindowOrderBy;
use crate::plans::SubqueryExpr;
use crate::plans::VisitorMut;
use crate::ColumnEntry;
use crate::MetadataRef;
use crate::ScalarExpr;
Expand Down Expand Up @@ -75,10 +67,16 @@ impl RulePushDownFilterScan {
column_entries: &[&ColumnEntry],
replace_view: bool,
) -> Result<ScalarExpr> {
match predicate {
ScalarExpr::BoundColumnRef(column) => {
struct ReplacePredicateColumnVisitor<'a> {
table_entries: &'a [TableEntry],
column_entries: &'a [&'a ColumnEntry],
replace_view: bool,
}

impl<'a> VisitorMut<'a> for ReplacePredicateColumnVisitor<'a> {
fn visit_bound_column_ref(&mut self, column: &mut BoundColumnRef) -> Result<()> {
if let Some(base_column) =
column_entries
self.column_entries
.iter()
.find_map(|column_entry| match column_entry {
ColumnEntry::BaseTableColumn(base_column)
Expand All @@ -89,7 +87,8 @@ impl RulePushDownFilterScan {
_ => None,
})
{
if let Some(table_entry) = table_entries
if let Some(table_entry) = self
.table_entries
.iter()
.find(|table_entry| table_entry.index() == base_column.table_index)
{
Expand All @@ -103,236 +102,31 @@ impl RulePushDownFilterScan {
.database_name(Some(table_entry.database().to_string()))
.table_index(Some(table_entry.index()));

if replace_view {
if self.replace_view {
column_binding_builder = column_binding_builder
.virtual_computed_expr(column.column.virtual_computed_expr.clone());
}

let bound_column_ref = BoundColumnRef {
span: column.span,
column: column_binding_builder.build(),
};
return Ok(ScalarExpr::BoundColumnRef(bound_column_ref));
column.column = column_binding_builder.build();
}
}
Ok(predicate.clone())
Ok(())
}
ScalarExpr::WindowFunction(window) => {
let func = match &window.func {
WindowFuncType::Aggregate(agg) => {
let args = agg
.args
.iter()
.map(|arg| {
Self::replace_predicate_column(
arg,
table_entries,
column_entries,
replace_view,
)
})
.collect::<Result<Vec<ScalarExpr>>>()?;

WindowFuncType::Aggregate(AggregateFunction {
func_name: agg.func_name.clone(),
distinct: agg.distinct,
params: agg.params.clone(),
args,
return_type: agg.return_type.clone(),
display_name: agg.display_name.clone(),
})
}
WindowFuncType::LagLead(ll) => {
let new_arg = Self::replace_predicate_column(
&ll.arg,
table_entries,
column_entries,
replace_view,
)?;
let new_default = match ll.default.clone().map(|d| {
Self::replace_predicate_column(
&d,
table_entries,
column_entries,
replace_view,
)
}) {
None => None,
Some(d) => Some(Box::new(d?)),
};
WindowFuncType::LagLead(LagLeadFunction {
is_lag: ll.is_lag,
arg: Box::new(new_arg),
offset: ll.offset,
default: new_default,
return_type: ll.return_type.clone(),
})
}
WindowFuncType::NthValue(func) => {
let new_arg = Self::replace_predicate_column(
&func.arg,
table_entries,
column_entries,
replace_view,
)?;
WindowFuncType::NthValue(NthValueFunction {
n: func.n,
arg: Box::new(new_arg),
return_type: func.return_type.clone(),
})
}
func => func.clone(),
};

let partition_by = window
.partition_by
.iter()
.map(|arg| {
Self::replace_predicate_column(
arg,
table_entries,
column_entries,
replace_view,
)
})
.collect::<Result<Vec<ScalarExpr>>>()?;

let order_by = window
.order_by
.iter()
.map(|item| {
let replaced_scalar = Self::replace_predicate_column(
&item.expr,
table_entries,
column_entries,
replace_view,
)?;
Ok(WindowOrderBy {
expr: replaced_scalar,
asc: item.asc,
nulls_first: item.nulls_first,
})
})
.collect::<Result<Vec<WindowOrderBy>>>()?;

Ok(ScalarExpr::WindowFunction(WindowFunc {
span: window.span,
display_name: window.display_name.clone(),
func,
partition_by,
order_by,
frame: window.frame.clone(),
}))
fn visit_subquery_expr(&mut self, _subquery: &'a mut SubqueryExpr) -> Result<()> {
Ok(())
}
ScalarExpr::AggregateFunction(agg_func) => {
let args = agg_func
.args
.iter()
.map(|arg| {
Self::replace_predicate_column(
arg,
table_entries,
column_entries,
replace_view,
)
})
.collect::<Result<Vec<ScalarExpr>>>()?;

Ok(ScalarExpr::AggregateFunction(AggregateFunction {
func_name: agg_func.func_name.clone(),
distinct: agg_func.distinct,
params: agg_func.params.clone(),
args,
return_type: agg_func.return_type.clone(),
display_name: agg_func.display_name.clone(),
}))
}
ScalarExpr::LambdaFunction(lambda_func) => {
let args = lambda_func
.args
.iter()
.map(|arg| {
Self::replace_predicate_column(
arg,
table_entries,
column_entries,
replace_view,
)
})
.collect::<Result<Vec<ScalarExpr>>>()?;

Ok(ScalarExpr::LambdaFunction(LambdaFunc {
span: lambda_func.span,
func_name: lambda_func.func_name.clone(),
args,
lambda_expr: lambda_func.lambda_expr.clone(),
lambda_display: lambda_func.lambda_display.clone(),
return_type: lambda_func.return_type.clone(),
}))
}
ScalarExpr::FunctionCall(func) => {
let arguments = func
.arguments
.iter()
.map(|arg| {
Self::replace_predicate_column(
arg,
table_entries,
column_entries,
replace_view,
)
})
.collect::<Result<Vec<ScalarExpr>>>()?;

Ok(ScalarExpr::FunctionCall(FunctionCall {
span: func.span,
params: func.params.clone(),
arguments,
func_name: func.func_name.clone(),
}))
}
ScalarExpr::CastExpr(cast) => {
let arg = Self::replace_predicate_column(
&cast.argument,
table_entries,
column_entries,
replace_view,
)?;
Ok(ScalarExpr::CastExpr(CastExpr {
span: cast.span,
is_try: cast.is_try,
argument: Box::new(arg),
target_type: cast.target_type.clone(),
}))
}
ScalarExpr::UDFCall(udf) => {
let arguments = udf
.arguments
.iter()
.map(|arg| {
Self::replace_predicate_column(
arg,
table_entries,
column_entries,
replace_view,
)
})
.collect::<Result<Vec<ScalarExpr>>>()?;
}

Ok(ScalarExpr::UDFCall(UDFCall {
span: udf.span,
name: udf.name.clone(),
func_name: udf.func_name.clone(),
display_name: udf.display_name.clone(),
udf_type: udf.udf_type.clone(),
arg_types: udf.arg_types.clone(),
return_type: udf.return_type.clone(),
arguments,
}))
}
let mut visitor = ReplacePredicateColumnVisitor {
table_entries,
column_entries,
replace_view,
};
let mut predicate = predicate.clone();
visitor.visit(&mut predicate)?;

_ => Ok(predicate.clone()),
}
Ok(predicate.clone())
}

fn find_push_down_predicates(
Expand Down

0 comments on commit b232ff1

Please sign in to comment.