From c3345a5751f6f81f4fe6026deb7c42f37c2a4d68 Mon Sep 17 00:00:00 2001 From: Oleks V Date: Mon, 20 May 2024 14:56:58 -0700 Subject: [PATCH] Fix: Sort Merge Join LeftSemi issues when JoinFilter is set (#10304) * Fix: Sort Merge Join Left Semi crashes Co-authored-by: Andrew Lamb --- .../src/joins/sort_merge_join.rs | 230 +++++++++++++++--- .../test_files/sort_merge_join.slt | 151 ++++++++++++ 2 files changed, 352 insertions(+), 29 deletions(-) diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index d4cf6864d7e49..1cc7bf4700d1f 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -30,22 +30,13 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use crate::expressions::PhysicalSortExpr; -use crate::joins::utils::{ - build_join_schema, check_join_is_valid, estimate_join_statistics, - partitioned_join_output_partitioning, JoinFilter, JoinOn, JoinOnRef, -}; -use crate::metrics::{ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; -use crate::{ - execution_mode_from_children, metrics, DisplayAs, DisplayFormatType, Distribution, - ExecutionPlan, ExecutionPlanProperties, PhysicalExpr, PlanProperties, - RecordBatchStream, SendableRecordBatchStream, Statistics, -}; - use arrow::array::*; use arrow::compute::{self, concat_batches, take, SortOptions}; use arrow::datatypes::{DataType, SchemaRef, TimeUnit}; use arrow::error::ArrowError; +use futures::{Stream, StreamExt}; +use hashbrown::HashSet; + use datafusion_common::{ internal_err, not_impl_err, plan_err, DataFusionError, JoinSide, JoinType, Result, }; @@ -54,7 +45,17 @@ use datafusion_execution::TaskContext; use datafusion_physical_expr::equivalence::join_equivalence_properties; use datafusion_physical_expr::{PhysicalExprRef, PhysicalSortRequirement}; -use futures::{Stream, StreamExt}; +use crate::expressions::PhysicalSortExpr; +use crate::joins::utils::{ + build_join_schema, check_join_is_valid, estimate_join_statistics, + partitioned_join_output_partitioning, JoinFilter, JoinOn, JoinOnRef, +}; +use crate::metrics::{ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; +use crate::{ + execution_mode_from_children, metrics, DisplayAs, DisplayFormatType, Distribution, + ExecutionPlan, ExecutionPlanProperties, PhysicalExpr, PlanProperties, + RecordBatchStream, SendableRecordBatchStream, Statistics, +}; /// join execution plan executes partitions in parallel and combines them into a set of /// partitions. @@ -491,6 +492,10 @@ struct StreamedBatch { pub output_indices: Vec, /// Index of currently scanned batch from buffered data pub buffered_batch_idx: Option, + /// Indices that found a match for the given join filter + /// Used for semi joins to keep track the streaming index which got a join filter match + /// and already emitted to the output. + pub join_filter_matched_idxs: HashSet, } impl StreamedBatch { @@ -502,6 +507,7 @@ impl StreamedBatch { join_arrays, output_indices: vec![], buffered_batch_idx: None, + join_filter_matched_idxs: HashSet::new(), } } @@ -512,6 +518,7 @@ impl StreamedBatch { join_arrays: vec![], output_indices: vec![], buffered_batch_idx: None, + join_filter_matched_idxs: HashSet::new(), } } @@ -990,7 +997,22 @@ impl SMJStream { } Ordering::Equal => { if matches!(self.join_type, JoinType::LeftSemi) { - join_streamed = !self.streamed_joined; + // if the join filter is specified then its needed to output the streamed index + // only if it has not been emitted before + // the `join_filter_matched_idxs` keeps track on if streamed index has a successful + // filter match and prevents the same index to go into output more than once + if self.filter.is_some() { + join_streamed = !self + .streamed_batch + .join_filter_matched_idxs + .contains(&(self.streamed_batch.idx as u64)) + && !self.streamed_joined; + // if the join filter specified there can be references to buffered columns + // so buffered columns are needed to access them + join_buffered = join_streamed; + } else { + join_streamed = !self.streamed_joined; + } } if matches!( self.join_type, @@ -1134,17 +1156,15 @@ impl SMJStream { .collect::, ArrowError>>()?; let buffered_indices: UInt64Array = chunk.buffered_indices.finish(); - let mut buffered_columns = if matches!(self.join_type, JoinType::LeftSemi | JoinType::LeftAnti) { vec![] } else if let Some(buffered_idx) = chunk.buffered_batch_idx { - self.buffered_data.batches[buffered_idx] - .batch - .columns() - .iter() - .map(|column| take(column, &buffered_indices, None)) - .collect::, ArrowError>>()? + get_buffered_columns( + &self.buffered_data, + buffered_idx, + &buffered_indices, + )? } else { self.buffered_schema .fields() @@ -1161,6 +1181,15 @@ impl SMJStream { let filter_columns = if chunk.buffered_batch_idx.is_some() { if matches!(self.join_type, JoinType::Right) { get_filter_column(&self.filter, &buffered_columns, &streamed_columns) + } else if matches!(self.join_type, JoinType::LeftSemi) { + // unwrap is safe here as we check is_some on top of if statement + let buffered_columns = get_buffered_columns( + &self.buffered_data, + chunk.buffered_batch_idx.unwrap(), + &buffered_indices, + )?; + + get_filter_column(&self.filter, &streamed_columns, &buffered_columns) } else { get_filter_column(&self.filter, &streamed_columns, &buffered_columns) } @@ -1195,7 +1224,17 @@ impl SMJStream { .into_array(filter_batch.num_rows())?; // The selection mask of the filter - let mask = datafusion_common::cast::as_boolean_array(&filter_result)?; + let mut mask = + datafusion_common::cast::as_boolean_array(&filter_result)?; + + let maybe_filtered_join_mask: Option<(BooleanArray, Vec)> = + get_filtered_join_mask(self.join_type, streamed_indices, mask); + if let Some(ref filtered_join_mask) = maybe_filtered_join_mask { + mask = &filtered_join_mask.0; + self.streamed_batch + .join_filter_matched_idxs + .extend(&filtered_join_mask.1); + } // Push the filtered batch to the output let filtered_batch = @@ -1365,6 +1404,69 @@ fn get_filter_column( filter_columns } +/// Get `buffered_indices` rows for `buffered_data[buffered_batch_idx]` +#[inline(always)] +fn get_buffered_columns( + buffered_data: &BufferedData, + buffered_batch_idx: usize, + buffered_indices: &UInt64Array, +) -> Result, ArrowError> { + buffered_data.batches[buffered_batch_idx] + .batch + .columns() + .iter() + .map(|column| take(column, &buffered_indices, None)) + .collect::, ArrowError>>() +} + +// Calculate join filter bit mask considering join type specifics +// `streamed_indices` - array of streamed datasource JOINED row indices +// `mask` - array booleans representing computed join filter expression eval result: +// true = the row index matches the join filter +// false = the row index doesn't match the join filter +// `streamed_indices` have the same length as `mask` +fn get_filtered_join_mask( + join_type: JoinType, + streamed_indices: UInt64Array, + mask: &BooleanArray, +) -> Option<(BooleanArray, Vec)> { + // for LeftSemi Join the filter mask should be calculated in its own way: + // if we find at least one matching row for specific streaming index + // we don't need to check any others for the same index + if matches!(join_type, JoinType::LeftSemi) { + // have we seen a filter match for a streaming index before + let mut seen_as_true: bool = false; + let streamed_indices_length = streamed_indices.len(); + let mut corrected_mask: BooleanBuilder = + BooleanBuilder::with_capacity(streamed_indices_length); + + let mut filter_matched_indices: Vec = vec![]; + + #[allow(clippy::needless_range_loop)] + for i in 0..streamed_indices_length { + // LeftSemi respects only first true values for specific streaming index, + // others true values for the same index must be false + if mask.value(i) && !seen_as_true { + seen_as_true = true; + corrected_mask.append_value(true); + filter_matched_indices.push(streamed_indices.value(i)); + } else { + corrected_mask.append_value(false); + } + + // if switched to next streaming index(e.g. from 0 to 1, or from 1 to 2), we reset seen_as_true flag + if i < streamed_indices_length - 1 + && streamed_indices.value(i) != streamed_indices.value(i + 1) + { + seen_as_true = false; + } + } + Some((corrected_mask.finish(), filter_matched_indices)) + } else { + None + } +} + /// Buffered data contains all buffered batches with one unique join key #[derive(Debug, Default)] struct BufferedData { @@ -1604,17 +1706,13 @@ fn is_join_arrays_equal( mod tests { use std::sync::Arc; - use crate::expressions::Column; - use crate::joins::utils::JoinOn; - use crate::joins::SortMergeJoinExec; - use crate::memory::MemoryExec; - use crate::test::build_table_i32; - use crate::{common, ExecutionPlan}; - use arrow::array::{Date32Array, Date64Array, Int32Array}; use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; + use arrow_array::{BooleanArray, UInt64Array}; + + use datafusion_common::JoinType::LeftSemi; use datafusion_common::{ assert_batches_eq, assert_batches_sorted_eq, assert_contains, JoinType, Result, }; @@ -1622,6 +1720,14 @@ mod tests { use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion_execution::TaskContext; + use crate::expressions::Column; + use crate::joins::sort_merge_join::get_filtered_join_mask; + use crate::joins::utils::JoinOn; + use crate::joins::SortMergeJoinExec; + use crate::memory::MemoryExec; + use crate::test::build_table_i32; + use crate::{common, ExecutionPlan}; + fn build_table( a: (&str, &Vec), b: (&str, &Vec), @@ -2641,6 +2747,72 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn left_semi_join_filtered_mask() -> Result<()> { + assert_eq!( + get_filtered_join_mask( + LeftSemi, + UInt64Array::from(vec![0, 0, 1, 1]), + &BooleanArray::from(vec![true, true, false, false]) + ), + Some((BooleanArray::from(vec![true, false, false, false]), vec![0])) + ); + + assert_eq!( + get_filtered_join_mask( + LeftSemi, + UInt64Array::from(vec![0, 1]), + &BooleanArray::from(vec![true, true]) + ), + Some((BooleanArray::from(vec![true, true]), vec![0, 1])) + ); + + assert_eq!( + get_filtered_join_mask( + LeftSemi, + UInt64Array::from(vec![0, 1]), + &BooleanArray::from(vec![false, true]) + ), + Some((BooleanArray::from(vec![false, true]), vec![1])) + ); + + assert_eq!( + get_filtered_join_mask( + LeftSemi, + UInt64Array::from(vec![0, 1]), + &BooleanArray::from(vec![true, false]) + ), + Some((BooleanArray::from(vec![true, false]), vec![0])) + ); + + assert_eq!( + get_filtered_join_mask( + LeftSemi, + UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), + &BooleanArray::from(vec![false, true, true, true, true, true]) + ), + Some(( + BooleanArray::from(vec![false, true, false, true, false, false]), + vec![0, 1] + )) + ); + + assert_eq!( + get_filtered_join_mask( + LeftSemi, + UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), + &BooleanArray::from(vec![false, false, false, false, false, true]) + ), + Some(( + BooleanArray::from(vec![false, false, false, false, false, true]), + vec![1] + )) + ); + + Ok(()) + } + /// Returns the column names on the schema fn columns(schema: &Schema) -> Vec { schema.fields().iter().map(|f| f.name().clone()).collect() diff --git a/datafusion/sqllogictest/test_files/sort_merge_join.slt b/datafusion/sqllogictest/test_files/sort_merge_join.slt index 7b7e355fa2b52..3a27d9693d00f 100644 --- a/datafusion/sqllogictest/test_files/sort_merge_join.slt +++ b/datafusion/sqllogictest/test_files/sort_merge_join.slt @@ -263,6 +263,139 @@ DROP TABLE t1; statement ok DROP TABLE t2; +# LEFTSEMI join tests + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b union all + select 11 a, 13 b), +t2 as ( + select 11 a, 12 b union all + select 11 a, 13 b + ) + select t1.* from t1 where exists (select 1 from t2 where t2.a = t1.a and t2.b = t1.b) +) order by 1, 2 +---- +11 12 +11 13 + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b union all + select 11 a, 13 b), +t2 as ( + select 11 a, 12 b union all + select 11 a, 13 b + ) + select t1.* from t1 where exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b) +) order by 1, 2; +---- +11 12 +11 13 + +query II +select * from ( +with +t1 as ( + select null a, 12 b union all + select 11 a, 13 b), +t2 as ( + select 11 a, 12 b union all + select 11 a, 13 b + ) + select t1.* from t1 where exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b) +) order by 1, 2; +---- +11 13 + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b union all + select 11 a, 13 b) + select t1.* from t1 where exists (select 1 from t1 t2 where t2.a = t1.a and t2.b = t1.b) +) order by 1, 2; +---- +11 12 +11 13 + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b union all + select 11 a, 13 b) + select t1.* from t1 where exists (select 1 from t1 t2 where t2.a = t1.a and t2.b != t1.b) +) order by 1, 2; +---- +11 12 +11 13 + +query II +select * from ( +with +t1 as ( + select null a, 12 b union all + select 11 a, 13 b) + select t1.* from t1 where exists (select 1 from t1 t2 where t2.a = t1.a and t2.b != t1.b) +) order by 1, 2; + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b union all + select 11 a, 13 b), +t2 as ( + select 11 a, 12 b union all + select 11 a, 14 b + ) +select t1.* from t1 where exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b) +) order by 1, 2; +---- +11 12 +11 13 + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b union all + select 11 a, 13 b), +t2 as ( + select 11 a, 12 b union all + select 11 a, 12 b union all + select 11 a, 14 b + ) +select t1.* from t1 where exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b) +) order by 1, 2; +---- +11 12 +11 13 + +#LEFTANTI tests +# returns no rows instead of correct result +#query III +#select * from ( +#with +#t1 as ( +# select 11 a, 12 b, 1 c union all +# select 11 a, 13 b, 2 c), +#t2 as ( +# select 11 a, 12 b, 3 c union all +# select 11 a, 14 b, 4 c +# ) +#select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b and t1.c > t2.c) +#) order by 1, 2; +#---- +#11 12 1 +#11 13 2 + # Set batch size to 1 for sort merge join to test scenario when data spread across multiple batches statement ok set datafusion.execution.batch_size = 1; @@ -280,5 +413,23 @@ SELECT * FROM ( ) ORDER BY 1, 2; ---- + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b union all + select 11 a, 13 b), +t2 as ( + select 11 a, 12 b union all + select 11 a, 12 b union all + select 11 a, 14 b + ) +select t1.* from t1 where exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b) +) order by 1, 2; +---- +11 12 +11 13 + statement ok set datafusion.optimizer.prefer_hash_join = true;