diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index 01c79e0c0..29bf30959 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -45,7 +45,7 @@ use datafusion::{ }, }; use datafusion_common::{ - tree_node::{TreeNode, VisitRecursion}, + tree_node::{TreeNode, TreeNodeRewriter, VisitRecursion}, JoinType as DFJoinType, ScalarValue, }; use itertools::Itertools; @@ -904,7 +904,18 @@ impl PhysicalPlanner { // Handle join filter as DataFusion `JoinFilter` struct let join_filter = if let Some(expr) = &join.condition { - let physical_expr = self.create_expr(expr, left.schema())?; + let left_schema = left.schema(); + let right_schema = right.schema(); + let left_fields = left_schema.fields(); + let right_fields = right_schema.fields(); + let all_fields: Vec<_> = left_fields + .into_iter() + .chain(right_fields.into_iter()) + .map(|f| f.clone()) + .collect(); + let full_schema = Arc::new(Schema::new(all_fields)); + + let physical_expr = self.create_expr(expr, full_schema)?; let (left_field_indices, right_field_indices) = expr_to_columns( &physical_expr, left.schema().fields.len(), @@ -916,10 +927,12 @@ impl PhysicalPlanner { ); let filter_fields: Vec = left_field_indices + .clone() .into_iter() .map(|i| left.schema().field(i).clone()) .chain( right_field_indices + .clone() .into_iter() .map(|i| right.schema().field(i).clone()), ) @@ -927,8 +940,21 @@ impl PhysicalPlanner { let filter_schema = Schema::new_with_metadata(filter_fields, HashMap::new()); - Some(JoinFilter::new( + // Rewrite the physical expression to use the new column indices. + // DataFusion's join filter is bound to intermediate schema which contains + // only the fields used in the filter expression. But the Spark's join filter + // expression is bound to the full schema. We need to rewrite the physical + // expression to use the new column indices. + let rewritten_physical_expr = rewrite_physical_expr( physical_expr, + left_schema.fields.len(), + right_schema.fields.len(), + &left_field_indices, + &right_field_indices, + )?; + + Some(JoinFilter::new( + rewritten_physical_expr, column_indices, filter_schema, )) @@ -1173,6 +1199,99 @@ fn expr_to_columns( Ok((left_field_indices, right_field_indices)) } +/// A physical join filter rewritter which rewrites the column indices in the expression +/// to use the new column indices. See `rewrite_physical_expr`. +struct JoinFilterRewriter<'a> { + left_field_len: usize, + right_field_len: usize, + left_field_indices: &'a [usize], + right_field_indices: &'a [usize], +} + +impl JoinFilterRewriter<'_> { + fn new<'a>( + left_field_len: usize, + right_field_len: usize, + left_field_indices: &'a [usize], + right_field_indices: &'a [usize], + ) -> JoinFilterRewriter<'a> { + JoinFilterRewriter { + left_field_len, + right_field_len, + left_field_indices, + right_field_indices, + } + } +} + +impl TreeNodeRewriter for JoinFilterRewriter<'_> { + type N = Arc; + + fn mutate(&mut self, node: Self::N) -> datafusion_common::Result { + let new_expr: Arc = + if let Some(column) = node.as_any().downcast_ref::() { + if column.index() < self.left_field_len { + // left side + let new_index = self + .left_field_indices + .iter() + .position(|&x| x == column.index()) + .ok_or_else(|| { + DataFusionError::Internal(format!( + "Column index {} not found in left field indices", + column.index() + )) + })?; + Arc::new(Column::new(column.name(), new_index)) + } else if column.index() < self.left_field_len + self.right_field_len { + // right side + let new_index = self + .right_field_indices + .iter() + .position(|&x| x + self.left_field_len == column.index()) + .ok_or_else(|| { + DataFusionError::Internal(format!( + "Column index {} not found in right field indices", + column.index() + )) + })?; + Arc::new(Column::new( + column.name(), + new_index + self.left_field_indices.len(), + )) + } else { + return Err(DataFusionError::Internal(format!( + "Column index {} out of range", + column.index() + ))); + } + } else { + node.clone() + }; + Ok(new_expr) + } +} + +/// Rewrites the physical expression to use the new column indices. +/// This is necessary when the physical expression is used in a join filter, as the column +/// indices are different from the original schema. +fn rewrite_physical_expr( + expr: Arc, + left_field_len: usize, + right_field_len: usize, + left_field_indices: &[usize], + right_field_indices: &[usize], +) -> Result, ExecutionError> { + let mut rewriter = JoinFilterRewriter::new( + left_field_len, + right_field_len, + left_field_indices, + right_field_indices, + ); + + Ok(expr.rewrite(&mut rewriter)?) +} + #[cfg(test)] mod tests { use std::{sync::Arc, task::Poll}; diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index b8237c835..09d3a6263 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -105,6 +105,38 @@ class CometExecSuite extends CometTestBase { } } + test("HashJoin with join filter") { + withSQLConf( + SQLConf.PREFER_SORTMERGEJOIN.key -> "false", + SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + withParquetTable((0 until 10).map(i => (i, i % 5)), "tbl_a") { + withParquetTable((0 until 10).map(i => (i % 10, i + 2)), "tbl_b") { + // Inner join: build left + val df1 = + sql( + "SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a JOIN tbl_b " + + "ON tbl_a._2 = tbl_b._1 AND tbl_a._1 > tbl_b._2") + checkSparkAnswerAndOperator(df1) + + // Right join: build left + val df2 = + sql( + "SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a RIGHT JOIN tbl_b " + + "ON tbl_a._2 = tbl_b._1 AND tbl_a._1 > tbl_b._2") + checkSparkAnswerAndOperator(df2) + + // Full join: build left + val df3 = + sql( + "SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a FULL JOIN tbl_b " + + "ON tbl_a._2 = tbl_b._1 AND tbl_a._1 > tbl_b._2") + checkSparkAnswerAndOperator(df3) + } + } + } + } + test("Fix corrupted AggregateMode when transforming plan parameters") { withParquetTable((0 until 5).map(i => (i, i + 1)), "table") { val df = sql("SELECT * FROM table").groupBy($"_1").agg(sum("_2"))