diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index b7fee96bba1c..9931dd15aec8 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -28,7 +28,7 @@ use datafusion::logical_expr::{ }; use datafusion::logical_expr::{ expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, Partitioning, - Repartition, WindowFrameBound, WindowFrameUnits, + Repartition, Subquery, WindowFrameBound, WindowFrameUnits, }; use datafusion::prelude::JoinType; use datafusion::sql::TableReference; @@ -39,6 +39,7 @@ use datafusion::{ scalar::ScalarValue, }; use substrait::proto::exchange_rel::ExchangeKind; +use substrait::proto::expression::subquery::SubqueryType; use substrait::proto::expression::{FieldReference, Literal, ScalarFunction}; use substrait::proto::{ aggregate_function::AggregationInvocation, @@ -61,7 +62,7 @@ use substrait::proto::{ use substrait::proto::{FunctionArgument, SortField}; use datafusion::common::plan_err; -use datafusion::logical_expr::expr::{InList, Sort}; +use datafusion::logical_expr::expr::{InList, InSubquery, Sort}; use std::collections::HashMap; use std::str::FromStr; use std::sync::Arc; @@ -230,7 +231,8 @@ pub async fn from_substrait_rel( let mut exprs: Vec = vec![]; for e in &p.expressions { let x = - from_substrait_rex(e, input.clone().schema(), extensions).await?; + from_substrait_rex(ctx, e, input.clone().schema(), extensions) + .await?; // if the expression is WindowFunction, wrap in a Window relation // before returning and do not add to list of this Projection's expression list // otherwise, add expression to the Projection's expression list @@ -256,7 +258,8 @@ pub async fn from_substrait_rel( ); if let Some(condition) = filter.condition.as_ref() { let expr = - from_substrait_rex(condition, input.schema(), extensions).await?; + from_substrait_rex(ctx, condition, input.schema(), extensions) + .await?; input.filter(expr.as_ref().clone())?.build() } else { not_impl_err!("Filter without an condition is not valid") @@ -288,7 +291,8 @@ pub async fn from_substrait_rel( from_substrait_rel(ctx, input, extensions).await?, ); let sorts = - from_substrait_sorts(&sort.sorts, input.schema(), extensions).await?; + from_substrait_sorts(ctx, &sort.sorts, input.schema(), extensions) + .await?; input.sort(sorts)?.build() } else { not_impl_err!("Sort without an input is not valid") @@ -306,7 +310,8 @@ pub async fn from_substrait_rel( 1 => { for e in &agg.groupings[0].grouping_expressions { let x = - from_substrait_rex(e, input.schema(), extensions).await?; + from_substrait_rex(ctx, e, input.schema(), extensions) + .await?; group_expr.push(x.as_ref().clone()); } } @@ -315,8 +320,13 @@ pub async fn from_substrait_rel( for grouping in &agg.groupings { let mut grouping_set = vec![]; for e in &grouping.grouping_expressions { - let x = from_substrait_rex(e, input.schema(), extensions) - .await?; + let x = from_substrait_rex( + ctx, + e, + input.schema(), + extensions, + ) + .await?; grouping_set.push(x.as_ref().clone()); } grouping_sets.push(grouping_set); @@ -334,7 +344,7 @@ pub async fn from_substrait_rel( for m in &agg.measures { let filter = match &m.filter { Some(fil) => Some(Box::new( - from_substrait_rex(fil, input.schema(), extensions) + from_substrait_rex(ctx, fil, input.schema(), extensions) .await? .as_ref() .clone(), @@ -402,8 +412,8 @@ pub async fn from_substrait_rel( // Otherwise, build join with only the filter, without join keys match &join.expression.as_ref() { Some(expr) => { - let on = - from_substrait_rex(expr, &in_join_schema, extensions).await?; + let on = from_substrait_rex(ctx, expr, &in_join_schema, extensions) + .await?; // The join expression can contain both equal and non-equal ops. // As of datafusion 31.0.0, the equal and non equal join conditions are in separate fields. // So we extract each part as follows: @@ -612,14 +622,16 @@ fn from_substrait_jointype(join_type: i32) -> Result { /// Convert Substrait Sorts to DataFusion Exprs pub async fn from_substrait_sorts( + ctx: &SessionContext, substrait_sorts: &Vec, input_schema: &DFSchema, extensions: &HashMap, ) -> Result> { let mut sorts: Vec = vec![]; for s in substrait_sorts { - let expr = from_substrait_rex(s.expr.as_ref().unwrap(), input_schema, extensions) - .await?; + let expr = + from_substrait_rex(ctx, s.expr.as_ref().unwrap(), input_schema, extensions) + .await?; let asc_nullfirst = match &s.sort_kind { Some(k) => match k { Direction(d) => { @@ -660,13 +672,14 @@ pub async fn from_substrait_sorts( /// Convert Substrait Expressions to DataFusion Exprs pub async fn from_substrait_rex_vec( + ctx: &SessionContext, exprs: &Vec, input_schema: &DFSchema, extensions: &HashMap, ) -> Result> { let mut expressions: Vec = vec![]; for expr in exprs { - let expression = from_substrait_rex(expr, input_schema, extensions).await?; + let expression = from_substrait_rex(ctx, expr, input_schema, extensions).await?; expressions.push(expression.as_ref().clone()); } Ok(expressions) @@ -674,6 +687,7 @@ pub async fn from_substrait_rex_vec( /// Convert Substrait FunctionArguments to DataFusion Exprs pub async fn from_substriat_func_args( + ctx: &SessionContext, arguments: &Vec, input_schema: &DFSchema, extensions: &HashMap, @@ -682,7 +696,7 @@ pub async fn from_substriat_func_args( for arg in arguments { let arg_expr = match &arg.arg_type { Some(ArgType::Value(e)) => { - from_substrait_rex(e, input_schema, extensions).await + from_substrait_rex(ctx, e, input_schema, extensions).await } _ => { not_impl_err!("Aggregated function argument non-Value type not supported") @@ -707,7 +721,7 @@ pub async fn from_substrait_agg_func( for arg in &f.arguments { let arg_expr = match &arg.arg_type { Some(ArgType::Value(e)) => { - from_substrait_rex(e, input_schema, extensions).await + from_substrait_rex(ctx, e, input_schema, extensions).await } _ => { not_impl_err!("Aggregated function argument non-Value type not supported") @@ -745,6 +759,7 @@ pub async fn from_substrait_agg_func( /// Convert Substrait Rex to DataFusion Expr #[async_recursion] pub async fn from_substrait_rex( + ctx: &SessionContext, e: &Expression, input_schema: &DFSchema, extensions: &HashMap, @@ -755,13 +770,18 @@ pub async fn from_substrait_rex( let substrait_list = s.options.as_ref(); Ok(Arc::new(Expr::InList(InList { expr: Box::new( - from_substrait_rex(substrait_expr, input_schema, extensions) + from_substrait_rex(ctx, substrait_expr, input_schema, extensions) .await? .as_ref() .clone(), ), - list: from_substrait_rex_vec(substrait_list, input_schema, extensions) - .await?, + list: from_substrait_rex_vec( + ctx, + substrait_list, + input_schema, + extensions, + ) + .await?, negated: false, }))) } @@ -779,6 +799,7 @@ pub async fn from_substrait_rex( if if_expr.then.is_none() { expr = Some(Box::new( from_substrait_rex( + ctx, if_expr.r#if.as_ref().unwrap(), input_schema, extensions, @@ -793,6 +814,7 @@ pub async fn from_substrait_rex( when_then_expr.push(( Box::new( from_substrait_rex( + ctx, if_expr.r#if.as_ref().unwrap(), input_schema, extensions, @@ -803,6 +825,7 @@ pub async fn from_substrait_rex( ), Box::new( from_substrait_rex( + ctx, if_expr.then.as_ref().unwrap(), input_schema, extensions, @@ -816,7 +839,7 @@ pub async fn from_substrait_rex( // Parse `else` let else_expr = match &if_then.r#else { Some(e) => Some(Box::new( - from_substrait_rex(e, input_schema, extensions) + from_substrait_rex(ctx, e, input_schema, extensions) .await? .as_ref() .clone(), @@ -843,7 +866,7 @@ pub async fn from_substrait_rex( for arg in &f.arguments { let arg_expr = match &arg.arg_type { Some(ArgType::Value(e)) => { - from_substrait_rex(e, input_schema, extensions).await + from_substrait_rex(ctx, e, input_schema, extensions).await } _ => not_impl_err!( "Aggregated function argument non-Value type not supported" @@ -868,14 +891,14 @@ pub async fn from_substrait_rex( (Some(ArgType::Value(l)), Some(ArgType::Value(r))) => { Ok(Arc::new(Expr::BinaryExpr(BinaryExpr { left: Box::new( - from_substrait_rex(l, input_schema, extensions) + from_substrait_rex(ctx, l, input_schema, extensions) .await? .as_ref() .clone(), ), op, right: Box::new( - from_substrait_rex(r, input_schema, extensions) + from_substrait_rex(ctx, r, input_schema, extensions) .await? .as_ref() .clone(), @@ -888,7 +911,7 @@ pub async fn from_substrait_rex( } } ScalarFunctionType::Expr(builder) => { - builder.build(f, input_schema, extensions).await + builder.build(ctx, f, input_schema, extensions).await } } } @@ -900,6 +923,7 @@ pub async fn from_substrait_rex( Some(output_type) => Ok(Arc::new(Expr::Cast(Cast::new( Box::new( from_substrait_rex( + ctx, cast.as_ref().input.as_ref().unwrap().as_ref(), input_schema, extensions, @@ -921,7 +945,8 @@ pub async fn from_substrait_rex( ), }; let order_by = - from_substrait_sorts(&window.sorts, input_schema, extensions).await?; + from_substrait_sorts(ctx, &window.sorts, input_schema, extensions) + .await?; // Substrait does not encode WindowFrameUnits so we're using a simple logic to determine the units // If there is no `ORDER BY`, then by default, the frame counts each row from the lower up to upper boundary // If there is `ORDER BY`, then by default, each frame is a range starting from unbounded preceding to current row @@ -934,12 +959,14 @@ pub async fn from_substrait_rex( Ok(Arc::new(Expr::WindowFunction(expr::WindowFunction { fun: fun?.unwrap(), args: from_substriat_func_args( + ctx, &window.arguments, input_schema, extensions, ) .await?, partition_by: from_substrait_rex_vec( + ctx, &window.partitions, input_schema, extensions, @@ -953,6 +980,51 @@ pub async fn from_substrait_rex( }, }))) } + Some(RexType::Subquery(subquery)) => match &subquery.as_ref().subquery_type { + Some(subquery_type) => match subquery_type { + SubqueryType::InPredicate(in_predicate) => { + if in_predicate.needles.len() != 1 { + Err(DataFusionError::Substrait( + "InPredicate Subquery type must have exactly one Needle expression" + .to_string(), + )) + } else { + let needle_expr = &in_predicate.needles[0]; + let haystack_expr = &in_predicate.haystack; + if let Some(haystack_expr) = haystack_expr { + let haystack_expr = + from_substrait_rel(ctx, haystack_expr, extensions) + .await?; + let outer_refs = haystack_expr.all_out_ref_exprs(); + Ok(Arc::new(Expr::InSubquery(InSubquery { + expr: Box::new( + from_substrait_rex( + ctx, + needle_expr, + input_schema, + extensions, + ) + .await? + .as_ref() + .clone(), + ), + subquery: Subquery { + subquery: Arc::new(haystack_expr), + outer_ref_columns: outer_refs, + }, + negated: false, + }))) + } else { + substrait_err!("InPredicate Subquery type must have a Haystack expression") + } + } + } + _ => substrait_err!("Subquery type not implemented"), + }, + None => { + substrait_err!("Subquery experssion without SubqueryType is not allowed") + } + }, _ => not_impl_err!("unsupported rex_type"), } } @@ -1312,16 +1384,22 @@ impl BuiltinExprBuilder { pub async fn build( self, + ctx: &SessionContext, f: &ScalarFunction, input_schema: &DFSchema, extensions: &HashMap, ) -> Result> { match self.expr_name.as_str() { - "like" => Self::build_like_expr(false, f, input_schema, extensions).await, - "ilike" => Self::build_like_expr(true, f, input_schema, extensions).await, + "like" => { + Self::build_like_expr(ctx, false, f, input_schema, extensions).await + } + "ilike" => { + Self::build_like_expr(ctx, true, f, input_schema, extensions).await + } "not" | "negative" | "is_null" | "is_not_null" | "is_true" | "is_false" | "is_not_true" | "is_not_false" | "is_unknown" | "is_not_unknown" => { - Self::build_unary_expr(&self.expr_name, f, input_schema, extensions).await + Self::build_unary_expr(ctx, &self.expr_name, f, input_schema, extensions) + .await } _ => { not_impl_err!("Unsupported builtin expression: {}", self.expr_name) @@ -1330,6 +1408,7 @@ impl BuiltinExprBuilder { } async fn build_unary_expr( + ctx: &SessionContext, fn_name: &str, f: &ScalarFunction, input_schema: &DFSchema, @@ -1341,7 +1420,7 @@ impl BuiltinExprBuilder { let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else { return substrait_err!("Invalid arguments type for {fn_name} expr"); }; - let arg = from_substrait_rex(expr_substrait, input_schema, extensions) + let arg = from_substrait_rex(ctx, expr_substrait, input_schema, extensions) .await? .as_ref() .clone(); @@ -1365,6 +1444,7 @@ impl BuiltinExprBuilder { } async fn build_like_expr( + ctx: &SessionContext, case_insensitive: bool, f: &ScalarFunction, input_schema: &DFSchema, @@ -1378,22 +1458,23 @@ impl BuiltinExprBuilder { let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else { return substrait_err!("Invalid arguments type for `{fn_name}` expr"); }; - let expr = from_substrait_rex(expr_substrait, input_schema, extensions) + let expr = from_substrait_rex(ctx, expr_substrait, input_schema, extensions) .await? .as_ref() .clone(); let Some(ArgType::Value(pattern_substrait)) = &f.arguments[1].arg_type else { return substrait_err!("Invalid arguments type for `{fn_name}` expr"); }; - let pattern = from_substrait_rex(pattern_substrait, input_schema, extensions) - .await? - .as_ref() - .clone(); + let pattern = + from_substrait_rex(ctx, pattern_substrait, input_schema, extensions) + .await? + .as_ref() + .clone(); let Some(ArgType::Value(escape_char_substrait)) = &f.arguments[2].arg_type else { return substrait_err!("Invalid arguments type for `{fn_name}` expr"); }; let escape_char_expr = - from_substrait_rex(escape_char_substrait, input_schema, extensions) + from_substrait_rex(ctx, escape_char_substrait, input_schema, extensions) .await? .as_ref() .clone(); diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 50f872544298..926883251a63 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -36,12 +36,13 @@ use datafusion::common::{substrait_err, DFSchemaRef}; use datafusion::logical_expr::aggregate_function; use datafusion::logical_expr::expr::{ AggregateFunctionDefinition, Alias, BinaryExpr, Case, Cast, GroupingSet, InList, - ScalarFunctionDefinition, Sort, WindowFunction, + InSubquery, ScalarFunctionDefinition, Sort, WindowFunction, }; use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Operator}; use datafusion::prelude::Expr; use prost_types::Any as ProtoAny; use substrait::proto::exchange_rel::{ExchangeKind, RoundRobin, ScatterFields}; +use substrait::proto::expression::subquery::InPredicate; use substrait::proto::expression::window_function::BoundsType; use substrait::proto::{CrossRel, ExchangeRel}; use substrait::{ @@ -58,7 +59,8 @@ use substrait::{ window_function::bound::Kind as BoundKind, window_function::Bound, FieldReference, IfThen, Literal, MaskExpression, ReferenceSegment, RexType, - ScalarFunction, SingularOrList, WindowFunction as SubstraitWindowFunction, + ScalarFunction, SingularOrList, Subquery, + WindowFunction as SubstraitWindowFunction, }, extensions::{ self, @@ -167,7 +169,7 @@ pub fn to_substrait_rel( let expressions = p .expr .iter() - .map(|e| to_substrait_rex(e, p.input.schema(), 0, extension_info)) + .map(|e| to_substrait_rex(ctx, e, p.input.schema(), 0, extension_info)) .collect::>>()?; Ok(Box::new(Rel { rel_type: Some(RelType::Project(Box::new(ProjectRel { @@ -181,6 +183,7 @@ pub fn to_substrait_rel( LogicalPlan::Filter(filter) => { let input = to_substrait_rel(filter.input.as_ref(), ctx, extension_info)?; let filter_expr = to_substrait_rex( + ctx, &filter.predicate, filter.input.schema(), 0, @@ -214,7 +217,9 @@ pub fn to_substrait_rel( let sort_fields = sort .expr .iter() - .map(|e| substrait_sort_field(e, sort.input.schema(), extension_info)) + .map(|e| { + substrait_sort_field(ctx, e, sort.input.schema(), extension_info) + }) .collect::>>()?; Ok(Box::new(Rel { rel_type: Some(RelType::Sort(Box::new(SortRel { @@ -228,6 +233,7 @@ pub fn to_substrait_rel( LogicalPlan::Aggregate(agg) => { let input = to_substrait_rel(agg.input.as_ref(), ctx, extension_info)?; let groupings = to_substrait_groupings( + ctx, &agg.group_expr, agg.input.schema(), extension_info, @@ -235,7 +241,9 @@ pub fn to_substrait_rel( let measures = agg .aggr_expr .iter() - .map(|e| to_substrait_agg_measure(e, agg.input.schema(), extension_info)) + .map(|e| { + to_substrait_agg_measure(ctx, e, agg.input.schema(), extension_info) + }) .collect::>>()?; Ok(Box::new(Rel { @@ -283,6 +291,7 @@ pub fn to_substrait_rel( let in_join_schema = join.left.schema().join(join.right.schema())?; let join_filter = match &join.filter { Some(filter) => Some(to_substrait_rex( + ctx, filter, &Arc::new(in_join_schema), 0, @@ -299,6 +308,7 @@ pub fn to_substrait_rel( Operator::Eq }; let join_on = to_substrait_join_expr( + ctx, &join.on, eq_op, join.left.schema(), @@ -401,6 +411,7 @@ pub fn to_substrait_rel( let mut window_exprs = vec![]; for expr in &window.window_expr { window_exprs.push(to_substrait_rex( + ctx, expr, window.input.schema(), 0, @@ -500,6 +511,7 @@ pub fn to_substrait_rel( } fn to_substrait_join_expr( + ctx: &SessionContext, join_conditions: &Vec<(Expr, Expr)>, eq_op: Operator, left_schema: &DFSchemaRef, @@ -513,9 +525,10 @@ fn to_substrait_join_expr( let mut exprs: Vec = vec![]; for (left, right) in join_conditions { // Parse left - let l = to_substrait_rex(left, left_schema, 0, extension_info)?; + let l = to_substrait_rex(ctx, left, left_schema, 0, extension_info)?; // Parse right let r = to_substrait_rex( + ctx, right, right_schema, left_schema.fields().len(), // offset to return the correct index @@ -576,6 +589,7 @@ pub fn operator_to_name(op: Operator) -> &'static str { } pub fn parse_flat_grouping_exprs( + ctx: &SessionContext, exprs: &[Expr], schema: &DFSchemaRef, extension_info: &mut ( @@ -585,7 +599,7 @@ pub fn parse_flat_grouping_exprs( ) -> Result { let grouping_expressions = exprs .iter() - .map(|e| to_substrait_rex(e, schema, 0, extension_info)) + .map(|e| to_substrait_rex(ctx, e, schema, 0, extension_info)) .collect::>>()?; Ok(Grouping { grouping_expressions, @@ -593,6 +607,7 @@ pub fn parse_flat_grouping_exprs( } pub fn to_substrait_groupings( + ctx: &SessionContext, exprs: &Vec, schema: &DFSchemaRef, extension_info: &mut ( @@ -608,7 +623,9 @@ pub fn to_substrait_groupings( )), GroupingSet::GroupingSets(sets) => Ok(sets .iter() - .map(|set| parse_flat_grouping_exprs(set, schema, extension_info)) + .map(|set| { + parse_flat_grouping_exprs(ctx, set, schema, extension_info) + }) .collect::>>()?), GroupingSet::Rollup(set) => { let mut sets: Vec> = vec![vec![]]; @@ -618,17 +635,21 @@ pub fn to_substrait_groupings( Ok(sets .iter() .rev() - .map(|set| parse_flat_grouping_exprs(set, schema, extension_info)) + .map(|set| { + parse_flat_grouping_exprs(ctx, set, schema, extension_info) + }) .collect::>>()?) } }, _ => Ok(vec![parse_flat_grouping_exprs( + ctx, exprs, schema, extension_info, )?]), }, _ => Ok(vec![parse_flat_grouping_exprs( + ctx, exprs, schema, extension_info, @@ -638,6 +659,7 @@ pub fn to_substrait_groupings( #[allow(deprecated)] pub fn to_substrait_agg_measure( + ctx: &SessionContext, expr: &Expr, schema: &DFSchemaRef, extension_info: &mut ( @@ -650,13 +672,13 @@ pub fn to_substrait_agg_measure( match func_def { AggregateFunctionDefinition::BuiltIn (fun) => { let sorts = if let Some(order_by) = order_by { - order_by.iter().map(|expr| to_substrait_sort_field(expr, schema, extension_info)).collect::>>()? + order_by.iter().map(|expr| to_substrait_sort_field(ctx, expr, schema, extension_info)).collect::>>()? } else { vec![] }; let mut arguments: Vec = vec![]; for arg in args { - arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(arg, schema, 0, extension_info)?)) }); + arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(ctx, arg, schema, 0, extension_info)?)) }); } let function_anchor = _register_function(fun.to_string(), extension_info); Ok(Measure { @@ -674,20 +696,20 @@ pub fn to_substrait_agg_measure( options: vec![], }), filter: match filter { - Some(f) => Some(to_substrait_rex(f, schema, 0, extension_info)?), + Some(f) => Some(to_substrait_rex(ctx, f, schema, 0, extension_info)?), None => None } }) } AggregateFunctionDefinition::UDF(fun) => { let sorts = if let Some(order_by) = order_by { - order_by.iter().map(|expr| to_substrait_sort_field(expr, schema, extension_info)).collect::>>()? + order_by.iter().map(|expr| to_substrait_sort_field(ctx, expr, schema, extension_info)).collect::>>()? } else { vec![] }; let mut arguments: Vec = vec![]; for arg in args { - arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(arg, schema, 0, extension_info)?)) }); + arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(ctx, arg, schema, 0, extension_info)?)) }); } let function_anchor = _register_function(fun.name().to_string(), extension_info); Ok(Measure { @@ -702,7 +724,7 @@ pub fn to_substrait_agg_measure( options: vec![], }), filter: match filter { - Some(f) => Some(to_substrait_rex(f, schema, 0, extension_info)?), + Some(f) => Some(to_substrait_rex(ctx, f, schema, 0, extension_info)?), None => None } }) @@ -714,7 +736,7 @@ pub fn to_substrait_agg_measure( } Expr::Alias(Alias{expr,..})=> { - to_substrait_agg_measure(expr, schema, extension_info) + to_substrait_agg_measure(ctx, expr, schema, extension_info) } _ => internal_err!( "Expression must be compatible with aggregation. Unsupported expression: {:?}. ExpressionType: {:?}", @@ -726,6 +748,7 @@ pub fn to_substrait_agg_measure( /// Converts sort expression to corresponding substrait `SortField` fn to_substrait_sort_field( + ctx: &SessionContext, expr: &Expr, schema: &DFSchemaRef, extension_info: &mut ( @@ -743,6 +766,7 @@ fn to_substrait_sort_field( }; Ok(SortField { expr: Some(to_substrait_rex( + ctx, sort.expr.deref(), schema, 0, @@ -851,6 +875,7 @@ pub fn make_binary_op_scalar_func( /// * `extension_info` - Substrait extension info. Contains registered function information #[allow(deprecated)] pub fn to_substrait_rex( + ctx: &SessionContext, expr: &Expr, schema: &DFSchemaRef, col_ref_offset: usize, @@ -867,10 +892,10 @@ pub fn to_substrait_rex( }) => { let substrait_list = list .iter() - .map(|x| to_substrait_rex(x, schema, col_ref_offset, extension_info)) + .map(|x| to_substrait_rex(ctx, x, schema, col_ref_offset, extension_info)) .collect::>>()?; let substrait_expr = - to_substrait_rex(expr, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info)?; let substrait_or_list = Expression { rex_type: Some(RexType::SingularOrList(Box::new(SingularOrList { @@ -903,6 +928,7 @@ pub fn to_substrait_rex( for arg in &fun.args { arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex( + ctx, arg, schema, col_ref_offset, @@ -937,11 +963,11 @@ pub fn to_substrait_rex( if *negated { // `expr NOT BETWEEN low AND high` can be translated into (expr < low OR high < expr) let substrait_expr = - to_substrait_rex(expr, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info)?; let substrait_low = - to_substrait_rex(low, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, low, schema, col_ref_offset, extension_info)?; let substrait_high = - to_substrait_rex(high, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, high, schema, col_ref_offset, extension_info)?; let l_expr = make_binary_op_scalar_func( &substrait_expr, @@ -965,11 +991,11 @@ pub fn to_substrait_rex( } else { // `expr BETWEEN low AND high` can be translated into (low <= expr AND expr <= high) let substrait_expr = - to_substrait_rex(expr, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info)?; let substrait_low = - to_substrait_rex(low, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, low, schema, col_ref_offset, extension_info)?; let substrait_high = - to_substrait_rex(high, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, high, schema, col_ref_offset, extension_info)?; let l_expr = make_binary_op_scalar_func( &substrait_low, @@ -997,8 +1023,8 @@ pub fn to_substrait_rex( substrait_field_ref(index + col_ref_offset) } Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - let l = to_substrait_rex(left, schema, col_ref_offset, extension_info)?; - let r = to_substrait_rex(right, schema, col_ref_offset, extension_info)?; + let l = to_substrait_rex(ctx, left, schema, col_ref_offset, extension_info)?; + let r = to_substrait_rex(ctx, right, schema, col_ref_offset, extension_info)?; Ok(make_binary_op_scalar_func(&l, &r, *op, extension_info)) } @@ -1013,6 +1039,7 @@ pub fn to_substrait_rex( // Base expression exists ifs.push(IfClause { r#if: Some(to_substrait_rex( + ctx, e, schema, col_ref_offset, @@ -1025,12 +1052,14 @@ pub fn to_substrait_rex( for (r#if, then) in when_then_expr { ifs.push(IfClause { r#if: Some(to_substrait_rex( + ctx, r#if, schema, col_ref_offset, extension_info, )?), then: Some(to_substrait_rex( + ctx, then, schema, col_ref_offset, @@ -1042,6 +1071,7 @@ pub fn to_substrait_rex( // Parse outer `else` let r#else: Option> = match else_expr { Some(e) => Some(Box::new(to_substrait_rex( + ctx, e, schema, col_ref_offset, @@ -1060,6 +1090,7 @@ pub fn to_substrait_rex( substrait::proto::expression::Cast { r#type: Some(to_substrait_type(data_type)?), input: Some(Box::new(to_substrait_rex( + ctx, expr, schema, col_ref_offset, @@ -1072,7 +1103,7 @@ pub fn to_substrait_rex( } Expr::Literal(value) => to_substrait_literal(value), Expr::Alias(Alias { expr, .. }) => { - to_substrait_rex(expr, schema, col_ref_offset, extension_info) + to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info) } Expr::WindowFunction(WindowFunction { fun, @@ -1088,6 +1119,7 @@ pub fn to_substrait_rex( for arg in args { arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex( + ctx, arg, schema, col_ref_offset, @@ -1098,12 +1130,12 @@ pub fn to_substrait_rex( // partition by expressions let partition_by = partition_by .iter() - .map(|e| to_substrait_rex(e, schema, col_ref_offset, extension_info)) + .map(|e| to_substrait_rex(ctx, e, schema, col_ref_offset, extension_info)) .collect::>>()?; // order by expressions let order_by = order_by .iter() - .map(|e| substrait_sort_field(e, schema, extension_info)) + .map(|e| substrait_sort_field(ctx, e, schema, extension_info)) .collect::>>()?; // window frame let bounds = to_substrait_bounds(window_frame)?; @@ -1124,6 +1156,7 @@ pub fn to_substrait_rex( escape_char, case_insensitive, }) => make_substrait_like_expr( + ctx, *case_insensitive, *negated, expr, @@ -1133,7 +1166,50 @@ pub fn to_substrait_rex( col_ref_offset, extension_info, ), + Expr::InSubquery(InSubquery { + expr, + subquery, + negated, + }) => { + let substrait_expr = + to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info)?; + + let subquery_plan = + to_substrait_rel(subquery.subquery.as_ref(), ctx, extension_info)?; + + let substrait_subquery = Expression { + rex_type: Some(RexType::Subquery(Box::new(Subquery { + subquery_type: Some( + substrait::proto::expression::subquery::SubqueryType::InPredicate( + Box::new(InPredicate { + needles: (vec![substrait_expr]), + haystack: Some(subquery_plan), + }), + ), + ), + }))), + }; + if *negated { + let function_anchor = + _register_function("not".to_string(), extension_info); + + Ok(Expression { + rex_type: Some(RexType::ScalarFunction(ScalarFunction { + function_reference: function_anchor, + arguments: vec![FunctionArgument { + arg_type: Some(ArgType::Value(substrait_subquery)), + }], + output_type: None, + args: vec![], + options: vec![], + })), + }) + } else { + Ok(substrait_subquery) + } + } Expr::Not(arg) => to_substrait_unary_scalar_fn( + ctx, "not", arg, schema, @@ -1141,6 +1217,7 @@ pub fn to_substrait_rex( extension_info, ), Expr::IsNull(arg) => to_substrait_unary_scalar_fn( + ctx, "is_null", arg, schema, @@ -1148,6 +1225,7 @@ pub fn to_substrait_rex( extension_info, ), Expr::IsNotNull(arg) => to_substrait_unary_scalar_fn( + ctx, "is_not_null", arg, schema, @@ -1155,6 +1233,7 @@ pub fn to_substrait_rex( extension_info, ), Expr::IsTrue(arg) => to_substrait_unary_scalar_fn( + ctx, "is_true", arg, schema, @@ -1162,6 +1241,7 @@ pub fn to_substrait_rex( extension_info, ), Expr::IsFalse(arg) => to_substrait_unary_scalar_fn( + ctx, "is_false", arg, schema, @@ -1169,6 +1249,7 @@ pub fn to_substrait_rex( extension_info, ), Expr::IsUnknown(arg) => to_substrait_unary_scalar_fn( + ctx, "is_unknown", arg, schema, @@ -1176,6 +1257,7 @@ pub fn to_substrait_rex( extension_info, ), Expr::IsNotTrue(arg) => to_substrait_unary_scalar_fn( + ctx, "is_not_true", arg, schema, @@ -1183,6 +1265,7 @@ pub fn to_substrait_rex( extension_info, ), Expr::IsNotFalse(arg) => to_substrait_unary_scalar_fn( + ctx, "is_not_false", arg, schema, @@ -1190,6 +1273,7 @@ pub fn to_substrait_rex( extension_info, ), Expr::IsNotUnknown(arg) => to_substrait_unary_scalar_fn( + ctx, "is_not_unknown", arg, schema, @@ -1197,6 +1281,7 @@ pub fn to_substrait_rex( extension_info, ), Expr::Negative(arg) => to_substrait_unary_scalar_fn( + ctx, "negative", arg, schema, @@ -1421,6 +1506,7 @@ fn make_substrait_window_function( #[allow(deprecated)] #[allow(clippy::too_many_arguments)] fn make_substrait_like_expr( + ctx: &SessionContext, ignore_case: bool, negated: bool, expr: &Expr, @@ -1438,8 +1524,8 @@ fn make_substrait_like_expr( } else { _register_function("like".to_string(), extension_info) }; - let expr = to_substrait_rex(expr, schema, col_ref_offset, extension_info)?; - let pattern = to_substrait_rex(pattern, schema, col_ref_offset, extension_info)?; + let expr = to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info)?; + let pattern = to_substrait_rex(ctx, pattern, schema, col_ref_offset, extension_info)?; let escape_char = to_substrait_literal(&ScalarValue::Utf8(escape_char.map(|c| c.to_string())))?; let arguments = vec![ @@ -1669,6 +1755,7 @@ fn to_substrait_literal(value: &ScalarValue) -> Result { /// Util to generate substrait [RexType::ScalarFunction] with one argument fn to_substrait_unary_scalar_fn( + ctx: &SessionContext, fn_name: &str, arg: &Expr, schema: &DFSchemaRef, @@ -1679,7 +1766,8 @@ fn to_substrait_unary_scalar_fn( ), ) -> Result { let function_anchor = _register_function(fn_name.to_string(), extension_info); - let substrait_expr = to_substrait_rex(arg, schema, col_ref_offset, extension_info)?; + let substrait_expr = + to_substrait_rex(ctx, arg, schema, col_ref_offset, extension_info)?; Ok(Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { @@ -1880,6 +1968,7 @@ fn try_to_substrait_field_reference( } fn substrait_sort_field( + ctx: &SessionContext, expr: &Expr, schema: &DFSchemaRef, extension_info: &mut ( @@ -1893,7 +1982,7 @@ fn substrait_sort_field( asc, nulls_first, }) => { - let e = to_substrait_rex(expr, schema, 0, extension_info)?; + let e = to_substrait_rex(ctx, expr, schema, 0, extension_info)?; let d = match (asc, nulls_first) { (true, true) => SortDirection::AscNullsFirst, (true, false) => SortDirection::AscNullsLast, diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 47eb5a8f73f5..d7327caee43d 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -394,6 +394,24 @@ async fn roundtrip_inlist_4() -> Result<()> { roundtrip("SELECT * FROM data WHERE f NOT IN ('a', 'b', 'c', 'd')").await } +#[tokio::test] +async fn roundtrip_inlist_5() -> Result<()> { + // on roundtrip there is an additional projection during TableScan which includes all column of the table, + // using assert_expected_plan here as a workaround + assert_expected_plan( + "SELECT a, f FROM data WHERE (f IN ('a', 'b', 'c') OR a in (SELECT data2.a FROM data2 WHERE f IN ('b', 'c', 'd')))", + "Filter: data.f = Utf8(\"a\") OR data.f = Utf8(\"b\") OR data.f = Utf8(\"c\") OR data.a IN ()\ + \n Subquery:\ + \n Projection: data2.a\ + \n Filter: data2.f IN ([Utf8(\"b\"), Utf8(\"c\"), Utf8(\"d\")])\ + \n TableScan: data2 projection=[a, b, c, d, e, f]\ + \n TableScan: data projection=[a, f], partial_filters=[data.f = Utf8(\"a\") OR data.f = Utf8(\"b\") OR data.f = Utf8(\"c\") OR data.a IN ()]\ + \n Subquery:\ + \n Projection: data2.a\ + \n Filter: data2.f IN ([Utf8(\"b\"), Utf8(\"c\"), Utf8(\"d\")])\ + \n TableScan: data2 projection=[a, b, c, d, e, f]").await +} + #[tokio::test] async fn roundtrip_cross_join() -> Result<()> { roundtrip("SELECT * FROM data CROSS JOIN data2").await