Skip to content

Commit

Permalink
Add function to parse join parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Mar 20, 2024
1 parent 85afec8 commit 64240fa
Showing 1 changed file with 168 additions and 178 deletions.
346 changes: 168 additions & 178 deletions core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,14 @@ type PhyAggResult = Result<Vec<Arc<dyn AggregateExpr>>, ExecutionError>;
type PhyExprResult = Result<Vec<(Arc<dyn PhysicalExpr>, String)>, ExecutionError>;
type PartitionPhyExprResult = Result<Vec<Arc<dyn PhysicalExpr>>, ExecutionError>;

struct JoinParameters {
pub left: Arc<dyn ExecutionPlan>,
pub right: Arc<dyn ExecutionPlan>,
pub join_on: Vec<(Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>)>,
pub join_filter: Option<JoinFilter>,
pub join_type: DFJoinType,
}

pub const TEST_EXEC_CONTEXT_ID: i64 = -1;

/// The query planner for converting Spark query plans to DataFusion query plans.
Expand Down Expand Up @@ -873,214 +881,196 @@ impl PhysicalPlanner {
))
}
OpStruct::SortMergeJoin(join) => {
assert!(children.len() == 2);
let (mut left_scans, left) = self.create_plan(&children[0], inputs)?;
let (mut right_scans, right) = self.create_plan(&children[1], inputs)?;

left_scans.append(&mut right_scans);

let left_join_exprs = join
.left_join_keys
.iter()
.map(|expr| self.create_expr(expr, left.schema()))
.collect::<Result<Vec<_>, _>>()?;
let right_join_exprs = join
.right_join_keys
.iter()
.map(|expr| self.create_expr(expr, right.schema()))
.collect::<Result<Vec<_>, _>>()?;

let join_on = left_join_exprs
.into_iter()
.zip(right_join_exprs)
.collect::<Vec<_>>();

let join_type = match join.join_type.try_into() {
Ok(JoinType::Inner) => DFJoinType::Inner,
Ok(JoinType::LeftOuter) => DFJoinType::Left,
Ok(JoinType::RightOuter) => DFJoinType::Right,
Ok(JoinType::FullOuter) => DFJoinType::Full,
Ok(JoinType::LeftSemi) => DFJoinType::LeftSemi,
Ok(JoinType::RightSemi) => DFJoinType::RightSemi,
Ok(JoinType::LeftAnti) => DFJoinType::LeftAnti,
Ok(JoinType::RightAnti) => DFJoinType::RightAnti,
Err(_) => {
return Err(ExecutionError::GeneralError(format!(
"Unsupported join type: {:?}",
join.join_type
)));
}
};
let (join_params, scans) = self.parse_join_parameters(
inputs,
children,
&join.left_join_keys,
&join.right_join_keys,
join.join_type,
&None,
)?;

let sort_options = join
.sort_options
.iter()
.map(|sort_option| {
let sort_expr = self.create_sort_expr(sort_option, left.schema()).unwrap();
let sort_expr = self
.create_sort_expr(sort_option, join_params.left.schema())
.unwrap();
SortOptions {
descending: sort_expr.options.descending,
nulls_first: sort_expr.options.nulls_first,
}
})
.collect();

// DataFusion `SortMergeJoinExec` operator keeps the input batch internally. We need
// to copy the input batch to avoid the data corruption from reusing the input
// batch.
let left = if can_reuse_input_batch(&left) {
Arc::new(CopyExec::new(left))
} else {
left
};

let right = if can_reuse_input_batch(&right) {
Arc::new(CopyExec::new(right))
} else {
right
};

let join = Arc::new(SortMergeJoinExec::try_new(
left,
right,
join_on,
None,
join_type,
join_params.left,
join_params.right,
join_params.join_on,
join_params.join_filter,
join_params.join_type,
sort_options,
// null doesn't equal to null in Spark join key. If the join key is
// `EqualNullSafe`, Spark will rewrite it during planning.
false,
)?);

Ok((left_scans, join))
Ok((scans, join))
}
OpStruct::HashJoin(join) => {
assert!(children.len() == 2);
let (mut left_scans, left) = self.create_plan(&children[0], inputs)?;
let (mut right_scans, right) = self.create_plan(&children[1], inputs)?;

left_scans.append(&mut right_scans);
let (join_params, scans) = self.parse_join_parameters(
inputs,
children,
&join.left_join_keys,
&join.right_join_keys,
join.join_type,
&join.condition,
)?;
let join = Arc::new(HashJoinExec::try_new(
join_params.left,
join_params.right,
join_params.join_on,
join_params.join_filter,
&join_params.join_type,
PartitionMode::Partitioned,
// null doesn't equal to null in Spark join key. If the join key is
// `EqualNullSafe`, Spark will rewrite it during planning.
false,
)?);
Ok((scans, join))
}
}
}

