Skip to content

Commit

Permalink
Fix outer join
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Feb 3, 2024
1 parent 99940c2 commit be35c90
Show file tree
Hide file tree
Showing 3 changed files with 268 additions and 46 deletions.
174 changes: 129 additions & 45 deletions datafusion/physical-plan/src/joins/sort_merge_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1104,32 +1104,11 @@ impl SMJStream {
.map(|f| new_null_array(f.data_type(), buffered_indices.len()))
.collect::<Vec<_>>();

let filter_columns =
get_filter_column(&self.filter, &streamed_columns, &buffered_columns);

streamed_columns.extend(buffered_columns);
let columns = streamed_columns;

let output_batch = RecordBatch::try_new(self.schema.clone(), columns)?;

// Apply join filter if any
let output_batch = if let Some(f) = &self.filter {
// Construct batch with only filter columns
let filter_batch =
RecordBatch::try_new(Arc::new(f.schema().clone()), filter_columns)?;

let filter_result = f
.expression()
.evaluate(&filter_batch)?
.into_array(filter_batch.num_rows())?;
let mask = datafusion_common::cast::as_boolean_array(&filter_result)?;

compute::filter_record_batch(&output_batch, mask)?
} else {
output_batch
};

self.output_record_batches.push(output_batch);
self.output_record_batches
.push(RecordBatch::try_new(self.schema.clone(), columns)?);
}
Ok(())
}
Expand Down Expand Up @@ -1172,40 +1151,138 @@ impl SMJStream {
.collect::<Vec<_>>()
};

let filter_columns = if matches!(self.join_type, JoinType::Right) {
get_filter_column(&self.filter, &buffered_columns, &streamed_columns)
let streamed_columns_length = streamed_columns.len();
let buffered_columns_length = buffered_columns.len();

// Prepare the columns we apply join filter on later.
// Only for joined rows between streamed and buffered.
let filter_columns = if chunk.buffered_batch_idx.is_some() {
if matches!(self.join_type, JoinType::Right) {
get_filter_column(&self.filter, &buffered_columns, &streamed_columns)
} else {
get_filter_column(&self.filter, &streamed_columns, &buffered_columns)
}
} else {
get_filter_column(&self.filter, &streamed_columns, &buffered_columns)
vec![]
};

let columns = if matches!(self.join_type, JoinType::Right) {
buffered_columns.extend(streamed_columns);
buffered_columns.extend(streamed_columns.clone());
buffered_columns
} else {
streamed_columns.extend(buffered_columns);
streamed_columns
};

let output_batch = RecordBatch::try_new(self.schema.clone(), columns)?;
let output_batch =
RecordBatch::try_new(self.schema.clone(), columns.clone())?;

// Apply join filter if any
let output_batch = if let Some(f) = &self.filter {
// Construct batch with only filter columns
let filter_batch =
RecordBatch::try_new(Arc::new(f.schema().clone()), filter_columns)?;

let filter_result = f
.expression()
.evaluate(&filter_batch)?
.into_array(filter_batch.num_rows())?;
let mask = datafusion_common::cast::as_boolean_array(&filter_result)?;

compute::filter_record_batch(&output_batch, mask)?
} else {
output_batch
};
if !filter_columns.is_empty() {
if let Some(f) = &self.filter {
// Construct batch with only filter columns
let filter_batch = RecordBatch::try_new(
Arc::new(f.schema().clone()),
filter_columns,
)?;

let filter_result = f
.expression()
.evaluate(&filter_batch)?
.into_array(filter_batch.num_rows())?;

// The selection mask of the filter
let mask = datafusion_common::cast::as_boolean_array(&filter_result)?;

// Push the filtered batch to the output
let filtered_batch =
compute::filter_record_batch(&output_batch, mask)?;
self.output_record_batches.push(filtered_batch);

// For outer joins, we need to push the null joined rows to the output.
if matches!(
self.join_type,
JoinType::Left | JoinType::Right | JoinType::Full
) {
// The reverse of the selection mask, which is for null joined rows
let not_mask = compute::not(mask)?;
let null_joined_batch =
compute::filter_record_batch(&output_batch, &not_mask)?;

let mut buffered_columns = self
.buffered_schema
.fields()
.iter()
.map(|f| {
new_null_array(
f.data_type(),
null_joined_batch.num_rows(),
)
})
.collect::<Vec<_>>();

let columns = if matches!(self.join_type, JoinType::Right) {
let streamed_columns = null_joined_batch
.columns()
.iter()
.skip(buffered_columns_length)
.cloned()
.collect::<Vec<_>>();

buffered_columns.extend(streamed_columns);
buffered_columns
} else {
let mut streamed_columns = null_joined_batch
.columns()
.iter()
.take(streamed_columns_length)
.cloned()
.collect::<Vec<_>>();

streamed_columns.extend(buffered_columns);
streamed_columns
};

self.output_record_batches.push(output_batch);
let null_joined_streamed_batch =
RecordBatch::try_new(self.schema.clone(), columns.clone())?;
self.output_record_batches.push(null_joined_streamed_batch);

// For full join, we also need to output the null joined rows from the buffered side
if matches!(self.join_type, JoinType::Full) {
let mut streamed_columns = self
.streamed_schema
.fields()
.iter()
.map(|f| {
new_null_array(
f.data_type(),
null_joined_batch.num_rows(),
)
})
.collect::<Vec<_>>();

let buffered_columns = null_joined_batch
.columns()
.iter()
.skip(streamed_columns_length)
.cloned()
.collect::<Vec<_>>();

streamed_columns.extend(buffered_columns);

let null_joined_buffered_batch = RecordBatch::try_new(
self.schema.clone(),
streamed_columns,
)?;
self.output_record_batches.push(null_joined_buffered_batch);
}
}
} else {
self.output_record_batches.push(output_batch);
}
} else {
self.output_record_batches.push(output_batch);
}
}

