From be35c906b01a452fb0e519050c88b7062a4fbb8f Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 2 Feb 2024 17:56:37 -0800 Subject: [PATCH] Fix outer join --- .../src/joins/sort_merge_join.rs | 174 +++++++++++++----- datafusion/sqllogictest/test_files/join.slt | 21 +++ .../test_files/sort_merge_join.slt | 119 +++++++++++- 3 files changed, 268 insertions(+), 46 deletions(-) diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 2d196acdb31c..3f9c9f7393f2 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -1104,32 +1104,11 @@ impl SMJStream { .map(|f| new_null_array(f.data_type(), buffered_indices.len())) .collect::>(); - let filter_columns = - get_filter_column(&self.filter, &streamed_columns, &buffered_columns); - streamed_columns.extend(buffered_columns); let columns = streamed_columns; - let output_batch = RecordBatch::try_new(self.schema.clone(), columns)?; - - // Apply join filter if any - let output_batch = if let Some(f) = &self.filter { - // Construct batch with only filter columns - let filter_batch = - RecordBatch::try_new(Arc::new(f.schema().clone()), filter_columns)?; - - let filter_result = f - .expression() - .evaluate(&filter_batch)? - .into_array(filter_batch.num_rows())?; - let mask = datafusion_common::cast::as_boolean_array(&filter_result)?; - - compute::filter_record_batch(&output_batch, mask)? - } else { - output_batch - }; - - self.output_record_batches.push(output_batch); + self.output_record_batches + .push(RecordBatch::try_new(self.schema.clone(), columns)?); } Ok(()) } @@ -1172,40 +1151,138 @@ impl SMJStream { .collect::>() }; - let filter_columns = if matches!(self.join_type, JoinType::Right) { - get_filter_column(&self.filter, &buffered_columns, &streamed_columns) + let streamed_columns_length = streamed_columns.len(); + let buffered_columns_length = buffered_columns.len(); + + // Prepare the columns we apply join filter on later. + // Only for joined rows between streamed and buffered. + let filter_columns = if chunk.buffered_batch_idx.is_some() { + if matches!(self.join_type, JoinType::Right) { + get_filter_column(&self.filter, &buffered_columns, &streamed_columns) + } else { + get_filter_column(&self.filter, &streamed_columns, &buffered_columns) + } } else { - get_filter_column(&self.filter, &streamed_columns, &buffered_columns) + vec![] }; let columns = if matches!(self.join_type, JoinType::Right) { - buffered_columns.extend(streamed_columns); + buffered_columns.extend(streamed_columns.clone()); buffered_columns } else { streamed_columns.extend(buffered_columns); streamed_columns }; - let output_batch = RecordBatch::try_new(self.schema.clone(), columns)?; + let output_batch = + RecordBatch::try_new(self.schema.clone(), columns.clone())?; // Apply join filter if any - let output_batch = if let Some(f) = &self.filter { - // Construct batch with only filter columns - let filter_batch = - RecordBatch::try_new(Arc::new(f.schema().clone()), filter_columns)?; - - let filter_result = f - .expression() - .evaluate(&filter_batch)? - .into_array(filter_batch.num_rows())?; - let mask = datafusion_common::cast::as_boolean_array(&filter_result)?; - - compute::filter_record_batch(&output_batch, mask)? - } else { - output_batch - }; + if !filter_columns.is_empty() { + if let Some(f) = &self.filter { + // Construct batch with only filter columns + let filter_batch = RecordBatch::try_new( + Arc::new(f.schema().clone()), + filter_columns, + )?; + + let filter_result = f + .expression() + .evaluate(&filter_batch)? + .into_array(filter_batch.num_rows())?; + + // The selection mask of the filter + let mask = datafusion_common::cast::as_boolean_array(&filter_result)?; + + // Push the filtered batch to the output + let filtered_batch = + compute::filter_record_batch(&output_batch, mask)?; + self.output_record_batches.push(filtered_batch); + + // For outer joins, we need to push the null joined rows to the output. + if matches!( + self.join_type, + JoinType::Left | JoinType::Right | JoinType::Full + ) { + // The reverse of the selection mask, which is for null joined rows + let not_mask = compute::not(mask)?; + let null_joined_batch = + compute::filter_record_batch(&output_batch, ¬_mask)?; + + let mut buffered_columns = self + .buffered_schema + .fields() + .iter() + .map(|f| { + new_null_array( + f.data_type(), + null_joined_batch.num_rows(), + ) + }) + .collect::>(); + + let columns = if matches!(self.join_type, JoinType::Right) { + let streamed_columns = null_joined_batch + .columns() + .iter() + .skip(buffered_columns_length) + .cloned() + .collect::>(); + + buffered_columns.extend(streamed_columns); + buffered_columns + } else { + let mut streamed_columns = null_joined_batch + .columns() + .iter() + .take(streamed_columns_length) + .cloned() + .collect::>(); + + streamed_columns.extend(buffered_columns); + streamed_columns + }; - self.output_record_batches.push(output_batch); + let null_joined_streamed_batch = + RecordBatch::try_new(self.schema.clone(), columns.clone())?; + self.output_record_batches.push(null_joined_streamed_batch); + + // For full join, we also need to output the null joined rows from the buffered side + if matches!(self.join_type, JoinType::Full) { + let mut streamed_columns = self + .streamed_schema + .fields() + .iter() + .map(|f| { + new_null_array( + f.data_type(), + null_joined_batch.num_rows(), + ) + }) + .collect::>(); + + let buffered_columns = null_joined_batch + .columns() + .iter() + .skip(streamed_columns_length) + .cloned() + .collect::>(); + + streamed_columns.extend(buffered_columns); + + let null_joined_buffered_batch = RecordBatch::try_new( + self.schema.clone(), + streamed_columns, + )?; + self.output_record_batches.push(null_joined_buffered_batch); + } + } + } else { + self.output_record_batches.push(output_batch); + } + } else { + self.output_record_batches.push(output_batch); + } } self.streamed_batch.output_indices.clear(); @@ -1217,7 +1294,14 @@ impl SMJStream { let record_batch = concat_batches(&self.schema, &self.output_record_batches)?; self.join_metrics.output_batches.add(1); self.join_metrics.output_rows.add(record_batch.num_rows()); - self.output_size -= record_batch.num_rows(); + // If join filter exists, `self.output_size` is not accurate as we don't know the exact + // number of rows in the output record batch. If streamed row joined with buffered rows, + // once join filter is applied, the number of output rows may be more than 1. + if record_batch.num_rows() > self.output_size { + self.output_size = 0; + } else { + self.output_size -= record_batch.num_rows(); + } self.output_record_batches.clear(); Ok(record_batch) } diff --git a/datafusion/sqllogictest/test_files/join.slt b/datafusion/sqllogictest/test_files/join.slt index ca9b918ff3ee..d287d11041eb 100644 --- a/datafusion/sqllogictest/test_files/join.slt +++ b/datafusion/sqllogictest/test_files/join.slt @@ -238,6 +238,27 @@ SELECT t1_int, t2_int, t2_id FROM t1 RIGHT JOIN t2 ON t1_id = t2_id AND t2_int < NULL 3 11 NULL 3 55 +# equijoin_full +query ITIITI rowsort +SELECT * FROM t1 FULL JOIN t2 ON t1_id = t2_id +---- +11 a 1 11 z 3 +22 b 2 22 y 1 +33 c 3 NULL NULL NULL +44 d 4 44 x 3 +NULL NULL NULL 55 w 3 + +# equijoin_full_and_condition_from_both +query ITIITI rowsort +SELECT * FROM t1 FULL JOIN t2 ON t1_id = t2_id AND t2_int <= t1_int +---- +11 a 1 NULL NULL NULL +22 b 2 22 y 1 +33 c 3 NULL NULL NULL +44 d 4 44 x 3 +NULL NULL NULL 11 z 3 +NULL NULL NULL 55 w 3 + # left_join query ITT rowsort SELECT t1_id, t1_name, t2_name FROM t1 LEFT JOIN t2 ON t1_id = t2_id diff --git a/datafusion/sqllogictest/test_files/sort_merge_join.slt b/datafusion/sqllogictest/test_files/sort_merge_join.slt index d9266dea5ab1..426b9a3a5291 100644 --- a/datafusion/sqllogictest/test_files/sort_merge_join.slt +++ b/datafusion/sqllogictest/test_files/sort_merge_join.slt @@ -84,6 +84,8 @@ SELECT * FROM t1 LEFT JOIN t2 ON t1.a = t2.a AND t2.b * 50 <= t1.b Alice 100 Alice 1 Alice 100 Alice 2 Alice 50 Alice 1 +Alice 50 NULL NULL +Bob 1 NULL NULL query TITI rowsort SELECT * FROM t1 LEFT JOIN t2 ON t1.a = t2.a AND t2.b < t1.b @@ -92,6 +94,7 @@ Alice 100 Alice 1 Alice 100 Alice 2 Alice 50 Alice 1 Alice 50 Alice 2 +Bob 1 NULL NULL # right join without join filter query TITI rowsort @@ -109,6 +112,7 @@ SELECT * FROM t1 RIGHT JOIN t2 ON t1.a = t2.a AND t2.b * 50 <= t1.b Alice 100 Alice 1 Alice 100 Alice 2 Alice 50 Alice 1 +NULL NULL Alice 2 query TITI rowsort SELECT * FROM t1 RIGHT JOIN t2 ON t1.a = t2.a AND t1.b > t2.b @@ -132,19 +136,132 @@ Bob 1 NULL NULL query TITI rowsort SELECT * FROM t1 FULL JOIN t2 ON t1.a = t2.a AND t2.b * 50 > t1.b ---- +Alice 100 NULL NULL +Alice 100 NULL NULL Alice 50 Alice 2 +Alice 50 NULL NULL +Bob 1 NULL NULL +NULL NULL Alice 1 +NULL NULL Alice 1 +NULL NULL Alice 2 query TITI rowsort SELECT * FROM t1 FULL JOIN t2 ON t1.a = t2.a AND t1.b > t2.b + 50 ---- Alice 100 Alice 1 Alice 100 Alice 2 +Alice 50 NULL NULL +Alice 50 NULL NULL +Bob 1 NULL NULL +NULL NULL Alice 1 +NULL NULL Alice 2 statement ok -set datafusion.optimizer.prefer_hash_join = true; +DROP TABLE t1; + +statement ok +DROP TABLE t2; + +statement ok +CREATE TABLE IF NOT EXISTS t1(t1_id INT, t1_name TEXT, t1_int INT) AS VALUES +(11, 'a', 1), +(22, 'b', 2), +(33, 'c', 3), +(44, 'd', 4); + +statement ok +CREATE TABLE IF NOT EXISTS t2(t2_id INT, t2_name TEXT, t2_int INT) AS VALUES +(11, 'z', 3), +(22, 'y', 1), +(44, 'x', 3), +(55, 'w', 3); + +# inner join with join filter +query III rowsort +SELECT t1_id, t1_int, t2_int FROM t1 JOIN t2 ON t1_id = t2_id AND t1_int >= t2_int +---- +22 2 1 +44 4 3 + +# equijoin_multiple_condition_ordering +query ITT rowsort +SELECT t1_id, t1_name, t2_name FROM t1 JOIN t2 ON t1_id = t2_id AND t1_name <> t2_name +---- +11 a z +22 b y +44 d x + +# equijoin_right_and_condition_from_left +query ITT rowsort +SELECT t1_id, t1_name, t2_name FROM t1 RIGHT JOIN t2 ON t1_id = t2_id AND t1_id >= 22 +---- +22 b y +44 d x +NULL NULL w +NULL NULL z + +# equijoin_left_and_condition_from_left +query ITT rowsort +SELECT t1_id, t1_name, t2_name FROM t1 LEFT JOIN t2 ON t1_id = t2_id AND t1_id >= 44 +---- +11 a NULL +22 b NULL +33 c NULL +44 d x + +# equijoin_left_and_condition_from_both +query III rowsort +SELECT t1_id, t1_int, t2_int FROM t1 LEFT JOIN t2 ON t1_id = t2_id AND t1_int >= t2_int +---- +11 1 NULL +22 2 1 +33 3 NULL +44 4 3 + +# equijoin_right_and_condition_from_right +query ITT rowsort +SELECT t1_id, t1_name, t2_name FROM t1 RIGHT JOIN t2 ON t1_id = t2_id AND t2_id >= 22 +---- +22 b y +44 d x +NULL NULL w +NULL NULL z + +# equijoin_right_and_condition_from_both +query III rowsort +SELECT t1_int, t2_int, t2_id FROM t1 RIGHT JOIN t2 ON t1_id = t2_id AND t2_int <= t1_int +---- +2 1 22 +4 3 44 +NULL 3 11 +NULL 3 55 + +# equijoin_full +query ITIITI rowsort +SELECT * FROM t1 FULL JOIN t2 ON t1_id = t2_id +---- +11 a 1 11 z 3 +22 b 2 22 y 1 +33 c 3 NULL NULL NULL +44 d 4 44 x 3 +NULL NULL NULL 55 w 3 + +# equijoin_full_and_condition_from_both +query ITIITI rowsort +SELECT * FROM t1 FULL JOIN t2 ON t1_id = t2_id AND t2_int <= t1_int +---- +11 a 1 NULL NULL NULL +22 b 2 22 y 1 +33 c 3 NULL NULL NULL +44 d 4 44 x 3 +NULL NULL NULL 11 z 3 +NULL NULL NULL 55 w 3 statement ok DROP TABLE t1; statement ok DROP TABLE t2; + +statement ok +set datafusion.optimizer.prefer_hash_join = true;