From f96fc32060f825b117e44e1e48cf868274c55972 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 13 May 2024 20:35:18 -0700 Subject: [PATCH] Fix SortMergeJoin with join filter filtering all rows out --- datafusion/core/tests/sql/joins.rs | 29 +++++++++++++++++++ .../src/joins/sort_merge_join.rs | 4 ++- 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index f7d5205db0d3..a803ab7ceb36 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -231,3 +231,32 @@ async fn join_change_in_planner_without_sort_not_allowed() -> Result<()> { } Ok(()) } + +#[tokio::test] +async fn test_smj_with_join_filter_fitering_all() -> Result<()> { + let ctx: SessionContext = SessionContext::new(); + + let sql = "set datafusion.optimizer.prefer_hash_join = false;"; + let _ = ctx.sql(sql).await?.collect().await?; + + let sql = "set datafusion.execution.batch_size = 1"; + let _ = ctx.sql(sql).await?.collect().await?; + + let sql = " + select * from ( + with + t1 as ( + select 12 a, 12 b + ), + t2 as ( + select 12 a, 12 b + ) + select t1.* from t1 join t2 on t1.a = t2.b where t1.a > t2.b + ) order by 1, 2; + "; + + let results = ctx.sql(sql).await?.collect().await?; + assert_eq!(results.len(), 0); + + Ok(()) +} diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 4c928a3d2d8d..d4cf6864d7e4 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -1323,7 +1323,9 @@ impl SMJStream { // If join filter exists, `self.output_size` is not accurate as we don't know the exact // number of rows in the output record batch. If streamed row joined with buffered rows, // once join filter is applied, the number of output rows may be more than 1. - if record_batch.num_rows() > self.output_size { + // If `record_batch` is empty, we should reset `self.output_size` to 0. It could be happened + // when the join filter is applied and all rows are filtered out. + if record_batch.num_rows() == 0 || record_batch.num_rows() > self.output_size { self.output_size = 0; } else { self.output_size -= record_batch.num_rows();