diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index e9124a72970ae..783aafc1b2c1e 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -22,7 +22,7 @@ use std::any::Any; use std::cmp::Ordering; -use std::collections::VecDeque; +use std::collections::{HashMap, VecDeque}; use std::fmt::Formatter; use std::mem; use std::ops::Range; @@ -1379,23 +1379,32 @@ impl SMJStream { // If it is joined with streamed side, but doesn't match the join filter, // we need to output it with nulls as streamed side. if matches!(self.join_type, JoinType::Full) { + let mut buffered_indices_map: HashMap = + HashMap::new(); + for i in 0..pre_mask.len() { - let buffered_batch = &mut self.buffered_data.batches - [chunk.buffered_batch_idx.unwrap()]; let buffered_index = buffered_indices.value(i); - if !pre_mask.value(i) { - // For a buffered row that is joined with streamed side but doesn't satisfy the join filter, + buffered_indices_map.insert( + buffered_index, + *buffered_indices_map + .get(&buffered_index) + .unwrap_or(&true) + && !pre_mask.value(i), + ); + } + + let buffered_batch = &mut self.buffered_data.batches + [chunk.buffered_batch_idx.unwrap()]; + for (buffered_index, failed_join_filter) in + buffered_indices_map + { + if failed_join_filter { + // For a buffered row that is joined with streamed side rows but all joined rows don't + // satisfy the join filter, buffered_batch .join_filter_failed_idxs .insert(buffered_index); - } else if buffered_batch - .join_filter_failed_idxs - .contains(&buffered_index) - { - buffered_batch - .join_filter_failed_idxs - .remove(&buffered_index); } } }