self.streamed_batch.output_indices.clear();
Expand All @@ -1217,7 +1294,14 @@ impl SMJStream {
let record_batch = concat_batches(&self.schema, &self.output_record_batches)?;
self.join_metrics.output_batches.add(1);
self.join_metrics.output_rows.add(record_batch.num_rows());
self.output_size -= record_batch.num_rows();
// If join filter exists, `self.output_size` is not accurate as we don't know the exact
// number of rows in the output record batch. If streamed row joined with buffered rows,
// once join filter is applied, the number of output rows may be more than 1.
if record_batch.num_rows() > self.output_size {
self.output_size = 0;
} else {
self.output_size -= record_batch.num_rows();
}
self.output_record_batches.clear();
Ok(record_batch)
}
Expand Down
21 changes: 21 additions & 0 deletions datafusion/sqllogictest/test_files/join.slt
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,27 @@ SELECT t1_int, t2_int, t2_id FROM t1 RIGHT JOIN t2 ON t1_id = t2_id AND t2_int <
NULL 3 11
NULL 3 55

# equijoin_full
query ITIITI rowsort
SELECT * FROM t1 FULL JOIN t2 ON t1_id = t2_id
----
11 a 1 11 z 3
22 b 2 22 y 1
33 c 3 NULL NULL NULL
44 d 4 44 x 3
NULL NULL NULL 55 w 3

# equijoin_full_and_condition_from_both
query ITIITI rowsort
SELECT * FROM t1 FULL JOIN t2 ON t1_id = t2_id AND t2_int <= t1_int
----
11 a 1 NULL NULL NULL
22 b 2 22 y 1
33 c 3 NULL NULL NULL
44 d 4 44 x 3
NULL NULL NULL 11 z 3
NULL NULL NULL 55 w 3

# left_join
query ITT rowsort
SELECT t1_id, t1_name, t2_name FROM t1 LEFT JOIN t2 ON t1_id = t2_id
Expand Down
119 changes: 118 additions & 1 deletion datafusion/sqllogictest/test_files/sort_merge_join.slt
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ SELECT * FROM t1 LEFT JOIN t2 ON t1.a = t2.a AND t2.b * 50 <= t1.b
Alice 100 Alice 1
Alice 100 Alice 2
Alice 50 Alice 1
Alice 50 NULL NULL
Bob 1 NULL NULL

query TITI rowsort
SELECT * FROM t1 LEFT JOIN t2 ON t1.a = t2.a AND t2.b < t1.b
Expand All @@ -92,6 +94,7 @@ Alice 100 Alice 1
Alice 100 Alice 2
Alice 50 Alice 1
Alice 50 Alice 2
Bob 1 NULL NULL

# right join without join filter
query TITI rowsort
Expand All @@ -109,6 +112,7 @@ SELECT * FROM t1 RIGHT JOIN t2 ON t1.a = t2.a AND t2.b * 50 <= t1.b
Alice 100 Alice 1
Alice 100 Alice 2
Alice 50 Alice 1
NULL NULL Alice 2

