Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Sort Merge Join LeftSemi issues when JoinFilter is set #10304

Merged
merged 20 commits into from
May 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
230 changes: 201 additions & 29 deletions datafusion/physical-plan/src/joins/sort_merge_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,13 @@ use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};

use crate::expressions::PhysicalSortExpr;
use crate::joins::utils::{
build_join_schema, check_join_is_valid, estimate_join_statistics,
partitioned_join_output_partitioning, JoinFilter, JoinOn, JoinOnRef,
};
use crate::metrics::{ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
use crate::{
execution_mode_from_children, metrics, DisplayAs, DisplayFormatType, Distribution,
ExecutionPlan, ExecutionPlanProperties, PhysicalExpr, PlanProperties,
RecordBatchStream, SendableRecordBatchStream, Statistics,
};

use arrow::array::*;
use arrow::compute::{self, concat_batches, take, SortOptions};
use arrow::datatypes::{DataType, SchemaRef, TimeUnit};
use arrow::error::ArrowError;
use futures::{Stream, StreamExt};
use hashbrown::HashSet;

use datafusion_common::{
internal_err, not_impl_err, plan_err, DataFusionError, JoinSide, JoinType, Result,
};
Expand All @@ -54,7 +45,17 @@ use datafusion_execution::TaskContext;
use datafusion_physical_expr::equivalence::join_equivalence_properties;
use datafusion_physical_expr::{PhysicalExprRef, PhysicalSortRequirement};

use futures::{Stream, StreamExt};
use crate::expressions::PhysicalSortExpr;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this change is by formatter

use crate::joins::utils::{
build_join_schema, check_join_is_valid, estimate_join_statistics,
partitioned_join_output_partitioning, JoinFilter, JoinOn, JoinOnRef,
};
use crate::metrics::{ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
use crate::{
execution_mode_from_children, metrics, DisplayAs, DisplayFormatType, Distribution,
ExecutionPlan, ExecutionPlanProperties, PhysicalExpr, PlanProperties,
RecordBatchStream, SendableRecordBatchStream, Statistics,
};

/// join execution plan executes partitions in parallel and combines them into a set of
/// partitions.
Expand Down Expand Up @@ -491,6 +492,10 @@ struct StreamedBatch {
pub output_indices: Vec<StreamedJoinedChunk>,
/// Index of currently scanned batch from buffered data
pub buffered_batch_idx: Option<usize>,
/// Indices that found a match for the given join filter
/// Used for semi joins to keep track the streaming index which got a join filter match
/// and already emitted to the output.
pub join_filter_matched_idxs: HashSet<u64>,
}

impl StreamedBatch {
Expand All @@ -502,6 +507,7 @@ impl StreamedBatch {
join_arrays,
output_indices: vec![],
buffered_batch_idx: None,
join_filter_matched_idxs: HashSet::new(),
}
}

Expand All @@ -512,6 +518,7 @@ impl StreamedBatch {
join_arrays: vec![],
output_indices: vec![],
buffered_batch_idx: None,
join_filter_matched_idxs: HashSet::new(),
}
}

Expand Down Expand Up @@ -990,7 +997,22 @@ impl SMJStream {
}
Ordering::Equal => {
if matches!(self.join_type, JoinType::LeftSemi) {
join_streamed = !self.streamed_joined;
// if the join filter is specified then its needed to output the streamed index
// only if it has not been emitted before
// the `join_filter_matched_idxs` keeps track on if streamed index has a successful
// filter match and prevents the same index to go into output more than once
if self.filter.is_some() {
join_streamed = !self
.streamed_batch
.join_filter_matched_idxs
.contains(&(self.streamed_batch.idx as u64))
&& !self.streamed_joined;
// if the join filter specified there can be references to buffered columns
// so buffered columns are needed to access them
join_buffered = join_streamed;
} else {
join_streamed = !self.streamed_joined;
}
}
if matches!(
self.join_type,
Expand Down Expand Up @@ -1134,17 +1156,15 @@ impl SMJStream {
.collect::<Result<Vec<_>, ArrowError>>()?;

let buffered_indices: UInt64Array = chunk.buffered_indices.finish();

comphead marked this conversation as resolved.
Show resolved Hide resolved
let mut buffered_columns =
if matches!(self.join_type, JoinType::LeftSemi | JoinType::LeftAnti) {
vec![]
} else if let Some(buffered_idx) = chunk.buffered_batch_idx {
self.buffered_data.batches[buffered_idx]
.batch
.columns()
.iter()
.map(|column| take(column, &buffered_indices, None))
.collect::<Result<Vec<_>, ArrowError>>()?
get_buffered_columns(
&self.buffered_data,
buffered_idx,
&buffered_indices,
)?
} else {
self.buffered_schema
.fields()
Expand All @@ -1161,6 +1181,15 @@ impl SMJStream {
let filter_columns = if chunk.buffered_batch_idx.is_some() {
if matches!(self.join_type, JoinType::Right) {
get_filter_column(&self.filter, &buffered_columns, &streamed_columns)
} else if matches!(self.join_type, JoinType::LeftSemi) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this should also check for JoinType::Left (and the clause above also check for JoinType::RightSemi 🤔

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have more to go, LeftAnti is first as it prevents TPCH to run and then double check RightSemi as well, good point

// unwrap is safe here as we check is_some on top of if statement
let buffered_columns = get_buffered_columns(
&self.buffered_data,
chunk.buffered_batch_idx.unwrap(),
&buffered_indices,
)?;

get_filter_column(&self.filter, &streamed_columns, &buffered_columns)
} else {
get_filter_column(&self.filter, &streamed_columns, &buffered_columns)
}
Expand Down Expand Up @@ -1195,7 +1224,17 @@ impl SMJStream {
.into_array(filter_batch.num_rows())?;

// The selection mask of the filter
let mask = datafusion_common::cast::as_boolean_array(&filter_result)?;
let mut mask =
datafusion_common::cast::as_boolean_array(&filter_result)?;

let maybe_filtered_join_mask: Option<(BooleanArray, Vec<u64>)> =
get_filtered_join_mask(self.join_type, streamed_indices, mask);
if let Some(ref filtered_join_mask) = maybe_filtered_join_mask {
mask = &filtered_join_mask.0;
self.streamed_batch
.join_filter_matched_idxs
.extend(&filtered_join_mask.1);
}

// Push the filtered batch to the output
let filtered_batch =
Expand Down Expand Up @@ -1365,6 +1404,69 @@ fn get_filter_column(
filter_columns
}

/// Get `buffered_indices` rows for `buffered_data[buffered_batch_idx]`
#[inline(always)]
fn get_buffered_columns(
buffered_data: &BufferedData,
buffered_batch_idx: usize,
buffered_indices: &UInt64Array,
) -> Result<Vec<ArrayRef>, ArrowError> {
buffered_data.batches[buffered_batch_idx]
.batch
.columns()
.iter()
.map(|column| take(column, &buffered_indices, None))
.collect::<Result<Vec<_>, ArrowError>>()
}

// Calculate join filter bit mask considering join type specifics
comphead marked this conversation as resolved.
Show resolved Hide resolved
// `streamed_indices` - array of streamed datasource JOINED row indices
// `mask` - array booleans representing computed join filter expression eval result:
// true = the row index matches the join filter
// false = the row index doesn't match the join filter
// `streamed_indices` have the same length as `mask`
fn get_filtered_join_mask(
join_type: JoinType,
streamed_indices: UInt64Array,
mask: &BooleanArray,
) -> Option<(BooleanArray, Vec<u64>)> {
// for LeftSemi Join the filter mask should be calculated in its own way:
// if we find at least one matching row for specific streaming index
// we don't need to check any others for the same index
if matches!(join_type, JoinType::LeftSemi) {
comphead marked this conversation as resolved.
Show resolved Hide resolved
// have we seen a filter match for a streaming index before
let mut seen_as_true: bool = false;
let streamed_indices_length = streamed_indices.len();
let mut corrected_mask: BooleanBuilder =
BooleanBuilder::with_capacity(streamed_indices_length);

let mut filter_matched_indices: Vec<u64> = vec![];

#[allow(clippy::needless_range_loop)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder why ignore clippy here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clippy doesn't like for loops anymore ....

for i in 0..streamed_indices_length {
// LeftSemi respects only first true values for specific streaming index,
// others true values for the same index must be false
if mask.value(i) && !seen_as_true {
seen_as_true = true;
corrected_mask.append_value(true);
filter_matched_indices.push(streamed_indices.value(i));
viirya marked this conversation as resolved.
Show resolved Hide resolved
} else {
corrected_mask.append_value(false);
}

// if switched to next streaming index(e.g. from 0 to 1, or from 1 to 2), we reset seen_as_true flag
if i < streamed_indices_length - 1
&& streamed_indices.value(i) != streamed_indices.value(i + 1)
{
seen_as_true = false;
}
}
Some((corrected_mask.finish(), filter_matched_indices))
} else {
None
}
}

/// Buffered data contains all buffered batches with one unique join key
#[derive(Debug, Default)]
struct BufferedData {
Expand Down Expand Up @@ -1604,24 +1706,28 @@ fn is_join_arrays_equal(
mod tests {
use std::sync::Arc;

use crate::expressions::Column;
use crate::joins::utils::JoinOn;
use crate::joins::SortMergeJoinExec;
use crate::memory::MemoryExec;
use crate::test::build_table_i32;
use crate::{common, ExecutionPlan};

use arrow::array::{Date32Array, Date64Array, Int32Array};
use arrow::compute::SortOptions;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use arrow_array::{BooleanArray, UInt64Array};

use datafusion_common::JoinType::LeftSemi;
use datafusion_common::{
assert_batches_eq, assert_batches_sorted_eq, assert_contains, JoinType, Result,
};
use datafusion_execution::config::SessionConfig;
use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv};
use datafusion_execution::TaskContext;

use crate::expressions::Column;
use crate::joins::sort_merge_join::get_filtered_join_mask;
use crate::joins::utils::JoinOn;
use crate::joins::SortMergeJoinExec;
use crate::memory::MemoryExec;
use crate::test::build_table_i32;
use crate::{common, ExecutionPlan};

fn build_table(
a: (&str, &Vec<i32>),
b: (&str, &Vec<i32>),
Expand Down Expand Up @@ -2641,6 +2747,72 @@ mod tests {

Ok(())
}

#[tokio::test]
async fn left_semi_join_filtered_mask() -> Result<()> {
assert_eq!(
get_filtered_join_mask(
LeftSemi,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we should test a type other than LeftSemi as negative test coverage 🤔

UInt64Array::from(vec![0, 0, 1, 1]),
&BooleanArray::from(vec![true, true, false, false])
),
Some((BooleanArray::from(vec![true, false, false, false]), vec![0]))
);

assert_eq!(
get_filtered_join_mask(
LeftSemi,
UInt64Array::from(vec![0, 1]),
&BooleanArray::from(vec![true, true])
),
Some((BooleanArray::from(vec![true, true]), vec![0, 1]))
);

assert_eq!(
get_filtered_join_mask(
LeftSemi,
UInt64Array::from(vec![0, 1]),
&BooleanArray::from(vec![false, true])
),
Some((BooleanArray::from(vec![false, true]), vec![1]))
);

assert_eq!(
get_filtered_join_mask(
LeftSemi,
UInt64Array::from(vec![0, 1]),
&BooleanArray::from(vec![true, false])
),
Some((BooleanArray::from(vec![true, false]), vec![0]))
);

assert_eq!(
get_filtered_join_mask(
LeftSemi,
UInt64Array::from(vec![0, 0, 0, 1, 1, 1]),
&BooleanArray::from(vec![false, true, true, true, true, true])
),
Some((
BooleanArray::from(vec![false, true, false, true, false, false]),
vec![0, 1]
))
);

assert_eq!(
get_filtered_join_mask(
LeftSemi,
UInt64Array::from(vec![0, 0, 0, 1, 1, 1]),
&BooleanArray::from(vec![false, false, false, false, false, true])
),
Some((
BooleanArray::from(vec![false, false, false, false, false, true]),
vec![1]
))
);

Ok(())
}

/// Returns the column names on the schema
fn columns(schema: &Schema) -> Vec<String> {
schema.fields().iter().map(|f| f.name().clone()).collect()
Expand Down
Loading