Skip to content

Commit

Permalink
Fix: Sort Merge Join LeftSemi issues when JoinFilter is set (#10304)
Browse files Browse the repository at this point in the history
* Fix: Sort Merge Join Left Semi crashes

Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
comphead and alamb authored May 20, 2024
1 parent b716c09 commit 94b5511
Show file tree
Hide file tree
Showing 2 changed files with 352 additions and 29 deletions.
230 changes: 201 additions & 29 deletions datafusion/physical-plan/src/joins/sort_merge_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,13 @@ use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};

use crate::expressions::PhysicalSortExpr;
use crate::joins::utils::{
build_join_schema, check_join_is_valid, estimate_join_statistics,
partitioned_join_output_partitioning, JoinFilter, JoinOn, JoinOnRef,
};
use crate::metrics::{ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
use crate::{
execution_mode_from_children, metrics, DisplayAs, DisplayFormatType, Distribution,
ExecutionPlan, ExecutionPlanProperties, PhysicalExpr, PlanProperties,
RecordBatchStream, SendableRecordBatchStream, Statistics,
};

use arrow::array::*;
use arrow::compute::{self, concat_batches, take, SortOptions};
use arrow::datatypes::{DataType, SchemaRef, TimeUnit};
use arrow::error::ArrowError;
use futures::{Stream, StreamExt};
use hashbrown::HashSet;

use datafusion_common::{
internal_err, not_impl_err, plan_err, DataFusionError, JoinSide, JoinType, Result,
};
Expand All @@ -54,7 +45,17 @@ use datafusion_execution::TaskContext;
use datafusion_physical_expr::equivalence::join_equivalence_properties;
use datafusion_physical_expr::{PhysicalExprRef, PhysicalSortRequirement};

use futures::{Stream, StreamExt};
use crate::expressions::PhysicalSortExpr;
use crate::joins::utils::{
build_join_schema, check_join_is_valid, estimate_join_statistics,
partitioned_join_output_partitioning, JoinFilter, JoinOn, JoinOnRef,
};
use crate::metrics::{ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
use crate::{
execution_mode_from_children, metrics, DisplayAs, DisplayFormatType, Distribution,
ExecutionPlan, ExecutionPlanProperties, PhysicalExpr, PlanProperties,
RecordBatchStream, SendableRecordBatchStream, Statistics,
};

/// join execution plan executes partitions in parallel and combines them into a set of
/// partitions.
Expand Down Expand Up @@ -491,6 +492,10 @@ struct StreamedBatch {
pub output_indices: Vec<StreamedJoinedChunk>,
/// Index of currently scanned batch from buffered data
pub buffered_batch_idx: Option<usize>,
/// Indices that found a match for the given join filter
/// Used for semi joins to keep track the streaming index which got a join filter match
/// and already emitted to the output.
pub join_filter_matched_idxs: HashSet<u64>,
}

impl StreamedBatch {
Expand All @@ -502,6 +507,7 @@ impl StreamedBatch {
join_arrays,
output_indices: vec![],
buffered_batch_idx: None,
join_filter_matched_idxs: HashSet::new(),
}
}

Expand All @@ -512,6 +518,7 @@ impl StreamedBatch {
join_arrays: vec![],
output_indices: vec![],
buffered_batch_idx: None,
join_filter_matched_idxs: HashSet::new(),
}
}

Expand Down Expand Up @@ -990,7 +997,22 @@ impl SMJStream {
}
Ordering::Equal => {
if matches!(self.join_type, JoinType::LeftSemi) {
join_streamed = !self.streamed_joined;
// if the join filter is specified then its needed to output the streamed index
// only if it has not been emitted before
// the `join_filter_matched_idxs` keeps track on if streamed index has a successful
// filter match and prevents the same index to go into output more than once
if self.filter.is_some() {
join_streamed = !self
.streamed_batch
.join_filter_matched_idxs
.contains(&(self.streamed_batch.idx as u64))
&& !self.streamed_joined;
// if the join filter specified there can be references to buffered columns
// so buffered columns are needed to access them
join_buffered = join_streamed;
} else {
join_streamed = !self.streamed_joined;
}
}
if matches!(
self.join_type,
Expand Down Expand Up @@ -1134,17 +1156,15 @@ impl SMJStream {
.collect::<Result<Vec<_>, ArrowError>>()?;

let buffered_indices: UInt64Array = chunk.buffered_indices.finish();

let mut buffered_columns =
if matches!(self.join_type, JoinType::LeftSemi | JoinType::LeftAnti) {
vec![]
} else if let Some(buffered_idx) = chunk.buffered_batch_idx {
self.buffered_data.batches[buffered_idx]
.batch
.columns()
.iter()
.map(|column| take(column, &buffered_indices, None))
.collect::<Result<Vec<_>, ArrowError>>()?
get_buffered_columns(
&self.buffered_data,
buffered_idx,
&buffered_indices,
)?
} else {
self.buffered_schema
.fields()
Expand All @@ -1161,6 +1181,15 @@ impl SMJStream {
let filter_columns = if chunk.buffered_batch_idx.is_some() {
if matches!(self.join_type, JoinType::Right) {
get_filter_column(&self.filter, &buffered_columns, &streamed_columns)
} else if matches!(self.join_type, JoinType::LeftSemi) {
// unwrap is safe here as we check is_some on top of if statement
let buffered_columns = get_buffered_columns(
&self.buffered_data,
chunk.buffered_batch_idx.unwrap(),
&buffered_indices,
)?;

get_filter_column(&self.filter, &streamed_columns, &buffered_columns)
} else {
get_filter_column(&self.filter, &streamed_columns, &buffered_columns)
}
Expand Down Expand Up @@ -1195,7 +1224,17 @@ impl SMJStream {
.into_array(filter_batch.num_rows())?;

// The selection mask of the filter
let mask = datafusion_common::cast::as_boolean_array(&filter_result)?;
let mut mask =
datafusion_common::cast::as_boolean_array(&filter_result)?;

let maybe_filtered_join_mask: Option<(BooleanArray, Vec<u64>)> =
get_filtered_join_mask(self.join_type, streamed_indices, mask);
if let Some(ref filtered_join_mask) = maybe_filtered_join_mask {
mask = &filtered_join_mask.0;
self.streamed_batch
.join_filter_matched_idxs
.extend(&filtered_join_mask.1);
}

// Push the filtered batch to the output
let filtered_batch =
Expand Down Expand Up @@ -1365,6 +1404,69 @@ fn get_filter_column(
filter_columns
}

/// Get `buffered_indices` rows for `buffered_data[buffered_batch_idx]`
#[inline(always)]
fn get_buffered_columns(
buffered_data: &BufferedData,
buffered_batch_idx: usize,
buffered_indices: &UInt64Array,
) -> Result<Vec<ArrayRef>, ArrowError> {
buffered_data.batches[buffered_batch_idx]
.batch
.columns()
.iter()
.map(|column| take(column, &buffered_indices, None))
.collect::<Result<Vec<_>, ArrowError>>()
}

// Calculate join filter bit mask considering join type specifics
// `streamed_indices` - array of streamed datasource JOINED row indices
// `mask` - array booleans representing computed join filter expression eval result:
// true = the row index matches the join filter
// false = the row index doesn't match the join filter
// `streamed_indices` have the same length as `mask`
fn get_filtered_join_mask(
join_type: JoinType,
streamed_indices: UInt64Array,
mask: &BooleanArray,
) -> Option<(BooleanArray, Vec<u64>)> {
// for LeftSemi Join the filter mask should be calculated in its own way:
// if we find at least one matching row for specific streaming index
// we don't need to check any others for the same index
if matches!(join_type, JoinType::LeftSemi) {
// have we seen a filter match for a streaming index before
let mut seen_as_true: bool = false;
let streamed_indices_length = streamed_indices.len();
let mut corrected_mask: BooleanBuilder =
BooleanBuilder::with_capacity(streamed_indices_length);

let mut filter_matched_indices: Vec<u64> = vec![];

#[allow(clippy::needless_range_loop)]
for i in 0..streamed_indices_length {
// LeftSemi respects only first true values for specific streaming index,
// others true values for the same index must be false
if mask.value(i) && !seen_as_true {
seen_as_true = true;
corrected_mask.append_value(true);
filter_matched_indices.push(streamed_indices.value(i));
} else {
corrected_mask.append_value(false);
}

// if switched to next streaming index(e.g. from 0 to 1, or from 1 to 2), we reset seen_as_true flag
if i < streamed_indices_length - 1
&& streamed_indices.value(i) != streamed_indices.value(i + 1)
{
seen_as_true = false;
}
}
Some((corrected_mask.finish(), filter_matched_indices))
} else {
None
}
}

/// Buffered data contains all buffered batches with one unique join key
#[derive(Debug, Default)]
struct BufferedData {
Expand Down Expand Up @@ -1604,24 +1706,28 @@ fn is_join_arrays_equal(
mod tests {
use std::sync::Arc;

use crate::expressions::Column;
use crate::joins::utils::JoinOn;
use crate::joins::SortMergeJoinExec;
use crate::memory::MemoryExec;
use crate::test::build_table_i32;
use crate::{common, ExecutionPlan};

use arrow::array::{Date32Array, Date64Array, Int32Array};
use arrow::compute::SortOptions;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use arrow_array::{BooleanArray, UInt64Array};

use datafusion_common::JoinType::LeftSemi;
use datafusion_common::{
assert_batches_eq, assert_batches_sorted_eq, assert_contains, JoinType, Result,
};
use datafusion_execution::config::SessionConfig;
use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv};
use datafusion_execution::TaskContext;

use crate::expressions::Column;
use crate::joins::sort_merge_join::get_filtered_join_mask;
use crate::joins::utils::JoinOn;
use crate::joins::SortMergeJoinExec;
use crate::memory::MemoryExec;
use crate::test::build_table_i32;
use crate::{common, ExecutionPlan};

fn build_table(
a: (&str, &Vec<i32>),
b: (&str, &Vec<i32>),
Expand Down Expand Up @@ -2641,6 +2747,72 @@ mod tests {

Ok(())
}

#[tokio::test]
async fn left_semi_join_filtered_mask() -> Result<()> {
assert_eq!(
get_filtered_join_mask(
LeftSemi,
UInt64Array::from(vec![0, 0, 1, 1]),
&BooleanArray::from(vec![true, true, false, false])
),
Some((BooleanArray::from(vec![true, false, false, false]), vec![0]))
);

assert_eq!(
get_filtered_join_mask(
LeftSemi,
UInt64Array::from(vec![0, 1]),
&BooleanArray::from(vec![true, true])
),
Some((BooleanArray::from(vec![true, true]), vec![0, 1]))
);

assert_eq!(
get_filtered_join_mask(
LeftSemi,
UInt64Array::from(vec![0, 1]),
&BooleanArray::from(vec![false, true])
),
Some((BooleanArray::from(vec![false, true]), vec![1]))
);

assert_eq!(
get_filtered_join_mask(
LeftSemi,
UInt64Array::from(vec![0, 1]),
&BooleanArray::from(vec![true, false])
),
Some((BooleanArray::from(vec![true, false]), vec![0]))
);

assert_eq!(
get_filtered_join_mask(
LeftSemi,
UInt64Array::from(vec![0, 0, 0, 1, 1, 1]),
&BooleanArray::from(vec![false, true, true, true, true, true])
),
Some((
BooleanArray::from(vec![false, true, false, true, false, false]),
vec![0, 1]
))
);

assert_eq!(
get_filtered_join_mask(
LeftSemi,
UInt64Array::from(vec![0, 0, 0, 1, 1, 1]),
&BooleanArray::from(vec![false, false, false, false, false, true])
),
Some((
BooleanArray::from(vec![false, false, false, false, false, true]),
vec![1]
))
);

Ok(())
}

/// Returns the column names on the schema
fn columns(schema: &Schema) -> Vec<String> {
schema.fields().iter().map(|f| f.name().clone()).collect()
Expand Down
Loading

0 comments on commit 94b5511

Please sign in to comment.