Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Produce buffered null join row only if all joined rows are failed on join filter in SMJ full join #12090

Merged
merged 8 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 27 additions & 21 deletions datafusion/physical-plan/src/joins/sort_merge_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

use std::any::Any;
use std::cmp::Ordering;
use std::collections::VecDeque;
use std::collections::{HashMap, VecDeque};
use std::fmt::Formatter;
use std::fs::File;
use std::io::BufReader;
Expand Down Expand Up @@ -595,8 +595,10 @@ struct BufferedBatch {
/// Size estimation used for reserving / releasing memory
pub size_estimation: usize,
/// The indices of buffered batch that failed the join filter.
/// This is a map between buffered row index and a boolean value indicating whether all joined row
/// of the buffered row failed the join filter.
/// When dequeuing the buffered batch, we need to produce null joined rows for these indices.
pub join_filter_failed_idxs: HashSet<u64>,
pub join_filter_failed_map: HashMap<u64, bool>,
viirya marked this conversation as resolved.
Show resolved Hide resolved
/// Current buffered batch number of rows. Equal to batch.num_rows()
/// but if batch is spilled to disk this property is preferable
/// and less expensive
Expand Down Expand Up @@ -637,7 +639,7 @@ impl BufferedBatch {
join_arrays,
null_joined: vec![],
size_estimation,
join_filter_failed_idxs: HashSet::new(),
join_filter_failed_map: HashMap::new(),
num_rows,
spill_file: None,
}
Expand Down Expand Up @@ -1229,11 +1231,19 @@ impl SMJStream {
}
buffered_batch.null_joined.clear();

// For buffered rows which are joined with streamed side but doesn't satisfy the join filter
// For buffered row which is joined with streamed side rows but all joined rows
// don't satisfy the join filter
if output_not_matched_filter {
let not_matched_buffered_indices = buffered_batch
.join_filter_failed_map
.iter()
.filter_map(|(idx, failed)| if *failed { Some(*idx) } else { None })
.collect::<Vec<_>>();

let buffered_indices = UInt64Array::from_iter_values(
buffered_batch.join_filter_failed_idxs.iter().copied(),
not_matched_buffered_indices.iter().copied(),
);

if let Some(record_batch) = produce_buffered_null_batch(
&self.schema,
&self.streamed_schema,
Expand All @@ -1242,7 +1252,7 @@ impl SMJStream {
)? {
self.output_record_batches.push(record_batch);
}
buffered_batch.join_filter_failed_idxs.clear();
buffered_batch.join_filter_failed_map.clear();
}
}
Ok(())
Expand Down Expand Up @@ -1459,24 +1469,20 @@ impl SMJStream {
// If it is joined with streamed side, but doesn't match the join filter,
// we need to output it with nulls as streamed side.
if matches!(self.join_type, JoinType::Full) {
let buffered_batch = &mut self.buffered_data.batches
[chunk.buffered_batch_idx.unwrap()];

for i in 0..pre_mask.len() {
let buffered_batch = &mut self.buffered_data.batches
[chunk.buffered_batch_idx.unwrap()];
let buffered_index = buffered_indices.value(i);

if !pre_mask.value(i) {
// For a buffered row that is joined with streamed side but doesn't satisfy the join filter,
buffered_batch
.join_filter_failed_idxs
.insert(buffered_index);
} else if buffered_batch
.join_filter_failed_idxs
.contains(&buffered_index)
{
buffered_batch
.join_filter_failed_idxs
.remove(&buffered_index);
}
buffered_batch.join_filter_failed_map.insert(
buffered_index,
*buffered_batch
.join_filter_failed_map
.get(&buffered_index)
.unwrap_or(&true)
&& !pre_mask.value(i),
);
}
}
}
Expand Down
27 changes: 26 additions & 1 deletion datafusion/sqllogictest/test_files/sort_merge_join.slt
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,6 @@ Alice 100 NULL NULL
Alice 50 Alice 2
Bob 1 NULL NULL
NULL NULL Alice 1
NULL NULL Alice 2
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is incorrect. Alice 2 was joined with Alice 50 above. We should not produce null join row for it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

double checked the joins, it looks correct


query TITI rowsort
SELECT * FROM t1 FULL JOIN t2 ON t1.a = t2.a AND t1.b > t2.b + 50
Expand Down Expand Up @@ -624,6 +623,32 @@ NULL NULL 7 9
NULL NULL 8 10
NULL NULL 9 11

query IIII
select * from (
with t as (
select id_a id_a_1, id_a % 5 id_a_2 from (select unnest(make_array(5, 6, 7, 8, 9, 0, 1, 2, 3, 4)) id_a)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The order of the values in the array is important. If it is 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, it cannot test against the bug.

), t1 as (
select id_b % 10 id_b_1, id_b + 2 id_b_2 from (select unnest(make_array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9)) id_b)
)
select * from t full join t1 on t.id_a_2 = t1.id_b_1 and t.id_a_1 > t1.id_b_2
) order by 1, 2, 3, 4
----
0 0 NULL NULL
1 1 NULL NULL
2 2 NULL NULL
3 3 NULL NULL
4 4 NULL NULL
5 0 0 2
6 1 1 3
7 2 2 4
8 3 3 5
9 4 4 6
NULL NULL 5 7
NULL NULL 6 8
NULL NULL 7 9
NULL NULL 8 10
NULL NULL 9 11

# return sql params back to default values
statement ok
set datafusion.optimizer.prefer_hash_join = true;
Expand Down