diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index 19180fd49..c8869c5f3 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -92,6 +92,14 @@ type PhyAggResult = Result>, ExecutionError>; type PhyExprResult = Result, String)>, ExecutionError>; type PartitionPhyExprResult = Result>, ExecutionError>; +struct JoinParameters { + pub left: Arc, + pub right: Arc, + pub join_on: Vec<(Arc, Arc)>, + pub join_filter: Option, + pub join_type: DFJoinType, +} + pub const TEST_EXEC_CONTEXT_ID: i64 = -1; /// The query planner for converting Spark query plans to DataFusion query plans. @@ -873,50 +881,22 @@ impl PhysicalPlanner { )) } OpStruct::SortMergeJoin(join) => { - assert!(children.len() == 2); - let (mut left_scans, left) = self.create_plan(&children[0], inputs)?; - let (mut right_scans, right) = self.create_plan(&children[1], inputs)?; - - left_scans.append(&mut right_scans); - - let left_join_exprs = join - .left_join_keys - .iter() - .map(|expr| self.create_expr(expr, left.schema())) - .collect::, _>>()?; - let right_join_exprs = join - .right_join_keys - .iter() - .map(|expr| self.create_expr(expr, right.schema())) - .collect::, _>>()?; - - let join_on = left_join_exprs - .into_iter() - .zip(right_join_exprs) - .collect::>(); - - let join_type = match join.join_type.try_into() { - Ok(JoinType::Inner) => DFJoinType::Inner, - Ok(JoinType::LeftOuter) => DFJoinType::Left, - Ok(JoinType::RightOuter) => DFJoinType::Right, - Ok(JoinType::FullOuter) => DFJoinType::Full, - Ok(JoinType::LeftSemi) => DFJoinType::LeftSemi, - Ok(JoinType::RightSemi) => DFJoinType::RightSemi, - Ok(JoinType::LeftAnti) => DFJoinType::LeftAnti, - Ok(JoinType::RightAnti) => DFJoinType::RightAnti, - Err(_) => { - return Err(ExecutionError::GeneralError(format!( - "Unsupported join type: {:?}", - join.join_type - ))); - } - }; + let (join_params, scans) = self.parse_join_parameters( + inputs, + children, + &join.left_join_keys, + &join.right_join_keys, + join.join_type, + &None, + )?; let sort_options = join .sort_options .iter() .map(|sort_option| { - let sort_expr = self.create_sort_expr(sort_option, left.schema()).unwrap(); + let sort_expr = self + .create_sort_expr(sort_option, join_params.left.schema()) + .unwrap(); SortOptions { descending: sort_expr.options.descending, nulls_first: sort_expr.options.nulls_first, @@ -924,163 +904,173 @@ impl PhysicalPlanner { }) .collect(); - // DataFusion `SortMergeJoinExec` operator keeps the input batch internally. We need - // to copy the input batch to avoid the data corruption from reusing the input - // batch. - let left = if can_reuse_input_batch(&left) { - Arc::new(CopyExec::new(left)) - } else { - left - }; - - let right = if can_reuse_input_batch(&right) { - Arc::new(CopyExec::new(right)) - } else { - right - }; - let join = Arc::new(SortMergeJoinExec::try_new( - left, - right, - join_on, - None, - join_type, + join_params.left, + join_params.right, + join_params.join_on, + join_params.join_filter, + join_params.join_type, sort_options, // null doesn't equal to null in Spark join key. If the join key is // `EqualNullSafe`, Spark will rewrite it during planning. false, )?); - Ok((left_scans, join)) + Ok((scans, join)) } OpStruct::HashJoin(join) => { - assert!(children.len() == 2); - let (mut left_scans, left) = self.create_plan(&children[0], inputs)?; - let (mut right_scans, right) = self.create_plan(&children[1], inputs)?; - - left_scans.append(&mut right_scans); + let (join_params, scans) = self.parse_join_parameters( + inputs, + children, + &join.left_join_keys, + &join.right_join_keys, + join.join_type, + &join.condition, + )?; + let join = Arc::new(HashJoinExec::try_new( + join_params.left, + join_params.right, + join_params.join_on, + join_params.join_filter, + &join_params.join_type, + PartitionMode::Partitioned, + // null doesn't equal to null in Spark join key. If the join key is + // `EqualNullSafe`, Spark will rewrite it during planning. + false, + )?); + Ok((scans, join)) + } + } + } - let left_join_exprs: Vec<_> = join - .left_join_keys - .iter() - .map(|expr| self.create_expr(expr, left.schema())) - .collect::, _>>()?; - let right_join_exprs: Vec<_> = join - .right_join_keys - .iter() - .map(|expr| self.create_expr(expr, right.schema())) - .collect::, _>>()?; + fn parse_join_parameters( + &self, + inputs: &mut Vec>, + children: &[Operator], + left_join_keys: &[Expr], + right_join_keys: &[Expr], + join_type: i32, + condition: &Option, + ) -> Result<(JoinParameters, Vec), ExecutionError> { + assert!(children.len() == 2); + let (mut left_scans, left) = self.create_plan(&children[0], inputs)?; + let (mut right_scans, right) = self.create_plan(&children[1], inputs)?; + + left_scans.append(&mut right_scans); + + let left_join_exprs: Vec<_> = left_join_keys + .iter() + .map(|expr| self.create_expr(expr, left.schema())) + .collect::, _>>()?; + let right_join_exprs: Vec<_> = right_join_keys + .iter() + .map(|expr| self.create_expr(expr, right.schema())) + .collect::, _>>()?; - let join_on = left_join_exprs - .into_iter() - .zip(right_join_exprs) - .collect::>(); + let join_on = left_join_exprs + .into_iter() + .zip(right_join_exprs) + .collect::>(); - let join_type = match join.join_type.try_into() { - Ok(JoinType::Inner) => DFJoinType::Inner, - Ok(JoinType::LeftOuter) => DFJoinType::Left, - Ok(JoinType::RightOuter) => DFJoinType::Right, - Ok(JoinType::FullOuter) => DFJoinType::Full, - Ok(JoinType::LeftSemi) => DFJoinType::LeftSemi, - Ok(JoinType::RightSemi) => DFJoinType::RightSemi, - Ok(JoinType::LeftAnti) => DFJoinType::LeftAnti, - Ok(JoinType::RightAnti) => DFJoinType::RightAnti, - Err(_) => { - return Err(ExecutionError::GeneralError(format!( - "Unsupported join type: {:?}", - join.join_type - ))); - } - }; + let join_type = match join_type.try_into() { + Ok(JoinType::Inner) => DFJoinType::Inner, + Ok(JoinType::LeftOuter) => DFJoinType::Left, + Ok(JoinType::RightOuter) => DFJoinType::Right, + Ok(JoinType::FullOuter) => DFJoinType::Full, + Ok(JoinType::LeftSemi) => DFJoinType::LeftSemi, + Ok(JoinType::RightSemi) => DFJoinType::RightSemi, + Ok(JoinType::LeftAnti) => DFJoinType::LeftAnti, + Ok(JoinType::RightAnti) => DFJoinType::RightAnti, + Err(_) => { + return Err(ExecutionError::GeneralError(format!( + "Unsupported join type: {:?}", + join_type + ))); + } + }; - // Handle join filter as DataFusion `JoinFilter` struct - let join_filter = if let Some(expr) = &join.condition { - let left_schema = left.schema(); - let right_schema = right.schema(); - let left_fields = left_schema.fields(); - let right_fields = right_schema.fields(); - let all_fields: Vec<_> = left_fields - .into_iter() - .chain(right_fields) - .cloned() - .collect(); - let full_schema = Arc::new(Schema::new(all_fields)); - - let physical_expr = self.create_expr(expr, full_schema)?; - let (left_field_indices, right_field_indices) = expr_to_columns( - &physical_expr, - left.schema().fields.len(), - right.schema().fields.len(), - )?; - let column_indices = JoinFilter::build_column_indices( - left_field_indices.clone(), - right_field_indices.clone(), - ); - - let filter_fields: Vec = left_field_indices + // Handle join filter as DataFusion `JoinFilter` struct + let join_filter = if let Some(expr) = condition { + let left_schema = left.schema(); + let right_schema = right.schema(); + let left_fields = left_schema.fields(); + let right_fields = right_schema.fields(); + let all_fields: Vec<_> = left_fields + .into_iter() + .chain(right_fields) + .cloned() + .collect(); + let full_schema = Arc::new(Schema::new(all_fields)); + + let physical_expr = self.create_expr(expr, full_schema)?; + let (left_field_indices, right_field_indices) = + expr_to_columns(&physical_expr, left_fields.len(), right_fields.len())?; + let column_indices = JoinFilter::build_column_indices( + left_field_indices.clone(), + right_field_indices.clone(), + ); + + let filter_fields: Vec = left_field_indices + .clone() + .into_iter() + .map(|i| left.schema().field(i).clone()) + .chain( + right_field_indices .clone() .into_iter() - .map(|i| left.schema().field(i).clone()) - .chain( - right_field_indices - .clone() - .into_iter() - .map(|i| right.schema().field(i).clone()), - ) - .collect_vec(); - - let filter_schema = Schema::new_with_metadata(filter_fields, HashMap::new()); - - // Rewrite the physical expression to use the new column indices. - // DataFusion's join filter is bound to intermediate schema which contains - // only the fields used in the filter expression. But the Spark's join filter - // expression is bound to the full schema. We need to rewrite the physical - // expression to use the new column indices. - let rewritten_physical_expr = rewrite_physical_expr( - physical_expr, - left_schema.fields.len(), - right_schema.fields.len(), - &left_field_indices, - &right_field_indices, - )?; - - Some(JoinFilter::new( - rewritten_physical_expr, - column_indices, - filter_schema, - )) - } else { - None - }; - - // DataFusion `HashJoinExec` operator keeps the input batch internally. We need - // to copy the input batch to avoid the data corruption from reusing the input - // batch. - let left = if can_reuse_input_batch(&left) { - Arc::new(CopyExec::new(left)) - } else { - left - }; + .map(|i| right.schema().field(i).clone()), + ) + .collect_vec(); + + let filter_schema = Schema::new_with_metadata(filter_fields, HashMap::new()); + + // Rewrite the physical expression to use the new column indices. + // DataFusion's join filter is bound to intermediate schema which contains + // only the fields used in the filter expression. But the Spark's join filter + // expression is bound to the full schema. We need to rewrite the physical + // expression to use the new column indices. + let rewritten_physical_expr = rewrite_physical_expr( + physical_expr, + left_schema.fields.len(), + right_schema.fields.len(), + &left_field_indices, + &right_field_indices, + )?; + + Some(JoinFilter::new( + rewritten_physical_expr, + column_indices, + filter_schema, + )) + } else { + None + }; - let right = if can_reuse_input_batch(&right) { - Arc::new(CopyExec::new(right)) - } else { - right - }; + // DataFusion Join operators keep the input batch internally. We need + // to copy the input batch to avoid the data corruption from reusing the input + // batch. + let left = if can_reuse_input_batch(&left) { + Arc::new(CopyExec::new(left)) + } else { + left + }; - let join = Arc::new(HashJoinExec::try_new( - left, - right, - join_on, - join_filter, - &join_type, - PartitionMode::Partitioned, - false, - )?); + let right = if can_reuse_input_batch(&right) { + Arc::new(CopyExec::new(right)) + } else { + right + }; - Ok((left_scans, join)) - } - } + Ok(( + JoinParameters { + left, + right, + join_on, + join_type, + join_filter, + }, + left_scans, + )) } /// Create a DataFusion physical aggregate expression from Spark physical aggregate expression