-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix: Sort Merge Join LeftSemi issues when JoinFilter is set #10304
Changes from all commits
ca564ce
fd21ccf
64d7e5c
8c6010e
f9e1133
ed0035b
4c2c8f3
4052b0d
9da2c45
9c71eef
fe0bb60
c0fd73e
f993b3c
1354f83
c129846
22c61fc
30f28fe
823f396
f0e60da
a06acaa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,22 +30,13 @@ use std::pin::Pin; | |
use std::sync::Arc; | ||
use std::task::{Context, Poll}; | ||
|
||
use crate::expressions::PhysicalSortExpr; | ||
use crate::joins::utils::{ | ||
build_join_schema, check_join_is_valid, estimate_join_statistics, | ||
partitioned_join_output_partitioning, JoinFilter, JoinOn, JoinOnRef, | ||
}; | ||
use crate::metrics::{ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; | ||
use crate::{ | ||
execution_mode_from_children, metrics, DisplayAs, DisplayFormatType, Distribution, | ||
ExecutionPlan, ExecutionPlanProperties, PhysicalExpr, PlanProperties, | ||
RecordBatchStream, SendableRecordBatchStream, Statistics, | ||
}; | ||
|
||
use arrow::array::*; | ||
use arrow::compute::{self, concat_batches, take, SortOptions}; | ||
use arrow::datatypes::{DataType, SchemaRef, TimeUnit}; | ||
use arrow::error::ArrowError; | ||
use futures::{Stream, StreamExt}; | ||
use hashbrown::HashSet; | ||
|
||
use datafusion_common::{ | ||
internal_err, not_impl_err, plan_err, DataFusionError, JoinSide, JoinType, Result, | ||
}; | ||
|
@@ -54,7 +45,17 @@ use datafusion_execution::TaskContext; | |
use datafusion_physical_expr::equivalence::join_equivalence_properties; | ||
use datafusion_physical_expr::{PhysicalExprRef, PhysicalSortRequirement}; | ||
|
||
use futures::{Stream, StreamExt}; | ||
use crate::expressions::PhysicalSortExpr; | ||
use crate::joins::utils::{ | ||
build_join_schema, check_join_is_valid, estimate_join_statistics, | ||
partitioned_join_output_partitioning, JoinFilter, JoinOn, JoinOnRef, | ||
}; | ||
use crate::metrics::{ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; | ||
use crate::{ | ||
execution_mode_from_children, metrics, DisplayAs, DisplayFormatType, Distribution, | ||
ExecutionPlan, ExecutionPlanProperties, PhysicalExpr, PlanProperties, | ||
RecordBatchStream, SendableRecordBatchStream, Statistics, | ||
}; | ||
|
||
/// join execution plan executes partitions in parallel and combines them into a set of | ||
/// partitions. | ||
|
@@ -491,6 +492,10 @@ struct StreamedBatch { | |
pub output_indices: Vec<StreamedJoinedChunk>, | ||
/// Index of currently scanned batch from buffered data | ||
pub buffered_batch_idx: Option<usize>, | ||
/// Indices that found a match for the given join filter | ||
/// Used for semi joins to keep track the streaming index which got a join filter match | ||
/// and already emitted to the output. | ||
pub join_filter_matched_idxs: HashSet<u64>, | ||
} | ||
|
||
impl StreamedBatch { | ||
|
@@ -502,6 +507,7 @@ impl StreamedBatch { | |
join_arrays, | ||
output_indices: vec![], | ||
buffered_batch_idx: None, | ||
join_filter_matched_idxs: HashSet::new(), | ||
} | ||
} | ||
|
||
|
@@ -512,6 +518,7 @@ impl StreamedBatch { | |
join_arrays: vec![], | ||
output_indices: vec![], | ||
buffered_batch_idx: None, | ||
join_filter_matched_idxs: HashSet::new(), | ||
} | ||
} | ||
|
||
|
@@ -990,7 +997,22 @@ impl SMJStream { | |
} | ||
Ordering::Equal => { | ||
if matches!(self.join_type, JoinType::LeftSemi) { | ||
join_streamed = !self.streamed_joined; | ||
// if the join filter is specified then its needed to output the streamed index | ||
// only if it has not been emitted before | ||
// the `join_filter_matched_idxs` keeps track on if streamed index has a successful | ||
// filter match and prevents the same index to go into output more than once | ||
if self.filter.is_some() { | ||
join_streamed = !self | ||
.streamed_batch | ||
.join_filter_matched_idxs | ||
.contains(&(self.streamed_batch.idx as u64)) | ||
&& !self.streamed_joined; | ||
// if the join filter specified there can be references to buffered columns | ||
// so buffered columns are needed to access them | ||
join_buffered = join_streamed; | ||
} else { | ||
join_streamed = !self.streamed_joined; | ||
} | ||
} | ||
if matches!( | ||
self.join_type, | ||
|
@@ -1134,17 +1156,15 @@ impl SMJStream { | |
.collect::<Result<Vec<_>, ArrowError>>()?; | ||
|
||
let buffered_indices: UInt64Array = chunk.buffered_indices.finish(); | ||
|
||
comphead marked this conversation as resolved.
Show resolved
Hide resolved
|
||
let mut buffered_columns = | ||
if matches!(self.join_type, JoinType::LeftSemi | JoinType::LeftAnti) { | ||
vec![] | ||
} else if let Some(buffered_idx) = chunk.buffered_batch_idx { | ||
self.buffered_data.batches[buffered_idx] | ||
.batch | ||
.columns() | ||
.iter() | ||
.map(|column| take(column, &buffered_indices, None)) | ||
.collect::<Result<Vec<_>, ArrowError>>()? | ||
get_buffered_columns( | ||
&self.buffered_data, | ||
buffered_idx, | ||
&buffered_indices, | ||
)? | ||
} else { | ||
self.buffered_schema | ||
.fields() | ||
|
@@ -1161,6 +1181,15 @@ impl SMJStream { | |
let filter_columns = if chunk.buffered_batch_idx.is_some() { | ||
if matches!(self.join_type, JoinType::Right) { | ||
get_filter_column(&self.filter, &buffered_columns, &streamed_columns) | ||
} else if matches!(self.join_type, JoinType::LeftSemi) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if this should also check for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have more to go, LeftAnti is first as it prevents TPCH to run and then double check RightSemi as well, good point |
||
// unwrap is safe here as we check is_some on top of if statement | ||
let buffered_columns = get_buffered_columns( | ||
&self.buffered_data, | ||
chunk.buffered_batch_idx.unwrap(), | ||
&buffered_indices, | ||
)?; | ||
|
||
get_filter_column(&self.filter, &streamed_columns, &buffered_columns) | ||
} else { | ||
get_filter_column(&self.filter, &streamed_columns, &buffered_columns) | ||
} | ||
|
@@ -1195,7 +1224,17 @@ impl SMJStream { | |
.into_array(filter_batch.num_rows())?; | ||
|
||
// The selection mask of the filter | ||
let mask = datafusion_common::cast::as_boolean_array(&filter_result)?; | ||
let mut mask = | ||
datafusion_common::cast::as_boolean_array(&filter_result)?; | ||
|
||
let maybe_filtered_join_mask: Option<(BooleanArray, Vec<u64>)> = | ||
get_filtered_join_mask(self.join_type, streamed_indices, mask); | ||
if let Some(ref filtered_join_mask) = maybe_filtered_join_mask { | ||
mask = &filtered_join_mask.0; | ||
self.streamed_batch | ||
.join_filter_matched_idxs | ||
.extend(&filtered_join_mask.1); | ||
} | ||
|
||
// Push the filtered batch to the output | ||
let filtered_batch = | ||
|
@@ -1365,6 +1404,69 @@ fn get_filter_column( | |
filter_columns | ||
} | ||
|
||
/// Get `buffered_indices` rows for `buffered_data[buffered_batch_idx]` | ||
#[inline(always)] | ||
fn get_buffered_columns( | ||
buffered_data: &BufferedData, | ||
buffered_batch_idx: usize, | ||
buffered_indices: &UInt64Array, | ||
) -> Result<Vec<ArrayRef>, ArrowError> { | ||
buffered_data.batches[buffered_batch_idx] | ||
.batch | ||
.columns() | ||
.iter() | ||
.map(|column| take(column, &buffered_indices, None)) | ||
.collect::<Result<Vec<_>, ArrowError>>() | ||
} | ||
|
||
// Calculate join filter bit mask considering join type specifics | ||
comphead marked this conversation as resolved.
Show resolved
Hide resolved
|
||
// `streamed_indices` - array of streamed datasource JOINED row indices | ||
// `mask` - array booleans representing computed join filter expression eval result: | ||
// true = the row index matches the join filter | ||
// false = the row index doesn't match the join filter | ||
// `streamed_indices` have the same length as `mask` | ||
fn get_filtered_join_mask( | ||
join_type: JoinType, | ||
streamed_indices: UInt64Array, | ||
mask: &BooleanArray, | ||
) -> Option<(BooleanArray, Vec<u64>)> { | ||
// for LeftSemi Join the filter mask should be calculated in its own way: | ||
// if we find at least one matching row for specific streaming index | ||
// we don't need to check any others for the same index | ||
if matches!(join_type, JoinType::LeftSemi) { | ||
comphead marked this conversation as resolved.
Show resolved
Hide resolved
|
||
// have we seen a filter match for a streaming index before | ||
let mut seen_as_true: bool = false; | ||
let streamed_indices_length = streamed_indices.len(); | ||
let mut corrected_mask: BooleanBuilder = | ||
BooleanBuilder::with_capacity(streamed_indices_length); | ||
|
||
let mut filter_matched_indices: Vec<u64> = vec![]; | ||
|
||
#[allow(clippy::needless_range_loop)] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder why ignore clippy here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Clippy doesn't like for loops anymore .... |
||
for i in 0..streamed_indices_length { | ||
// LeftSemi respects only first true values for specific streaming index, | ||
// others true values for the same index must be false | ||
if mask.value(i) && !seen_as_true { | ||
seen_as_true = true; | ||
corrected_mask.append_value(true); | ||
filter_matched_indices.push(streamed_indices.value(i)); | ||
viirya marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} else { | ||
corrected_mask.append_value(false); | ||
} | ||
|
||
// if switched to next streaming index(e.g. from 0 to 1, or from 1 to 2), we reset seen_as_true flag | ||
if i < streamed_indices_length - 1 | ||
&& streamed_indices.value(i) != streamed_indices.value(i + 1) | ||
{ | ||
seen_as_true = false; | ||
} | ||
} | ||
Some((corrected_mask.finish(), filter_matched_indices)) | ||
} else { | ||
None | ||
} | ||
} | ||
|
||
/// Buffered data contains all buffered batches with one unique join key | ||
#[derive(Debug, Default)] | ||
struct BufferedData { | ||
|
@@ -1604,24 +1706,28 @@ fn is_join_arrays_equal( | |
mod tests { | ||
use std::sync::Arc; | ||
|
||
use crate::expressions::Column; | ||
use crate::joins::utils::JoinOn; | ||
use crate::joins::SortMergeJoinExec; | ||
use crate::memory::MemoryExec; | ||
use crate::test::build_table_i32; | ||
use crate::{common, ExecutionPlan}; | ||
|
||
use arrow::array::{Date32Array, Date64Array, Int32Array}; | ||
use arrow::compute::SortOptions; | ||
use arrow::datatypes::{DataType, Field, Schema}; | ||
use arrow::record_batch::RecordBatch; | ||
use arrow_array::{BooleanArray, UInt64Array}; | ||
|
||
use datafusion_common::JoinType::LeftSemi; | ||
use datafusion_common::{ | ||
assert_batches_eq, assert_batches_sorted_eq, assert_contains, JoinType, Result, | ||
}; | ||
use datafusion_execution::config::SessionConfig; | ||
use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; | ||
use datafusion_execution::TaskContext; | ||
|
||
use crate::expressions::Column; | ||
use crate::joins::sort_merge_join::get_filtered_join_mask; | ||
use crate::joins::utils::JoinOn; | ||
use crate::joins::SortMergeJoinExec; | ||
use crate::memory::MemoryExec; | ||
use crate::test::build_table_i32; | ||
use crate::{common, ExecutionPlan}; | ||
|
||
fn build_table( | ||
a: (&str, &Vec<i32>), | ||
b: (&str, &Vec<i32>), | ||
|
@@ -2641,6 +2747,72 @@ mod tests { | |
|
||
Ok(()) | ||
} | ||
|
||
#[tokio::test] | ||
async fn left_semi_join_filtered_mask() -> Result<()> { | ||
assert_eq!( | ||
get_filtered_join_mask( | ||
LeftSemi, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe we should test a type other than LeftSemi as negative test coverage 🤔 |
||
UInt64Array::from(vec![0, 0, 1, 1]), | ||
&BooleanArray::from(vec![true, true, false, false]) | ||
), | ||
Some((BooleanArray::from(vec![true, false, false, false]), vec![0])) | ||
); | ||
|
||
assert_eq!( | ||
get_filtered_join_mask( | ||
LeftSemi, | ||
UInt64Array::from(vec![0, 1]), | ||
&BooleanArray::from(vec![true, true]) | ||
), | ||
Some((BooleanArray::from(vec![true, true]), vec![0, 1])) | ||
); | ||
|
||
assert_eq!( | ||
get_filtered_join_mask( | ||
LeftSemi, | ||
UInt64Array::from(vec![0, 1]), | ||
&BooleanArray::from(vec![false, true]) | ||
), | ||
Some((BooleanArray::from(vec![false, true]), vec![1])) | ||
); | ||
|
||
assert_eq!( | ||
get_filtered_join_mask( | ||
LeftSemi, | ||
UInt64Array::from(vec![0, 1]), | ||
&BooleanArray::from(vec![true, false]) | ||
), | ||
Some((BooleanArray::from(vec![true, false]), vec![0])) | ||
); | ||
|
||
assert_eq!( | ||
get_filtered_join_mask( | ||
LeftSemi, | ||
UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), | ||
&BooleanArray::from(vec![false, true, true, true, true, true]) | ||
), | ||
Some(( | ||
BooleanArray::from(vec![false, true, false, true, false, false]), | ||
vec![0, 1] | ||
)) | ||
); | ||
|
||
assert_eq!( | ||
get_filtered_join_mask( | ||
LeftSemi, | ||
UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), | ||
&BooleanArray::from(vec![false, false, false, false, false, true]) | ||
), | ||
Some(( | ||
BooleanArray::from(vec![false, false, false, false, false, true]), | ||
vec![1] | ||
)) | ||
); | ||
|
||
Ok(()) | ||
} | ||
|
||
/// Returns the column names on the schema | ||
fn columns(schema: &Schema) -> Vec<String> { | ||
schema.fields().iter().map(|f| f.name().clone()).collect() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this change is by formatter