Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move filtered SMJ Left Anti filtered join out of join_partial phase #13111

Merged
merged 3 commits into from
Oct 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions datafusion/core/tests/fuzz_cases/join_fuzz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ use datafusion::physical_plan::joins::{
};
use datafusion::physical_plan::memory::MemoryExec;

use crate::fuzz_cases::join_fuzz::JoinTestType::NljHj;
use datafusion::prelude::{SessionConfig, SessionContext};
use test_utils::stagger_batch_with_seed;

Expand Down Expand Up @@ -223,17 +224,14 @@ async fn test_anti_join_1k() {
}

#[tokio::test]
// flaky for HjSmj case, giving 1 rows difference sometimes
// https://github.com/apache/datafusion/issues/11555
#[ignore]
async fn test_anti_join_1k_filtered() {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

works now

JoinFuzzTestCase::new(
make_staggered_batches(1000),
make_staggered_batches(1000),
JoinType::LeftAnti,
Some(Box::new(col_lt_col_filter)),
)
.run_test(&[JoinTestType::NljHj], false)
.run_test(&[JoinTestType::HjSmj, NljHj], false)
.await
}

Expand Down
245 changes: 227 additions & 18 deletions datafusion/physical-plan/src/joins/sort_merge_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -802,6 +802,32 @@ fn get_corrected_filter_mask(

Some(corrected_mask.finish())
}
JoinType::LeftAnti => {
for i in 0..row_indices_length {
let last_index =
last_index_for_row(i, row_indices, batch_ids, row_indices_length);

if filter_mask.value(i) {
seen_true = true;
}

if last_index {
if !seen_true {
corrected_mask.append_value(true);
} else {
corrected_mask.append_null();
}

seen_true = false;
} else {
corrected_mask.append_null();
}
}

let null_matched = expected_size - corrected_mask.len();
corrected_mask.extend(vec![Some(true); null_matched]);
Some(corrected_mask.finish())
}
// Only outer joins needs to keep track of processed rows and apply corrected filter mask
_ => None,
}
Expand Down Expand Up @@ -835,15 +861,18 @@ impl Stream for SMJStream {
JoinType::Left
| JoinType::LeftSemi
| JoinType::Right
| JoinType::LeftAnti
)
{
self.freeze_all()?;

if !self.output_record_batches.batches.is_empty()
&& self.buffered_data.scanning_finished()
{
let out_batch = self.filter_joined_batch()?;
return Poll::Ready(Some(Ok(out_batch)));
let out_filtered_batch =
self.filter_joined_batch()?;
return Poll::Ready(Some(Ok(
out_filtered_batch,
)));
}
}

Expand Down Expand Up @@ -907,15 +936,17 @@ impl Stream for SMJStream {
// because target output batch size can be hit in the middle of
// filtering causing the filtering to be incomplete and causing
// correctness issues
let record_batch = if !(self.filter.is_some()
if self.filter.is_some()
&& matches!(
self.join_type,
JoinType::Left | JoinType::LeftSemi | JoinType::Right
)) {
record_batch
} else {
JoinType::Left
| JoinType::LeftSemi
| JoinType::Right
| JoinType::LeftAnti
)
{
continue;
};
}

