diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 96c182cdec834..3f2573183d9da 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -1104,6 +1104,28 @@ impl SMJStream { .map(|f| new_null_array(f.data_type(), buffered_indices.len())) .collect::>(); + // Construct batch with only filter columns + let mut filter_columns = vec![]; + + if let Some(f) = &self.filter { + let left_columns = f + .column_indices() + .iter() + .filter(|col_index| (*col_index).side == JoinSide::Left) + .map(|i| streamed_columns[i.index].clone()) + .collect::>(); + + let right_columns = f + .column_indices() + .iter() + .filter(|col_index| (*col_index).side == JoinSide::Right) + .map(|i| buffered_columns[i.index].clone()) + .collect::>(); + + filter_columns.extend(left_columns); + filter_columns.extend(right_columns); + } + streamed_columns.extend(buffered_columns); let columns = streamed_columns; @@ -1113,10 +1135,14 @@ impl SMJStream { let output_batch = if let Some(f) = &self.filter { println!("f: {:?}", f); + // Construct batch with only filter columns + let filter_batch = + RecordBatch::try_new(Arc::new(f.schema().clone()), filter_columns)?; + let filter_result = f .expression() - .evaluate(&output_batch)? - .into_array(output_batch.num_rows())?; + .evaluate(&filter_batch)? + .into_array(filter_batch.num_rows())?; let mask = datafusion_common::cast::as_boolean_array(&filter_result)?; println!("mask: {:?}", mask); @@ -1168,6 +1194,28 @@ impl SMJStream { .collect::>() }; + // Construct batch with only filter columns + let mut filter_columns = vec![]; + + if let Some(f) = &self.filter { + let left_columns = f + .column_indices() + .iter() + .filter(|col_index| (*col_index).side == JoinSide::Left) + .map(|i| streamed_columns[i.index].clone()) + .collect::>(); + + let right_columns = f + .column_indices() + .iter() + .filter(|col_index| (*col_index).side == JoinSide::Right) + .map(|i| buffered_columns[i.index].clone()) + .collect::>(); + + filter_columns.extend(left_columns); + filter_columns.extend(right_columns); + } + let columns = if matches!(self.join_type, JoinType::Right) { buffered_columns.extend(streamed_columns); buffered_columns @@ -1182,10 +1230,14 @@ impl SMJStream { let output_batch = if let Some(f) = &self.filter { println!("f: {:?}", f); + // Construct batch with only filter columns + let filter_batch = + RecordBatch::try_new(Arc::new(f.schema().clone()), filter_columns)?; + let filter_result = f .expression() - .evaluate(&output_batch)? - .into_array(output_batch.num_rows())?; + .evaluate(&filter_batch)? + .into_array(filter_batch.num_rows())?; let mask = datafusion_common::cast::as_boolean_array(&filter_result)?; println!("mask: {:?}", mask);