query TITI rowsort
SELECT * FROM t1 RIGHT JOIN t2 ON t1.a = t2.a AND t1.b > t2.b
Expand All @@ -132,19 +136,132 @@ Bob 1 NULL NULL
query TITI rowsort
SELECT * FROM t1 FULL JOIN t2 ON t1.a = t2.a AND t2.b * 50 > t1.b
----
Alice 100 NULL NULL
Alice 100 NULL NULL
Alice 50 Alice 2
Alice 50 NULL NULL
Bob 1 NULL NULL
NULL NULL Alice 1
NULL NULL Alice 1
NULL NULL Alice 2

query TITI rowsort
SELECT * FROM t1 FULL JOIN t2 ON t1.a = t2.a AND t1.b > t2.b + 50
----
Alice 100 Alice 1
Alice 100 Alice 2
Alice 50 NULL NULL
Alice 50 NULL NULL
Bob 1 NULL NULL
NULL NULL Alice 1
NULL NULL Alice 2

statement ok
set datafusion.optimizer.prefer_hash_join = true;
DROP TABLE t1;

statement ok
DROP TABLE t2;

statement ok
CREATE TABLE IF NOT EXISTS t1(t1_id INT, t1_name TEXT, t1_int INT) AS VALUES
(11, 'a', 1),
(22, 'b', 2),
(33, 'c', 3),
(44, 'd', 4);

statement ok
CREATE TABLE IF NOT EXISTS t2(t2_id INT, t2_name TEXT, t2_int INT) AS VALUES
(11, 'z', 3),
(22, 'y', 1),
(44, 'x', 3),
(55, 'w', 3);

# inner join with join filter
query III rowsort
SELECT t1_id, t1_int, t2_int FROM t1 JOIN t2 ON t1_id = t2_id AND t1_int >= t2_int
----
22 2 1
44 4 3

# equijoin_multiple_condition_ordering
query ITT rowsort
SELECT t1_id, t1_name, t2_name FROM t1 JOIN t2 ON t1_id = t2_id AND t1_name <> t2_name
----
11 a z
22 b y
44 d x

# equijoin_right_and_condition_from_left
query ITT rowsort
SELECT t1_id, t1_name, t2_name FROM t1 RIGHT JOIN t2 ON t1_id = t2_id AND t1_id >= 22
----
22 b y
44 d x
NULL NULL w
NULL NULL z

# equijoin_left_and_condition_from_left
query ITT rowsort
SELECT t1_id, t1_name, t2_name FROM t1 LEFT JOIN t2 ON t1_id = t2_id AND t1_id >= 44
----
11 a NULL
22 b NULL
33 c NULL
44 d x

# equijoin_left_and_condition_from_both
query III rowsort
SELECT t1_id, t1_int, t2_int FROM t1 LEFT JOIN t2 ON t1_id = t2_id AND t1_int >= t2_int
----
11 1 NULL
22 2 1
33 3 NULL
44 4 3

# equijoin_right_and_condition_from_right
query ITT rowsort
SELECT t1_id, t1_name, t2_name FROM t1 RIGHT JOIN t2 ON t1_id = t2_id AND t2_id >= 22
----
22 b y
44 d x
NULL NULL w
NULL NULL z

# equijoin_right_and_condition_from_both
query III rowsort
SELECT t1_int, t2_int, t2_id FROM t1 RIGHT JOIN t2 ON t1_id = t2_id AND t2_int <= t1_int
----
2 1 22
4 3 44
NULL 3 11
NULL 3 55

# equijoin_full
query ITIITI rowsort
SELECT * FROM t1 FULL JOIN t2 ON t1_id = t2_id
----
11 a 1 11 z 3
22 b 2 22 y 1
33 c 3 NULL NULL NULL
44 d 4 44 x 3
NULL NULL NULL 55 w 3

# equijoin_full_and_condition_from_both
query ITIITI rowsort
SELECT * FROM t1 FULL JOIN t2 ON t1_id = t2_id AND t2_int <= t1_int
----
11 a 1 NULL NULL NULL
22 b 2 22 y 1
33 c 3 NULL NULL NULL
44 d 4 44 x 3
NULL NULL NULL 11 z 3
NULL NULL NULL 55 w 3

statement ok
DROP TABLE t1;

statement ok
DROP TABLE t2;

statement ok
set datafusion.optimizer.prefer_hash_join = true;

0 comments on commit be35c90

Please sign in to comment.