From e06ddaeb8e2a1ec12055de7ce674cfcac98d2199 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 12 Jun 2024 10:01:03 -0700 Subject: [PATCH] fix: Fix the incorrect null joined rows for outer join with join filter --- datafusion/core/tests/sql/joins.rs | 44 ++++ .../src/joins/sort_merge_join.rs | 230 +++++++++++------- .../test_files/sort_merge_join.slt | 8 - 3 files changed, 181 insertions(+), 101 deletions(-) diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index fad9b94b01120..1e690b45a09e0 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -235,3 +235,47 @@ async fn join_change_in_planner_without_sort_not_allowed() -> Result<()> { } Ok(()) } + +#[tokio::test] +async fn test_smj_right_filtered() -> Result<()> { + let ctx: SessionContext = SessionContext::new(); + + let sql = "set datafusion.optimizer.prefer_hash_join = false;"; + let _ = ctx.sql(sql).await?.collect().await?; + + let sql = "set datafusion.execution.batch_size = 100"; + let _ = ctx.sql(sql).await?.collect().await?; + + let sql = " + select * from ( + with t as ( + select id, id % 5 id1 from (select unnest(range(0,10)) id) + ), t1 as ( + select id % 10 id, id + 2 id1 from (select unnest(range(0,10)) id) + ) + select * from t right join t1 on t.id1 = t1.id and t.id > t1.id1 + ) order by 1, 2, 3, 4 + "; + + let actual = ctx.sql(sql).await?.collect().await?; + + let expected: Vec<&str> = vec![ + "+----+-----+----+-----+", + "| id | id1 | id | id1 |", + "+----+-----+----+-----+", + "| 5 | 0 | 0 | 2 |", + "| 6 | 1 | 1 | 3 |", + "| 7 | 2 | 2 | 4 |", + "| 8 | 3 | 3 | 5 |", + "| 9 | 4 | 4 | 6 |", + "| | | 5 | 7 |", + "| | | 6 | 8 |", + "| | | 7 | 9 |", + "| | | 8 | 10 |", + "| | | 9 | 11 |", + "+----+-----+----+-----+", + ]; + datafusion_common::assert_batches_eq!(expected, &actual); + + Ok(()) +} diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 8da345cdfca6e..54c7965afe0b0 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -34,6 +34,7 @@ use arrow::array::*; use arrow::compute::{self, concat_batches, take, SortOptions}; use arrow::datatypes::{DataType, SchemaRef, TimeUnit}; use arrow::error::ArrowError; +use arrow_array::types::UInt64Type; use futures::{Stream, StreamExt}; use hashbrown::HashSet; @@ -476,6 +477,7 @@ struct StreamedJoinedChunk { /// Array builder for streamed indices streamed_indices: UInt64Builder, /// Array builder for buffered indices + /// This could contain nulls if the join is null-joined buffered_indices: UInt64Builder, } @@ -564,6 +566,9 @@ struct BufferedBatch { pub null_joined: Vec, /// Size estimation used for reserving / releasing memory pub size_estimation: usize, + /// The indices of buffered batch that failed the join filter. + /// When dequeuing the buffered batch, we need to produce null joined rows for these indices. + pub join_filter_failed_idxs: HashSet, } impl BufferedBatch { @@ -595,6 +600,7 @@ impl BufferedBatch { join_arrays, null_joined: vec![], size_estimation, + join_filter_failed_idxs: HashSet::new(), } } } @@ -852,6 +858,7 @@ impl SMJStream { // pop previous buffered batches while !self.buffered_data.batches.is_empty() { let head_batch = self.buffered_data.head_batch(); + // If the head batch is fully processed, dequeue it and produce output of it. if head_batch.range.end == head_batch.batch.num_rows() { self.freeze_dequeuing_buffered()?; if let Some(buffered_batch) = @@ -860,6 +867,8 @@ impl SMJStream { self.reservation.shrink(buffered_batch.size_estimation); } } else { + // If the head batch is not fully processed, break the loop. + // Streamed batch will be joined with the head batch in the next step. break; } } @@ -1055,7 +1064,7 @@ impl SMJStream { Some(scanning_idx), ); } else { - // Join nulls and buffered row + // Join nulls and buffered row for full join self.buffered_data .scanning_batch_mut() .null_joined @@ -1098,6 +1107,7 @@ impl SMJStream { // 2. freezes NULLs joined to dequeued buffered batch to "release" it fn freeze_dequeuing_buffered(&mut self) -> Result<()> { self.freeze_streamed()?; + // Only freeze and produce the first batch in buffered_data as the batch is fully processed self.freeze_buffered(1)?; Ok(()) } @@ -1114,33 +1124,29 @@ impl SMJStream { let buffered_indices = UInt64Array::from_iter_values( buffered_batch.null_joined.iter().map(|&index| index as u64), ); - if buffered_indices.is_empty() { - continue; + if let Some(record_batch) = produce_buffered_null_batch( + &self.schema, + &self.streamed_schema, + &buffered_indices, + buffered_batch, + )? { + self.output_record_batches.push(record_batch); } buffered_batch.null_joined.clear(); - // Take buffered (right) columns - let buffered_columns = buffered_batch - .batch - .columns() - .iter() - .map(|column| take(column, &buffered_indices, None)) - .collect::, ArrowError>>() - .map_err(Into::::into)?; - - // Create null streamed (left) columns - let mut streamed_columns = self - .streamed_schema - .fields() - .iter() - .map(|f| new_null_array(f.data_type(), buffered_indices.len())) - .collect::>(); - - streamed_columns.extend(buffered_columns); - let columns = streamed_columns; - - self.output_record_batches - .push(RecordBatch::try_new(self.schema.clone(), columns)?); + // For buffered rows which are joined with streamed side but failed on join filter + let buffered_indices = UInt64Array::from_iter_values( + buffered_batch.join_filter_failed_idxs.iter().copied(), + ); + if let Some(record_batch) = produce_buffered_null_batch( + &self.schema, + &self.streamed_schema, + &buffered_indices, + buffered_batch, + )? { + self.output_record_batches.push(record_batch); + } + buffered_batch.join_filter_failed_idxs.clear(); } Ok(()) } @@ -1149,6 +1155,7 @@ impl SMJStream { // for current streamed batch and clears staged output indices. fn freeze_streamed(&mut self) -> Result<()> { for chunk in self.streamed_batch.output_indices.iter_mut() { + // The row indices of joined streamed batch let streamed_indices = chunk.streamed_indices.finish(); if streamed_indices.is_empty() { @@ -1163,6 +1170,7 @@ impl SMJStream { .map(|column| take(column, &streamed_indices, None)) .collect::, ArrowError>>()?; + // The row indices of joined buffered batch let buffered_indices: UInt64Array = chunk.buffered_indices.finish(); let mut buffered_columns = if matches!(self.join_type, JoinType::LeftSemi | JoinType::LeftAnti) { @@ -1174,6 +1182,8 @@ impl SMJStream { &buffered_indices, )? } else { + // If buffered batch none, meaning it is null joined batch. + // We need to create null arrays for buffered columns to join with streamed rows. self.buffered_schema .fields() .iter() @@ -1205,7 +1215,8 @@ impl SMJStream { get_filter_column(&self.filter, &streamed_columns, &buffered_columns) } } else { - // This chunk is for null joined rows (outer join), we don't need to apply join filter. + // This chunk is totally for null joined rows (outer join), we don't need to apply join filter. + // Any join filter applied only on either streamed or buffered side will be pushed already. vec![] }; @@ -1238,10 +1249,11 @@ impl SMJStream { let mut mask = datafusion_common::cast::as_boolean_array(&filter_result)?; + // For certain join types, we need to adjust the initial mask to handle the join filter. let maybe_filtered_join_mask: Option<(BooleanArray, Vec)> = get_filtered_join_mask( self.join_type, - streamed_indices, + &streamed_indices, mask, &self.streamed_batch.join_filter_matched_idxs, &self.buffered_data.scanning_offset, @@ -1254,29 +1266,39 @@ impl SMJStream { .extend(&filtered_join_mask.1); } - // Push the filtered batch to the output + // Push the filtered batch which contains rows passing join filter to the output let filtered_batch = compute::filter_record_batch(&output_batch, mask)?; self.output_record_batches.push(filtered_batch); - // For outer joins, we need to push the null joined rows to the output. + // For outer joins, we need to push the null joined rows to the output if + // all joined rows are failed on the join filter. + // I.e., if all rows joined from a streamed row are failed with the join filter, + // we need to join it with nulls as buffered side. if matches!( self.join_type, JoinType::Left | JoinType::Right | JoinType::Full ) { - // The reverse of the selection mask. For the rows not pass join filter above, - // we need to join them (left or right) with null rows for outer joins. - let not_mask = if mask.null_count() > 0 { - // If the mask contains nulls, we need to use `prep_null_mask_filter` to - // handle the nulls in the mask as false to produce rows where the mask - // was null itself. - compute::not(&compute::prep_null_mask_filter(mask))? - } else { - compute::not(mask)? - }; + // We need to get the mask for row indices that the joined rows are failed + // on the join filter. I.e., for a row in streamed side, if all joined rows + // between it and all buffered rows are failed on the join filter, we need to + // output it with null columns from buffered side. For the mask here, it + // behaves like LeftAnti join. + let null_mask: BooleanArray = get_filtered_join_mask( + // Set a mask slot as true only if all joined rows of same streamed index + // are failed on the join filter. + // The masking behavior is like LeftAnti join. + JoinType::LeftAnti, + &streamed_indices, + mask, + &self.streamed_batch.join_filter_matched_idxs, + &self.buffered_data.scanning_offset, + ) + .unwrap() + .0; let null_joined_batch = - compute::filter_record_batch(&output_batch, ¬_mask)?; + compute::filter_record_batch(&output_batch, &null_mask)?; let mut buffered_columns = self .buffered_schema @@ -1313,51 +1335,37 @@ impl SMJStream { streamed_columns }; + // Push the streamed/buffered batch joined nulls to the output let null_joined_streamed_batch = RecordBatch::try_new(self.schema.clone(), columns.clone())?; self.output_record_batches.push(null_joined_streamed_batch); - // For full join, we also need to output the null joined rows from the buffered side + // For full join, we also need to output the null joined rows from the buffered side. + // Usually this is done by `freeze_buffered`. However, if a buffered row is joined with + // streamed side, it won't be outputted by `freeze_buffered`. + // We need to check if a buffered row is joined with streamed side and output. + // If it is joined with streamed side, but finally fails on the join filter, + // we need to output it with nulls as streamed side. if matches!(self.join_type, JoinType::Full) { - // Handle not mask for buffered side further. - // For buffered side, we want to output the rows that are not null joined with - // the streamed side. i.e. the rows that are not null in the `buffered_indices`. - let not_mask = if let Some(nulls) = buffered_indices.nulls() { - let mask = not_mask.values() & nulls.inner(); - BooleanArray::new(mask, None) - } else { - not_mask - }; - - let null_joined_batch = - compute::filter_record_batch(&output_batch, ¬_mask)?; - - let mut streamed_columns = self - .streamed_schema - .fields() - .iter() - .map(|f| { - new_null_array( - f.data_type(), - null_joined_batch.num_rows(), - ) - }) - .collect::>(); - - let buffered_columns = null_joined_batch - .columns() - .iter() - .skip(streamed_columns_length) - .cloned() - .collect::>(); - - streamed_columns.extend(buffered_columns); - - let null_joined_buffered_batch = RecordBatch::try_new( - self.schema.clone(), - streamed_columns, - )?; - self.output_record_batches.push(null_joined_buffered_batch); + for i in 0..mask.len() { + let buffered_batch = &mut self.buffered_data.batches + [chunk.buffered_batch_idx.unwrap()]; + let buffered_index = buffered_indices.value(i); + + if !mask.value(i) { + // For a buffered row that is joined with streamed side but failed on the join filter, + buffered_batch + .join_filter_failed_idxs + .insert(buffered_index); + } else if buffered_batch + .join_filter_failed_idxs + .contains(&buffered_index) + { + buffered_batch + .join_filter_failed_idxs + .remove(&buffered_index); + } + } } } } else { @@ -1422,6 +1430,38 @@ fn get_filter_column( filter_columns } +fn produce_buffered_null_batch( + schema: &SchemaRef, + streamed_schema: &SchemaRef, + buffered_indices: &PrimitiveArray, + buffered_batch: &BufferedBatch, +) -> Result> { + if buffered_indices.is_empty() { + return Ok(None); + } + + // Take buffered (right) columns + let buffered_columns = buffered_batch + .batch + .columns() + .iter() + .map(|column| take(column, &buffered_indices, None)) + .collect::, ArrowError>>() + .map_err(Into::::into)?; + + // Create null streamed (left) columns + let mut streamed_columns = streamed_schema + .fields() + .iter() + .map(|f| new_null_array(f.data_type(), buffered_indices.len())) + .collect::>(); + + streamed_columns.extend(buffered_columns); + let columns = streamed_columns; + + Ok(Some(RecordBatch::try_new(schema.clone(), columns)?)) +} + /// Get `buffered_indices` rows for `buffered_data[buffered_batch_idx]` #[inline(always)] fn get_buffered_columns( @@ -1445,9 +1485,13 @@ fn get_buffered_columns( /// `streamed_indices` have the same length as `mask` /// `matched_indices` array of streaming indices that already has a join filter match /// `scanning_buffered_offset` current buffered offset across batches +/// +/// This return a tuple of: +/// - corrected mask with respect to the join type +/// - indices of rows in streamed batch that have a join filter match fn get_filtered_join_mask( join_type: JoinType, - streamed_indices: UInt64Array, + streamed_indices: &UInt64Array, mask: &BooleanArray, matched_indices: &HashSet, scanning_buffered_offset: &usize, @@ -2808,7 +2852,7 @@ mod tests { assert_eq!( get_filtered_join_mask( LeftSemi, - UInt64Array::from(vec![0, 0, 1, 1]), + &UInt64Array::from(vec![0, 0, 1, 1]), &BooleanArray::from(vec![true, true, false, false]), &HashSet::new(), &0, @@ -2819,7 +2863,7 @@ mod tests { assert_eq!( get_filtered_join_mask( LeftSemi, - UInt64Array::from(vec![0, 1]), + &UInt64Array::from(vec![0, 1]), &BooleanArray::from(vec![true, true]), &HashSet::new(), &0, @@ -2830,7 +2874,7 @@ mod tests { assert_eq!( get_filtered_join_mask( LeftSemi, - UInt64Array::from(vec![0, 1]), + &UInt64Array::from(vec![0, 1]), &BooleanArray::from(vec![false, true]), &HashSet::new(), &0, @@ -2841,7 +2885,7 @@ mod tests { assert_eq!( get_filtered_join_mask( LeftSemi, - UInt64Array::from(vec![0, 1]), + &UInt64Array::from(vec![0, 1]), &BooleanArray::from(vec![true, false]), &HashSet::new(), &0, @@ -2852,7 +2896,7 @@ mod tests { assert_eq!( get_filtered_join_mask( LeftSemi, - UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), + &UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), &BooleanArray::from(vec![false, true, true, true, true, true]), &HashSet::new(), &0, @@ -2866,7 +2910,7 @@ mod tests { assert_eq!( get_filtered_join_mask( LeftSemi, - UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), + &UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), &BooleanArray::from(vec![false, false, false, false, false, true]), &HashSet::new(), &0, @@ -2885,7 +2929,7 @@ mod tests { assert_eq!( get_filtered_join_mask( LeftAnti, - UInt64Array::from(vec![0, 0, 1, 1]), + &UInt64Array::from(vec![0, 0, 1, 1]), &BooleanArray::from(vec![true, true, false, false]), &HashSet::new(), &0, @@ -2896,7 +2940,7 @@ mod tests { assert_eq!( get_filtered_join_mask( LeftAnti, - UInt64Array::from(vec![0, 1]), + &UInt64Array::from(vec![0, 1]), &BooleanArray::from(vec![true, true]), &HashSet::new(), &0, @@ -2907,7 +2951,7 @@ mod tests { assert_eq!( get_filtered_join_mask( LeftAnti, - UInt64Array::from(vec![0, 1]), + &UInt64Array::from(vec![0, 1]), &BooleanArray::from(vec![false, true]), &HashSet::new(), &0, @@ -2918,7 +2962,7 @@ mod tests { assert_eq!( get_filtered_join_mask( LeftAnti, - UInt64Array::from(vec![0, 1]), + &UInt64Array::from(vec![0, 1]), &BooleanArray::from(vec![true, false]), &HashSet::new(), &0, @@ -2929,7 +2973,7 @@ mod tests { assert_eq!( get_filtered_join_mask( LeftAnti, - UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), + &UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), &BooleanArray::from(vec![false, true, true, true, true, true]), &HashSet::new(), &0, @@ -2943,7 +2987,7 @@ mod tests { assert_eq!( get_filtered_join_mask( LeftAnti, - UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), + &UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), &BooleanArray::from(vec![false, false, false, false, false, true]), &HashSet::new(), &0, diff --git a/datafusion/sqllogictest/test_files/sort_merge_join.slt b/datafusion/sqllogictest/test_files/sort_merge_join.slt index b4deb43a728e5..d120d366ce83f 100644 --- a/datafusion/sqllogictest/test_files/sort_merge_join.slt +++ b/datafusion/sqllogictest/test_files/sort_merge_join.slt @@ -84,7 +84,6 @@ SELECT * FROM t1 LEFT JOIN t2 ON t1.a = t2.a AND t2.b * 50 <= t1.b Alice 100 Alice 1 Alice 100 Alice 2 Alice 50 Alice 1 -Alice 50 NULL NULL Bob 1 NULL NULL query TITI rowsort @@ -112,7 +111,6 @@ SELECT * FROM t1 RIGHT JOIN t2 ON t1.a = t2.a AND t2.b * 50 <= t1.b Alice 100 Alice 1 Alice 100 Alice 2 Alice 50 Alice 1 -NULL NULL Alice 2 query TITI rowsort SELECT * FROM t1 RIGHT JOIN t2 ON t1.a = t2.a AND t1.b > t2.b @@ -137,12 +135,9 @@ query TITI rowsort SELECT * FROM t1 FULL JOIN t2 ON t1.a = t2.a AND t2.b * 50 > t1.b ---- Alice 100 NULL NULL -Alice 100 NULL NULL Alice 50 Alice 2 -Alice 50 NULL NULL Bob 1 NULL NULL NULL NULL Alice 1 -NULL NULL Alice 1 NULL NULL Alice 2 query TITI rowsort @@ -151,10 +146,7 @@ SELECT * FROM t1 FULL JOIN t2 ON t1.a = t2.a AND t1.b > t2.b + 50 Alice 100 Alice 1 Alice 100 Alice 2 Alice 50 NULL NULL -Alice 50 NULL NULL Bob 1 NULL NULL -NULL NULL Alice 1 -NULL NULL Alice 2 statement ok DROP TABLE t1;