Skip to content

Commit

Permalink
Fix join filter
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Mar 14, 2024
1 parent cc13619 commit 334d7d9
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 3 deletions.
125 changes: 122 additions & 3 deletions core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ use datafusion::{
},
};
use datafusion_common::{
tree_node::{TreeNode, VisitRecursion},
tree_node::{TreeNode, TreeNodeRewriter, VisitRecursion},
JoinType as DFJoinType, ScalarValue,
};
use itertools::Itertools;
Expand Down Expand Up @@ -904,7 +904,18 @@ impl PhysicalPlanner {

// Handle join filter as DataFusion `JoinFilter` struct
let join_filter = if let Some(expr) = &join.condition {
let physical_expr = self.create_expr(expr, left.schema())?;
let left_schema = left.schema();
let right_schema = right.schema();
let left_fields = left_schema.fields();
let right_fields = right_schema.fields();
let all_fields: Vec<_> = left_fields
.into_iter()
.chain(right_fields.into_iter())
.map(|f| f.clone())
.collect();
let full_schema = Arc::new(Schema::new(all_fields));

let physical_expr = self.create_expr(expr, full_schema)?;
let (left_field_indices, right_field_indices) = expr_to_columns(
&physical_expr,
left.schema().fields.len(),
Expand All @@ -916,19 +927,34 @@ impl PhysicalPlanner {
);

let filter_fields: Vec<Field> = left_field_indices
.clone()
.into_iter()
.map(|i| left.schema().field(i).clone())
.chain(
right_field_indices
.clone()
.into_iter()
.map(|i| right.schema().field(i).clone()),
)
.collect_vec();

let filter_schema = Schema::new_with_metadata(filter_fields, HashMap::new());

Some(JoinFilter::new(
// Rewrite the physical expression to use the new column indices.
// DataFusion's join filter is bound to intermediate schema which contains
// only the fields used in the filter expression. But the Spark's join filter
// expression is bound to the full schema. We need to rewrite the physical
// expression to use the new column indices.
let rewritten_physical_expr = rewrite_physical_expr(
physical_expr,
left_schema.fields.len(),
right_schema.fields.len(),
&left_field_indices,
&right_field_indices,
)?;

Some(JoinFilter::new(
rewritten_physical_expr,
column_indices,
filter_schema,
))
Expand Down Expand Up @@ -1173,6 +1199,99 @@ fn expr_to_columns(
Ok((left_field_indices, right_field_indices))
}

/// A physical join filter rewritter which rewrites the column indices in the expression
/// to use the new column indices. See `rewrite_physical_expr`.
struct JoinFilterRewriter<'a> {
left_field_len: usize,
right_field_len: usize,
left_field_indices: &'a [usize],
right_field_indices: &'a [usize],
}

impl JoinFilterRewriter<'_> {
fn new<'a>(
left_field_len: usize,
right_field_len: usize,
left_field_indices: &'a [usize],
right_field_indices: &'a [usize],
) -> JoinFilterRewriter<'a> {
JoinFilterRewriter {
left_field_len,
right_field_len,
left_field_indices,
right_field_indices,
}
}
}

impl TreeNodeRewriter for JoinFilterRewriter<'_> {
type N = Arc<dyn PhysicalExpr>;

fn mutate(&mut self, node: Self::N) -> datafusion_common::Result<Self::N> {
let new_expr: Arc<dyn PhysicalExpr> =
if let Some(column) = node.as_any().downcast_ref::<Column>() {
if column.index() < self.left_field_len {
// left side
let new_index = self
.left_field_indices
.iter()
.position(|&x| x == column.index())
.ok_or_else(|| {
DataFusionError::Internal(format!(
"Column index {} not found in left field indices",
column.index()
))
})?;
Arc::new(Column::new(column.name(), new_index))
} else if column.index() < self.left_field_len + self.right_field_len {
// right side
let new_index = self
.right_field_indices
.iter()
.position(|&x| x + self.left_field_len == column.index())
.ok_or_else(|| {
DataFusionError::Internal(format!(
"Column index {} not found in right field indices",
column.index()
))
})?;
Arc::new(Column::new(
column.name(),
new_index + self.left_field_indices.len(),
))
} else {
return Err(DataFusionError::Internal(format!(
"Column index {} out of range",
column.index()
)));
}
} else {
node.clone()
};
Ok(new_expr)
}
}

/// Rewrites the physical expression to use the new column indices.
/// This is necessary when the physical expression is used in a join filter, as the column
/// indices are different from the original schema.
fn rewrite_physical_expr(
expr: Arc<dyn PhysicalExpr>,
left_field_len: usize,
right_field_len: usize,
left_field_indices: &[usize],
right_field_indices: &[usize],
) -> Result<Arc<dyn PhysicalExpr>, ExecutionError> {
let mut rewriter = JoinFilterRewriter::new(
left_field_len,
right_field_len,
left_field_indices,
right_field_indices,
);

Ok(expr.rewrite(&mut rewriter)?)
}

#[cfg(test)]
mod tests {
use std::{sync::Arc, task::Poll};
Expand Down
32 changes: 32 additions & 0 deletions spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,38 @@ class CometExecSuite extends CometTestBase {
}
}

test("HashJoin with join filter") {
withSQLConf(
SQLConf.PREFER_SORTMERGEJOIN.key -> "false",
SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
withParquetTable((0 until 10).map(i => (i, i % 5)), "tbl_a") {
withParquetTable((0 until 10).map(i => (i % 10, i + 2)), "tbl_b") {
// Inner join: build left
val df1 =
sql(
"SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a JOIN tbl_b " +
"ON tbl_a._2 = tbl_b._1 AND tbl_a._1 > tbl_b._2")
checkSparkAnswerAndOperator(df1)

// Right join: build left
val df2 =
sql(
"SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a RIGHT JOIN tbl_b " +
"ON tbl_a._2 = tbl_b._1 AND tbl_a._1 > tbl_b._2")
checkSparkAnswerAndOperator(df2)

// Full join: build left
val df3 =
sql(
"SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a FULL JOIN tbl_b " +
"ON tbl_a._2 = tbl_b._1 AND tbl_a._1 > tbl_b._2")
checkSparkAnswerAndOperator(df3)
}
}
}
}

test("Fix corrupted AggregateMode when transforming plan parameters") {
withParquetTable((0 until 5).map(i => (i, i + 1)), "table") {
val df = sql("SELECT * FROM table").groupBy($"_1").agg(sum("_2"))
Expand Down

0 comments on commit 334d7d9

Please sign in to comment.