Skip to content

Commit

Permalink
Fix: Sort Merge Join Left Semi crashes
Browse files Browse the repository at this point in the history
  • Loading branch information
comphead committed May 4, 2024
1 parent 9385e4e commit 0c8222a
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 28 deletions.
35 changes: 7 additions & 28 deletions datafusion/physical-plan/src/joins/sort_merge_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -819,8 +819,6 @@ impl SMJStream {
self.streamed_state = StreamedState::Exhausted;
}
Poll::Ready(Some(batch)) => {
//dbg!(&batch);
//println!("{:#?}", &batch);
if batch.num_rows() > 0 {
self.freeze_streamed()?;
self.join_metrics.input_batches.add(1);
Expand Down Expand Up @@ -1057,7 +1055,6 @@ impl SMJStream {
Some(self.buffered_data.scanning_batch_idx)
};

//dbg!(self.buffered_data.scanning_idx());
self.streamed_batch
.append_output_pair(scanning_batch_idx, None);
self.output_size += 1;
Expand Down Expand Up @@ -1129,12 +1126,8 @@ impl SMJStream {
// Produces and stages record batch for all output indices found
// for current streamed batch and clears staged output indices.
fn freeze_streamed(&mut self) -> Result<()> {
//dbg!(&self.streamed_batch.batch);

for chunk in self.streamed_batch.output_indices.iter_mut() {
let streamed_indices = chunk.streamed_indices.finish();
dbg!(&streamed_indices);
//let streamed_indices = PrimitiveArray::<UInt64Type>::try_new(vec![0, 1].into(), None)?;

if streamed_indices.is_empty() {
continue;
Expand All @@ -1149,8 +1142,6 @@ impl SMJStream {
.collect::<Result<Vec<_>, ArrowError>>()?;

let buffered_indices: UInt64Array = chunk.buffered_indices.finish();
dbg!(&buffered_indices);
dbg!(&self.join_type);

let mut buffered_columns =
if matches!(self.join_type, JoinType::LeftSemi | JoinType::LeftAnti) {
Expand Down Expand Up @@ -1178,15 +1169,17 @@ impl SMJStream {
let filter_columns = if chunk.buffered_batch_idx.is_some() {
if matches!(self.join_type, JoinType::Right) {
get_filter_column(&self.filter, &buffered_columns, &streamed_columns)
} else if matches!(self.join_type, JoinType::LeftSemi | JoinType::LeftAnti) {
} else if matches!(
self.join_type,
JoinType::LeftSemi | JoinType::LeftAnti
) {
let buffered_columns = self.buffered_data.batches
[chunk.buffered_batch_idx.unwrap()]
.batch
.columns()
.iter()
.map(|column| take(column, &buffered_indices, None))
.collect::<Result<Vec<_>, ArrowError>>()?;
//dbg!(&buffered_columns);
.iter()
.map(|column| take(column, &buffered_indices, None))
.collect::<Result<Vec<_>, ArrowError>>()?;
get_filter_column(&self.filter, &streamed_columns, &buffered_columns)
} else {
get_filter_column(&self.filter, &streamed_columns, &buffered_columns)
Expand All @@ -1196,10 +1189,6 @@ impl SMJStream {
vec![]
};

dbg!(&streamed_columns);
dbg!(&buffered_columns);
dbg!(&filter_columns);

let columns = if matches!(self.join_type, JoinType::Right) {
buffered_columns.extend(streamed_columns.clone());
buffered_columns
Expand All @@ -1220,25 +1209,18 @@ impl SMJStream {
filter_columns,
)?;

dbg!(&filter_batch);
dbg!(&f.expression());


let filter_result = f
.expression()
.evaluate(&filter_batch)?
.into_array(filter_batch.num_rows())?;

dbg!(&filter_result);
// The selection mask of the filter
let mask = datafusion_common::cast::as_boolean_array(&filter_result)?;

// Push the filtered batch to the output
let filtered_batch =
compute::filter_record_batch(&output_batch, mask)?;

dbg!(&filtered_batch);

self.output_record_batches.push(filtered_batch);

// For outer joins, we need to push the null joined rows to the output.
Expand Down Expand Up @@ -1395,9 +1377,6 @@ fn get_filter_column(
.map(|i| buffered_columns[i.index].clone())
.collect::<Vec<_>>();

// dbg!(&left_columns);
// dbg!(&right_columns);

filter_columns.extend(left_columns);
filter_columns.extend(right_columns);
}
Expand Down
117 changes: 117 additions & 0 deletions datafusion/sqllogictest/test_files/sort_merge_join.slt
Original file line number Diff line number Diff line change
Expand Up @@ -263,5 +263,122 @@ DROP TABLE t1;
statement ok
DROP TABLE t2;


# LEFTSEMI join tests

query II
select * from (
with
t1 as (
select 11 a, 12 b union all
select 11 a, 13 b),
t2 as (
select 11 a, 12 b union all
select 11 a, 13 b
)
select t1.* from t1 where exists (select 1 from t2 where t2.a = t1.a and t2.b = t1.b)
) order by 1, 2
----
11 12
11 13

query II
select * from (
with
t1 as (
select 11 a, 12 b union all
select 11 a, 13 b),
t2 as (
select 11 a, 12 b union all
select 11 a, 13 b
)
select t1.* from t1 where exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b)
) order by 1, 2;
----
11 12
11 13

query II
select * from (
with
t1 as (
select null a, 12 b union all
select 11 a, 13 b),
t2 as (
select 11 a, 12 b union all
select 11 a, 13 b
)
select t1.* from t1 where exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b)
) order by 1, 2;
----
11 13

query II
select * from (
with
t1 as (
select 11 a, 12 b union all
select 11 a, 13 b)
select t1.* from t1 where exists (select 1 from t1 t2 where t2.a = t1.a and t2.b = t1.b)
) order by 1, 2;
----
11 12
11 13

query II
select * from (
with
t1 as (
select 11 a, 12 b union all
select 11 a, 13 b)
select t1.* from t1 where exists (select 1 from t1 t2 where t2.a = t1.a and t2.b != t1.b)
) order by 1, 2;
----
11 12
11 13

query II
select * from (
with
t1 as (
select null a, 12 b union all
select 11 a, 13 b)
select t1.* from t1 where exists (select 1 from t1 t2 where t2.a = t1.a and t2.b != t1.b)
) order by 1, 2;
----

# This gives a wrong result for now
#query II
#select * from (
#with
#t1 as (
# select 11 a, 12 b union all
# select 11 a, 13 b),
#t2 as (
# select 11 a, 12 b union all
# select 11 a, 14 b
# )
#select t1.* from t1 where exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b)
#) order by 1, 2;
#----

#LEFTANTI tests
# returns no rows instead of correct result
#query III
#select * from (
#with
#t1 as (
# select 11 a, 12 b, 1 c union all
# select 11 a, 13 b, 2 c),
#t2 as (
# select 11 a, 12 b, 3 c union all
# select 11 a, 14 b, 4 c
# )
#select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b and t1.c > t2.c)
#) order by 1, 2;
----
11 12 1
11 13 2

statement ok
set datafusion.optimizer.prefer_hash_join = true;

0 comments on commit 0c8222a

Please sign in to comment.