diff --git a/core/src/execution/datafusion/expressions/bloom_filter_might_contain.rs b/core/src/execution/datafusion/expressions/bloom_filter_might_contain.rs index 5d3941830..dd90cd8e9 100644 --- a/core/src/execution/datafusion/expressions/bloom_filter_might_contain.rs +++ b/core/src/execution/datafusion/expressions/bloom_filter_might_contain.rs @@ -19,7 +19,7 @@ use crate::{ execution::datafusion::util::spark_bloom_filter::SparkBloomFilter, parquet::data_type::AsBytes, }; use arrow::record_batch::RecordBatch; -use arrow_array::{cast::as_primitive_array, BooleanArray}; +use arrow_array::cast::as_primitive_array; use arrow_schema::{DataType, Schema}; use datafusion::physical_plan::ColumnarValue; use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue}; @@ -33,7 +33,6 @@ use std::{ /// A physical expression that checks if a value might be in a bloom filter. It corresponds to the /// Spark's `BloomFilterMightContain` expression. - #[derive(Debug, Hash)] pub struct BloomFilterMightContain { pub bloom_filter_expr: Arc, @@ -72,25 +71,24 @@ fn evaluate_bloom_filter( let bloom_filter_bytes = bloom_filter_expr.evaluate(&batch)?; match bloom_filter_bytes { ColumnarValue::Scalar(ScalarValue::Binary(v)) => { - Ok(v.map(|v| SparkBloomFilter::new_from_buf(v.as_bytes()))) + Ok(v.map(|v| SparkBloomFilter::new(v.as_bytes()))) } - _ => internal_err!("Bloom filter expression must be evaluated as a scalar binary value"), + _ => internal_err!("Bloom filter expression should be evaluated as a scalar binary value"), } } impl BloomFilterMightContain { - pub fn new( + pub fn try_new( bloom_filter_expr: Arc, value_expr: Arc, - ) -> Self { + ) -> Result { // early evaluate the bloom_filter_expr to get the actual bloom filter - let bloom_filter = evaluate_bloom_filter(&bloom_filter_expr) - .expect("bloom_filter_expr could be evaluated statically"); - Self { + let bloom_filter = evaluate_bloom_filter(&bloom_filter_expr)?; + Ok(Self { bloom_filter_expr, value_expr, bloom_filter, - } + }) } } @@ -108,7 +106,6 @@ impl PhysicalExpr for BloomFilterMightContain { } fn evaluate(&self, batch: &RecordBatch) -> Result { - let num_rows = batch.num_rows(); self.bloom_filter .as_ref() .map(|spark_filter| { @@ -123,14 +120,12 @@ impl PhysicalExpr for BloomFilterMightContain { let result = v.map(|v| spark_filter.might_contain_long(v)); Ok(ColumnarValue::Scalar(ScalarValue::Boolean(result))) } - _ => internal_err!("value expression must be int64 type"), + _ => internal_err!("value expression should be int64 type"), } }) .unwrap_or_else(|| { - // when the bloom filter is null, we should return a boolean array with all nulls - Ok(ColumnarValue::Array(Arc::new(BooleanArray::new_null( - num_rows, - )))) + // when the bloom filter is null, we should return null for all the input + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None))) }) } @@ -142,10 +137,10 @@ impl PhysicalExpr for BloomFilterMightContain { self: Arc, children: Vec>, ) -> Result> { - Ok(Arc::new(BloomFilterMightContain::new( + Ok(Arc::new(BloomFilterMightContain::try_new( children[0].clone(), children[1].clone(), - ))) + )?)) } fn dyn_hash(&self, state: &mut dyn Hasher) { diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index 7c73a307c..9389a8c23 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -539,10 +539,10 @@ impl PhysicalPlanner { let bloom_filter_expr = self.create_expr(expr.bloom_filter.as_ref().unwrap(), input_schema.clone())?; let value_expr = self.create_expr(expr.value.as_ref().unwrap(), input_schema)?; - Ok(Arc::new(BloomFilterMightContain::new( + Ok(Arc::new(BloomFilterMightContain::try_new( bloom_filter_expr, value_expr, - ))) + )?)) } expr => Err(ExecutionError::GeneralError(format!( "Not implemented: {:?}", diff --git a/core/src/execution/datafusion/util/spark_bit_array.rs b/core/src/execution/datafusion/util/spark_bit_array.rs index d4a24b315..9729627df 100644 --- a/core/src/execution/datafusion/util/spark_bit_array.rs +++ b/core/src/execution/datafusion/util/spark_bit_array.rs @@ -18,7 +18,6 @@ /// A simple bit array implementation that simulates the behavior of Spark's BitArray which is /// used in the BloomFilter implementation. Some methods are not implemented as they are not /// required for the current use case. - #[derive(Debug, Hash)] pub struct SparkBitArray { data: Vec, @@ -36,6 +35,7 @@ impl SparkBitArray { pub fn set(&mut self, index: usize) -> bool { if !self.get(index) { + // see the get method for the explanation of the shift operators self.data[index >> 6] |= 1u64 << (index & 0x3f); self.bit_count += 1; true @@ -45,6 +45,12 @@ impl SparkBitArray { } pub fn get(&self, index: usize) -> bool { + // Java version: (data[(int) (index >> 6)] & (1L << (index))) != 0 + // Rust and Java have different semantics for the shift operators. Java's shift operators + // explicitly mask the right-hand operand with 0x3f [1], while Rust's shift operators does + // not do this, it will panic with shift left with overflow for large right-hand operand. + // To fix this, we need to mask the right-hand operand with 0x3f in the rust side. + // [1]: https://docs.oracle.com/javase/specs/jls/se7/html/jls-15.html#jls-15.19 (self.data[index >> 6] & (1u64 << (index & 0x3f))) != 0 } diff --git a/core/src/execution/datafusion/util/spark_bloom_filter.rs b/core/src/execution/datafusion/util/spark_bloom_filter.rs index bf53a1dd8..22957a147 100644 --- a/core/src/execution/datafusion/util/spark_bloom_filter.rs +++ b/core/src/execution/datafusion/util/spark_bloom_filter.rs @@ -25,23 +25,23 @@ const SPARK_BLOOM_FILTER_VERSION_1: i32 = 1; /// A Bloom filter implementation that simulates the behavior of Spark's BloomFilter. /// It's not a complete implementation of Spark's BloomFilter, but just add the minimum /// methods to support mightContainsLong in the native side. - #[derive(Debug, Hash)] pub struct SparkBloomFilter { bits: SparkBitArray, - num_hashes: u32, + num_hash_functions: u32, } impl SparkBloomFilter { - pub fn new_from_buf(buf: &[u8]) -> Self { + pub fn new(buf: &[u8]) -> Self { let mut offset = 0; let version = read_num_be_bytes!(i32, 4, buf[offset..]); offset += 4; assert_eq!( version, SPARK_BLOOM_FILTER_VERSION_1, - "Unsupported BloomFilter version" + "Unsupported BloomFilter version: {}, expecting version: {}", + version, SPARK_BLOOM_FILTER_VERSION_1 ); - let num_hashes = read_num_be_bytes!(i32, 4, buf[offset..]); + let num_hash_functions = read_num_be_bytes!(i32, 4, buf[offset..]); offset += 4; let num_words = read_num_be_bytes!(i32, 4, buf[offset..]); offset += 4; @@ -52,18 +52,18 @@ impl SparkBloomFilter { } Self { bits: SparkBitArray::new(bits), - num_hashes: num_hashes as u32, + num_hash_functions: num_hash_functions as u32, } } pub fn put_long(&mut self, item: i64) -> bool { // Here we first hash the input long element into 2 int hash values, h1 and h2, then produce - // n hash values by `h1 + i * h2` with 1 <= i <= num_hashes. + // n hash values by `h1 + i * h2` with 1 <= i <= num_hash_functions. let h1 = spark_compatible_murmur3_hash(item.to_le_bytes(), 0); let h2 = spark_compatible_murmur3_hash(item.to_le_bytes(), h1); let bit_size = self.bits.bit_size() as i32; let mut bit_changed = false; - for i in 1..=self.num_hashes { + for i in 1..=self.num_hash_functions { let mut combined_hash = (h1 as i32).add_wrapping((i as i32).mul_wrapping(h2 as i32)); if combined_hash < 0 { combined_hash = !combined_hash; @@ -77,7 +77,7 @@ impl SparkBloomFilter { let h1 = spark_compatible_murmur3_hash(item.to_le_bytes(), 0); let h2 = spark_compatible_murmur3_hash(item.to_le_bytes(), h1); let bit_size = self.bits.bit_size() as i32; - for i in 1..=self.num_hashes { + for i in 1..=self.num_hash_functions { let mut combined_hash = (h1 as i32).add_wrapping((i as i32).mul_wrapping(h2 as i32)); if combined_hash < 0 { combined_hash = !combined_hash; diff --git a/spark/src/main/scala/org/apache/comet/shims/ShimQueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/shims/ShimQueryPlanSerde.scala index dd73e9521..7bdf2c0ef 100644 --- a/spark/src/main/scala/org/apache/comet/shims/ShimQueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/shims/ShimQueryPlanSerde.scala @@ -45,7 +45,7 @@ trait ShimQueryPlanSerde { } } - // todo: delete after drop Spark 3.2 support + // TODO: delete after drop Spark 3.2 support def isBloomFilterMightContain(binary: BinaryExpression): Boolean = { binary.getClass.getName == "org.apache.spark.sql.catalyst.expressions.BloomFilterMightContain" }