let left_join_exprs: Vec<_> = join
.left_join_keys
.iter()
.map(|expr| self.create_expr(expr, left.schema()))
.collect::<Result<Vec<_>, _>>()?;
let right_join_exprs: Vec<_> = join
.right_join_keys
.iter()
.map(|expr| self.create_expr(expr, right.schema()))
.collect::<Result<Vec<_>, _>>()?;
fn parse_join_parameters(
&self,
inputs: &mut Vec<Arc<GlobalRef>>,
children: &[Operator],
left_join_keys: &[Expr],
right_join_keys: &[Expr],
join_type: i32,
condition: &Option<Expr>,
) -> Result<(JoinParameters, Vec<ScanExec>), ExecutionError> {
assert!(children.len() == 2);
let (mut left_scans, left) = self.create_plan(&children[0], inputs)?;
let (mut right_scans, right) = self.create_plan(&children[1], inputs)?;

left_scans.append(&mut right_scans);

let left_join_exprs: Vec<_> = left_join_keys
.iter()
.map(|expr| self.create_expr(expr, left.schema()))
.collect::<Result<Vec<_>, _>>()?;
let right_join_exprs: Vec<_> = right_join_keys
.iter()
.map(|expr| self.create_expr(expr, right.schema()))
.collect::<Result<Vec<_>, _>>()?;

let join_on = left_join_exprs
.into_iter()
.zip(right_join_exprs)
.collect::<Vec<_>>();
let join_on = left_join_exprs
.into_iter()
.zip(right_join_exprs)
.collect::<Vec<_>>();

let join_type = match join.join_type.try_into() {
Ok(JoinType::Inner) => DFJoinType::Inner,
Ok(JoinType::LeftOuter) => DFJoinType::Left,
Ok(JoinType::RightOuter) => DFJoinType::Right,
Ok(JoinType::FullOuter) => DFJoinType::Full,
Ok(JoinType::LeftSemi) => DFJoinType::LeftSemi,
Ok(JoinType::RightSemi) => DFJoinType::RightSemi,
Ok(JoinType::LeftAnti) => DFJoinType::LeftAnti,
Ok(JoinType::RightAnti) => DFJoinType::RightAnti,
Err(_) => {
return Err(ExecutionError::GeneralError(format!(
"Unsupported join type: {:?}",
join.join_type
)));
}
};
let join_type = match join_type.try_into() {
Ok(JoinType::Inner) => DFJoinType::Inner,
Ok(JoinType::LeftOuter) => DFJoinType::Left,
Ok(JoinType::RightOuter) => DFJoinType::Right,
Ok(JoinType::FullOuter) => DFJoinType::Full,
Ok(JoinType::LeftSemi) => DFJoinType::LeftSemi,
Ok(JoinType::RightSemi) => DFJoinType::RightSemi,
Ok(JoinType::LeftAnti) => DFJoinType::LeftAnti,
Ok(JoinType::RightAnti) => DFJoinType::RightAnti,
Err(_) => {
return Err(ExecutionError::GeneralError(format!(
"Unsupported join type: {:?}",
join_type
)));
}
};

// Handle join filter as DataFusion `JoinFilter` struct
let join_filter = if let Some(expr) = &join.condition {
let left_schema = left.schema();
let right_schema = right.schema();
let left_fields = left_schema.fields();
let right_fields = right_schema.fields();
let all_fields: Vec<_> = left_fields
.into_iter()
.chain(right_fields)
.cloned()
.collect();
let full_schema = Arc::new(Schema::new(all_fields));

let physical_expr = self.create_expr(expr, full_schema)?;
let (left_field_indices, right_field_indices) = expr_to_columns(
&physical_expr,
left.schema().fields.len(),
right.schema().fields.len(),
)?;
let column_indices = JoinFilter::build_column_indices(
left_field_indices.clone(),
right_field_indices.clone(),
);

let filter_fields: Vec<Field> = left_field_indices
// Handle join filter as DataFusion `JoinFilter` struct
let join_filter = if let Some(expr) = condition {
let left_schema = left.schema();
let right_schema = right.schema();
let left_fields = left_schema.fields();
let right_fields = right_schema.fields();
let all_fields: Vec<_> = left_fields
.into_iter()
.chain(right_fields)
.cloned()
.collect();
let full_schema = Arc::new(Schema::new(all_fields));

let physical_expr = self.create_expr(expr, full_schema)?;
let (left_field_indices, right_field_indices) =
expr_to_columns(&physical_expr, left_fields.len(), right_fields.len())?;
let column_indices = JoinFilter::build_column_indices(
left_field_indices.clone(),
right_field_indices.clone(),
);

let filter_fields: Vec<Field> = left_field_indices
.clone()
.into_iter()
.map(|i| left.schema().field(i).clone())
.chain(
right_field_indices
.clone()
.into_iter()
.map(|i| left.schema().field(i).clone())
.chain(
right_field_indices
.clone()
.into_iter()
.map(|i| right.schema().field(i).clone()),
)
.collect_vec();

let filter_schema = Schema::new_with_metadata(filter_fields, HashMap::new());

// Rewrite the physical expression to use the new column indices.
// DataFusion's join filter is bound to intermediate schema which contains
// only the fields used in the filter expression. But the Spark's join filter
// expression is bound to the full schema. We need to rewrite the physical
// expression to use the new column indices.
let rewritten_physical_expr = rewrite_physical_expr(
physical_expr,
left_schema.fields.len(),
right_schema.fields.len(),
&left_field_indices,
&right_field_indices,
)?;

Some(JoinFilter::new(
rewritten_physical_expr,
column_indices,
filter_schema,
))
} else {
None
};

// DataFusion `HashJoinExec` operator keeps the input batch internally. We need
// to copy the input batch to avoid the data corruption from reusing the input
// batch.
let left = if can_reuse_input_batch(&left) {
Arc::new(CopyExec::new(left))
} else {
left
};
.map(|i| right.schema().field(i).clone()),
)
.collect_vec();

let filter_schema = Schema::new_with_metadata(filter_fields, HashMap::new());

// Rewrite the physical expression to use the new column indices.
// DataFusion's join filter is bound to intermediate schema which contains
// only the fields used in the filter expression. But the Spark's join filter
// expression is bound to the full schema. We need to rewrite the physical
// expression to use the new column indices.
let rewritten_physical_expr = rewrite_physical_expr(
physical_expr,
left_schema.fields.len(),
right_schema.fields.len(),
&left_field_indices,
&right_field_indices,
)?;

Some(JoinFilter::new(
rewritten_physical_expr,
column_indices,
filter_schema,
))
} else {
None
};

let right = if can_reuse_input_batch(&right) {
Arc::new(CopyExec::new(right))
} else {
right
};
// DataFusion Join operators keep the input batch internally. We need
// to copy the input batch to avoid the data corruption from reusing the input
// batch.
let left = if can_reuse_input_batch(&left) {
Arc::new(CopyExec::new(left))
} else {
left
};

let join = Arc::new(HashJoinExec::try_new(
left,
right,
join_on,
join_filter,
&join_type,
PartitionMode::Partitioned,
false,
)?);
let right = if can_reuse_input_batch(&right) {
Arc::new(CopyExec::new(right))
} else {
right
};

Ok((left_scans, join))
}
}
Ok((
JoinParameters {
left,
right,
join_on,
join_type,
join_filter,
},
left_scans,
))
}

/// Create a DataFusion physical aggregate expression from Spark physical aggregate expression
Expand Down

0 comments on commit 64240fa

Please sign in to comment.