return Poll::Ready(Some(Ok(record_batch)));
}
Expand All @@ -929,7 +960,10 @@ impl Stream for SMJStream {
if self.filter.is_some()
&& matches!(
self.join_type,
JoinType::Left | JoinType::LeftSemi | JoinType::Right
JoinType::Left
| JoinType::LeftSemi
| JoinType::Right
| JoinType::LeftAnti
)
{
let out = self.filter_joined_batch()?;
Expand Down Expand Up @@ -1273,11 +1307,7 @@ impl SMJStream {
};

if matches!(self.join_type, JoinType::LeftAnti) && self.filter.is_some() {
join_streamed = !self
.streamed_batch
.join_filter_matched_idxs
.contains(&(self.streamed_batch.idx as u64))
&& !self.streamed_joined;
join_streamed = !self.streamed_joined;
join_buffered = join_streamed;
}
}
Expand Down Expand Up @@ -1519,7 +1549,10 @@ impl SMJStream {
// Push the filtered batch which contains rows passing join filter to the output
if matches!(
self.join_type,
JoinType::Left | JoinType::LeftSemi | JoinType::Right
JoinType::Left
| JoinType::LeftSemi
| JoinType::Right
| JoinType::LeftAnti
) {
self.output_record_batches
.batches
Expand Down Expand Up @@ -1654,7 +1687,10 @@ impl SMJStream {
if !(self.filter.is_some()
&& matches!(
self.join_type,
JoinType::Left | JoinType::LeftSemi | JoinType::Right
JoinType::Left
| JoinType::LeftSemi
| JoinType::Right
| JoinType::LeftAnti
))
{
self.output_record_batches.batches.clear();
Expand Down Expand Up @@ -1727,7 +1763,7 @@ impl SMJStream {
&self.schema,
&[filtered_record_batch, null_joined_streamed_batch],
)?;
} else if matches!(self.join_type, JoinType::LeftSemi) {
} else if matches!(self.join_type, JoinType::LeftSemi | JoinType::LeftAnti) {
let output_column_indices = (0..streamed_columns_length).collect::<Vec<_>>();
filtered_record_batch =
filtered_record_batch.project(&output_column_indices)?;
Expand Down Expand Up @@ -3349,6 +3385,7 @@ mod tests {
batch_ids: vec![],
};

// Insert already prejoined non-filtered rows
batches.batches.push(RecordBatch::try_new(
Arc::clone(&schema),
vec![
Expand Down Expand Up @@ -3835,6 +3872,178 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn test_left_anti_join_filtered_mask() -> Result<()> {
let mut joined_batches = build_joined_record_batches()?;
let schema = joined_batches.batches.first().unwrap().schema();

let output = concat_batches(&schema, &joined_batches.batches)?;
let out_mask = joined_batches.filter_mask.finish();
let out_indices = joined_batches.row_indices.finish();

assert_eq!(
get_corrected_filter_mask(
LeftAnti,
&UInt64Array::from(vec![0]),
&[0usize],
&BooleanArray::from(vec![true]),
1
)
.unwrap(),
BooleanArray::from(vec![None])
);

assert_eq!(
get_corrected_filter_mask(
LeftAnti,
&UInt64Array::from(vec![0]),
&[0usize],
&BooleanArray::from(vec![false]),
1
)
.unwrap(),
BooleanArray::from(vec![Some(true)])
);

assert_eq!(
get_corrected_filter_mask(
LeftAnti,
&UInt64Array::from(vec![0, 0]),
&[0usize; 2],
&BooleanArray::from(vec![true, true]),
2
)
.unwrap(),
BooleanArray::from(vec![None, None])
);

assert_eq!(
get_corrected_filter_mask(
LeftAnti,
&UInt64Array::from(vec![0, 0, 0]),
&[0usize; 3],
&BooleanArray::from(vec![true, true, true]),
3
)
.unwrap(),
BooleanArray::from(vec![None, None, None])
);

assert_eq!(
get_corrected_filter_mask(
LeftAnti,
&UInt64Array::from(vec![0, 0, 0]),
&[0usize; 3],
&BooleanArray::from(vec![true, false, true]),
3
)
.unwrap(),
BooleanArray::from(vec![None, None, None])
);

assert_eq!(
get_corrected_filter_mask(
LeftAnti,
&UInt64Array::from(vec![0, 0, 0]),
&[0usize; 3],
&BooleanArray::from(vec![false, false, true]),
3
)
.unwrap(),
BooleanArray::from(vec![None, None, None])
);

assert_eq!(
get_corrected_filter_mask(
LeftAnti,
&UInt64Array::from(vec![0, 0, 0]),
&[0usize; 3],
&BooleanArray::from(vec![false, true, true]),
3
)
.unwrap(),
BooleanArray::from(vec![None, None, None])
);

assert_eq!(
get_corrected_filter_mask(
LeftAnti,
&UInt64Array::from(vec![0, 0, 0]),
&[0usize; 3],
&BooleanArray::from(vec![false, false, false]),
3
)
.unwrap(),
BooleanArray::from(vec![None, None, Some(true)])
);

let corrected_mask = get_corrected_filter_mask(
LeftAnti,
&out_indices,
&joined_batches.batch_ids,
&out_mask,
output.num_rows(),
)
.unwrap();

assert_eq!(
corrected_mask,
BooleanArray::from(vec![
None,
None,
None,
None,
None,
Some(true),
None,
Some(true)
])
);

let filtered_rb = filter_record_batch(&output, &corrected_mask)?;

assert_batches_eq!(
&[
"+---+----+---+----+",
"| a | b | x | y |",
"+---+----+---+----+",
"| 1 | 13 | 1 | 12 |",
"| 1 | 14 | 1 | 11 |",
"+---+----+---+----+",
],
&[filtered_rb]
);

// output null rows
let null_mask = arrow::compute::not(&corrected_mask)?;
assert_eq!(
null_mask,
BooleanArray::from(vec![
None,
None,
None,
None,
None,
Some(false),
None,
Some(false),
])
);

let null_joined_batch = filter_record_batch(&output, &null_mask)?;

assert_batches_eq!(
&[
"+---+---+---+---+",
"| a | b | x | y |",
"+---+---+---+---+",
"+---+---+---+---+",
],
&[null_joined_batch]
);
Ok(())
}

/// Returns the column names on the schema
fn columns(schema: &Schema) -> Vec<String> {
schema.fields().iter().map(|f| f.name().clone()).collect()
Expand Down
Loading