Skip to content

Commit

Permalink
Fix: Sort Merge Join Left Semi crashes. Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
comphead committed May 9, 2024
1 parent 2ab98e8 commit 8e9b490
Showing 1 changed file with 107 additions and 30 deletions.
137 changes: 107 additions & 30 deletions datafusion/physical-plan/src/joins/sort_merge_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1207,37 +1207,11 @@ impl SMJStream {
// The selection mask of the filter
let mut mask =
datafusion_common::cast::as_boolean_array(&filter_result)?;
// for LeftSemi Join the filter mask should be calculated in its own way:
// if we find at least one matching row for specific streaming key/filter we dont need to check others for the same key/filter
let mut maybe_left_semi_mask: Option<BooleanArray> = None;
if matches!(self.join_type, JoinType::LeftSemi) {
// did we get a filter match for a streaming index
let mut seen_as_true: bool = false;
let streamed_indices_length = streamed_indices.len();
let mut corrected_mask: Vec<bool> =
vec![false; streamed_indices_length];

#[allow(clippy::needless_range_loop)]
for i in 0..streamed_indices_length {
// if for a streaming index its a match first time, set it as true
if mask.value(i) && !seen_as_true {
seen_as_true = true;
corrected_mask[i] = true;
}

// if switched to next streaming index(e.g from 0 to 1, or from 1 to 2), we reset seen_as_true flag
if i < streamed_indices_length - 1
&& streamed_indices.value(i)
!= streamed_indices.value(i + 1)
{
seen_as_true = false;
}
}
maybe_left_semi_mask = Some(BooleanArray::from(corrected_mask))
};

if let Some(ref left_semi_mask) = maybe_left_semi_mask {
mask = left_semi_mask;
let maybe_filtered_join_mask: Option<BooleanArray> =
get_filtered_join_mask(self.join_type, streamed_indices, mask);
if let Some(ref filtered_join_mask) = maybe_filtered_join_mask {
mask = filtered_join_mask;
}

// Push the filtered batch to the output
Expand Down Expand Up @@ -1421,6 +1395,42 @@ fn get_buffered_columns(
.collect::<Result<Vec<_>, ArrowError>>()
}

// Calculate join filter bit mask considering join type specifics
fn get_filtered_join_mask(
join_type: JoinType,
streamed_indices: UInt64Array,
mask: &BooleanArray,
) -> Option<BooleanArray> {
// for LeftSemi Join the filter mask should be calculated in its own way:
// if we find at least one matching row for specific streaming index
// we dont need to check any others for the same index
if matches!(join_type, JoinType::LeftSemi) {
// have we seen a filter match for a streaming index before
let mut seen_as_true: bool = false;
let streamed_indices_length = streamed_indices.len();
let mut corrected_mask: Vec<bool> = vec![false; streamed_indices_length];

#[allow(clippy::needless_range_loop)]
for i in 0..streamed_indices_length {
// if for a streaming index its a match first time, set it as true
if mask.value(i) && !seen_as_true {
seen_as_true = true;
corrected_mask[i] = true;
}

// if switched to next streaming index(e.g from 0 to 1, or from 1 to 2), we reset seen_as_true flag
if i < streamed_indices_length - 1
&& streamed_indices.value(i) != streamed_indices.value(i + 1)
{
seen_as_true = false;
}
}
Some(BooleanArray::from(corrected_mask))
} else {
None
}
}

/// Buffered data contains all buffered batches with one unique join key
#[derive(Debug, Default)]
struct BufferedData {
Expand Down Expand Up @@ -1667,10 +1677,13 @@ mod tests {
use crate::test::build_table_i32;
use crate::{common, ExecutionPlan};

use crate::joins::sort_merge_join::get_filtered_join_mask;
use arrow::array::{Date32Array, Date64Array, Int32Array};
use arrow::compute::SortOptions;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use arrow_array::{BooleanArray, UInt64Array};
use datafusion_common::JoinType::LeftSemi;
use datafusion_common::{
assert_batches_eq, assert_batches_sorted_eq, assert_contains, JoinType, Result,
};
Expand Down Expand Up @@ -2697,6 +2710,70 @@ mod tests {

Ok(())
}

#[tokio::test]
async fn left_semi_join_filtered_mask() -> Result<()> {
assert_eq!(
get_filtered_join_mask(
LeftSemi,
UInt64Array::from(vec![0, 0, 1, 1]),
&BooleanArray::from(vec![true, true, false, false])
),
Some(BooleanArray::from(vec![true, false, false, false]))
);

assert_eq!(
get_filtered_join_mask(
LeftSemi,
UInt64Array::from(vec![0, 1]),
&BooleanArray::from(vec![true, true])
),
Some(BooleanArray::from(vec![true, true]))
);

assert_eq!(
get_filtered_join_mask(
LeftSemi,
UInt64Array::from(vec![0, 1]),
&BooleanArray::from(vec![false, true])
),
Some(BooleanArray::from(vec![false, true]))
);

assert_eq!(
get_filtered_join_mask(
LeftSemi,
UInt64Array::from(vec![0, 1]),
&BooleanArray::from(vec![true, false])
),
Some(BooleanArray::from(vec![true, false]))
);

assert_eq!(
get_filtered_join_mask(
LeftSemi,
UInt64Array::from(vec![0, 0, 0, 1, 1, 1]),
&BooleanArray::from(vec![false, true, true, true, true, true])
),
Some(BooleanArray::from(vec![
false, true, false, true, false, false
]))
);

assert_eq!(
get_filtered_join_mask(
LeftSemi,
UInt64Array::from(vec![0, 0, 0, 1, 1, 1]),
&BooleanArray::from(vec![false, false, false, false, false, true])
),
Some(BooleanArray::from(vec![
false, false, false, false, false, true
]))
);

Ok(())
}

/// Returns the column names on the schema
fn columns(schema: &Schema) -> Vec<String> {
schema.fields().iter().map(|f| f.name().clone()).collect()
Expand Down

0 comments on commit 8e9b490

Please sign in to comment.