From 23bec73cb39c30539552e05531778555c15814dd Mon Sep 17 00:00:00 2001 From: Xianjin Date: Wed, 13 Mar 2024 19:51:36 +0800 Subject: [PATCH] address comments --- .../expressions/bloom_filter_might_contain.rs | 107 ++++++++---------- core/src/execution/datafusion/planner.rs | 3 +- core/src/execution/datafusion/spark_hash.rs | 2 +- .../datafusion/util/spark_bit_array.rs | 10 +- .../datafusion/util/spark_bloom_filter.rs | 18 ++- pom.xml | 4 +- ...cala => CometExpression3_3PlusSuite.scala} | 24 ++-- 7 files changed, 78 insertions(+), 90 deletions(-) rename spark/src/test/spark-3.3-plus/org/apache/comet/{CometExpressionPlusSuite.scala => CometExpression3_3PlusSuite.scala} (88%) diff --git a/core/src/execution/datafusion/expressions/bloom_filter_might_contain.rs b/core/src/execution/datafusion/expressions/bloom_filter_might_contain.rs index 03d753f936..b61c9d2b5b 100644 --- a/core/src/execution/datafusion/expressions/bloom_filter_might_contain.rs +++ b/core/src/execution/datafusion/expressions/bloom_filter_might_contain.rs @@ -19,12 +19,11 @@ use crate::{ execution::datafusion::util::spark_bloom_filter::SparkBloomFilter, parquet::data_type::AsBytes, }; use arrow::record_batch::RecordBatch; -use arrow_array::{BooleanArray, Int64Array}; -use arrow_schema::DataType; -use datafusion::{common::Result, physical_plan::ColumnarValue}; -use datafusion_common::{internal_err, DataFusionError, Result as DataFusionResult, ScalarValue}; +use arrow_array::{cast::as_primitive_array, BooleanArray}; +use arrow_schema::{DataType, Schema}; +use datafusion::physical_plan::ColumnarValue; +use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue}; use datafusion_physical_expr::{aggregate::utils::down_cast_any_ref, PhysicalExpr}; -use once_cell::sync::OnceCell; use std::{ any::Any, fmt::Display, @@ -34,11 +33,12 @@ use std::{ /// A physical expression that checks if a value might be in a bloom filter. It corresponds to the /// Spark's `BloomFilterMightContain` expression. -#[derive(Debug)] + +#[derive(Debug, Hash)] pub struct BloomFilterMightContain { pub bloom_filter_expr: Arc, pub value_expr: Arc, - bloom_filter: OnceCell>, + bloom_filter: Option, } impl Display for BloomFilterMightContain { @@ -63,15 +63,33 @@ impl PartialEq for BloomFilterMightContain { } } +fn evaluate_bloom_filter( + bloom_filter_expr: &Arc, +) -> Result> { + // bloom_filter_expr must be a literal/scalar subquery expression, so we can evaluate it + // with an empty batch with empty schema + let batch = RecordBatch::new_empty(Arc::new(Schema::empty())); + let bloom_filter_bytes = bloom_filter_expr.evaluate(&batch)?; + match bloom_filter_bytes { + ColumnarValue::Scalar(ScalarValue::Binary(v)) => { + Ok(v.map(|v| SparkBloomFilter::new_from_buf(v.as_bytes()))) + } + _ => internal_err!("Bloom filter expression must be evaluated as a scalar binary value"), + } +} + impl BloomFilterMightContain { pub fn new( bloom_filter_expr: Arc, value_expr: Arc, ) -> Self { + // early evaluate the bloom_filter_expr to get the actual bloom filter + let bloom_filter = evaluate_bloom_filter(&bloom_filter_expr) + .expect("bloom_filter_expr could be evaluated statically"); Self { bloom_filter_expr, value_expr, - bloom_filter: Default::default(), + bloom_filter, } } } @@ -81,66 +99,40 @@ impl PhysicalExpr for BloomFilterMightContain { self } - fn data_type(&self, _input_schema: &arrow_schema::Schema) -> Result { + fn data_type(&self, _input_schema: &Schema) -> Result { Ok(DataType::Boolean) } - fn nullable(&self, _input_schema: &arrow_schema::Schema) -> Result { + fn nullable(&self, _input_schema: &Schema) -> Result { Ok(true) } - fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult { + fn evaluate(&self, batch: &RecordBatch) -> Result { // lazily get the spark bloom filter - if self.bloom_filter.get().is_none() { - let bloom_filter_bytes = self.bloom_filter_expr.evaluate(batch)?; - match bloom_filter_bytes { - ColumnarValue::Array(_) => { - return internal_err!( - "Bloom filter expression must be evaluated as a scalar value" - ); - } - ColumnarValue::Scalar(ScalarValue::Binary(v)) => { - let filter = v.map(|v| SparkBloomFilter::new_from_buf(v.as_bytes())); - self.bloom_filter.get_or_init(|| filter); - } - _ => { - return internal_err!("Bloom filter expression must be binary type"); - } - } - } let num_rows = batch.num_rows(); - let lazy_filter = self.bloom_filter.get().unwrap(); - if lazy_filter.is_none() { - // when the bloom filter is null, we should return a boolean array with all nulls - Ok(ColumnarValue::Array(Arc::new(BooleanArray::new_null( - num_rows, - )))) - } else { - let spark_filter = lazy_filter.as_ref().unwrap(); - let values = self.value_expr.evaluate(batch)?; - match values { - ColumnarValue::Array(array) => { - let array = array - .as_any() - .downcast_ref::() - .expect("value_expr must be evaluated as an int64 array"); - Ok(ColumnarValue::Array(Arc::new( - spark_filter.might_contain_longs(array)?, - ))) - } - ColumnarValue::Scalar(a) => match a { - ScalarValue::Int64(v) => { + self.bloom_filter + .as_ref() + .map(|spark_filter| { + let values = self.value_expr.evaluate(batch)?; + match values { + ColumnarValue::Array(array) => { + let boolean_array = + spark_filter.might_contain_longs(as_primitive_array(&array)); + Ok(ColumnarValue::Array(Arc::new(boolean_array))) + } + ColumnarValue::Scalar(ScalarValue::Int64(v)) => { let result = v.map(|v| spark_filter.might_contain_long(v)); Ok(ColumnarValue::Scalar(ScalarValue::Boolean(result))) } - _ => { - internal_err!( - "value_expr must be evaluated as an int64 array or a int64 scalar" - ) - } - }, - } - } + _ => internal_err!("value expression must be int64 type"), + } + }) + .unwrap_or_else(|| { + // when the bloom filter is null, we should return a boolean array with all nulls + Ok(ColumnarValue::Array(Arc::new(BooleanArray::new_null( + num_rows, + )))) + }) } fn children(&self) -> Vec> { @@ -161,5 +153,6 @@ impl PhysicalExpr for BloomFilterMightContain { let mut s = state; self.bloom_filter_expr.hash(&mut s); self.value_expr.hash(&mut s); + self.hash(&mut s); } } diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index 7c73a307c5..399c8db9c1 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -538,7 +538,8 @@ impl PhysicalPlanner { ExprStruct::BloomFilterMightContain(expr) => { let bloom_filter_expr = self.create_expr(expr.bloom_filter.as_ref().unwrap(), input_schema.clone())?; - let value_expr = self.create_expr(expr.value.as_ref().unwrap(), input_schema)?; + let value_expr = + self.create_expr(expr.value.as_ref().unwrap(), input_schema)?; Ok(Arc::new(BloomFilterMightContain::new( bloom_filter_expr, value_expr, diff --git a/core/src/execution/datafusion/spark_hash.rs b/core/src/execution/datafusion/spark_hash.rs index eadce34e03..1d8d1f2c96 100644 --- a/core/src/execution/datafusion/spark_hash.rs +++ b/core/src/execution/datafusion/spark_hash.rs @@ -32,7 +32,7 @@ use datafusion::{ }; #[inline] -pub fn spark_compatible_murmur3_hash>(data: T, seed: u32) -> u32 { +pub(crate) fn spark_compatible_murmur3_hash>(data: T, seed: u32) -> u32 { #[inline] fn mix_k1(mut k1: i32) -> i32 { k1 = k1.mul_wrapping(0xcc9e2d51u32 as i32); diff --git a/core/src/execution/datafusion/util/spark_bit_array.rs b/core/src/execution/datafusion/util/spark_bit_array.rs index f5ba2bce49..d4a24b315f 100644 --- a/core/src/execution/datafusion/util/spark_bit_array.rs +++ b/core/src/execution/datafusion/util/spark_bit_array.rs @@ -18,6 +18,7 @@ /// A simple bit array implementation that simulates the behavior of Spark's BitArray which is /// used in the BloomFilter implementation. Some methods are not implemented as they are not /// required for the current use case. + #[derive(Debug, Hash)] pub struct SparkBitArray { data: Vec, @@ -33,15 +34,6 @@ impl SparkBitArray { } } - pub fn new_from_bit_count(num_bits: usize) -> Self { - let num_words = (num_bits + 63) / 64; - debug_assert!(num_words < u32::MAX as usize, "num_words is too large"); - Self { - data: vec![0u64; num_words], - bit_count: num_bits, - } - } - pub fn set(&mut self, index: usize) -> bool { if !self.get(index) { self.data[index >> 6] |= 1u64 << (index & 0x3f); diff --git a/core/src/execution/datafusion/util/spark_bloom_filter.rs b/core/src/execution/datafusion/util/spark_bloom_filter.rs index 19f0c96cb6..bf53a1dd88 100644 --- a/core/src/execution/datafusion/util/spark_bloom_filter.rs +++ b/core/src/execution/datafusion/util/spark_bloom_filter.rs @@ -15,11 +15,8 @@ // specific language governing permissions and limitations // under the License. -use crate::{ - errors::CometResult, - execution::datafusion::{ - spark_hash::spark_compatible_murmur3_hash, util::spark_bit_array::SparkBitArray, - }, +use crate::execution::datafusion::{ + spark_hash::spark_compatible_murmur3_hash, util::spark_bit_array::SparkBitArray, }; use arrow_array::{ArrowNativeTypeOp, BooleanArray, Int64Array}; @@ -28,6 +25,7 @@ const SPARK_BLOOM_FILTER_VERSION_1: i32 = 1; /// A Bloom filter implementation that simulates the behavior of Spark's BloomFilter. /// It's not a complete implementation of Spark's BloomFilter, but just add the minimum /// methods to support mightContainsLong in the native side. + #[derive(Debug, Hash)] pub struct SparkBloomFilter { bits: SparkBitArray, @@ -60,9 +58,7 @@ impl SparkBloomFilter { pub fn put_long(&mut self, item: i64) -> bool { // Here we first hash the input long element into 2 int hash values, h1 and h2, then produce - // n hash values by `h1 + i * h2` with 1 <= i <= numHashFunctions. - // Note that `CountMinSketch` use a different strategy, it hashes the input long element - // with every i to produce n hash values. + // n hash values by `h1 + i * h2` with 1 <= i <= num_hashes. let h1 = spark_compatible_murmur3_hash(item.to_le_bytes(), 0); let h2 = spark_compatible_murmur3_hash(item.to_le_bytes(), h1); let bit_size = self.bits.bit_size() as i32; @@ -93,10 +89,10 @@ impl SparkBloomFilter { true } - pub fn might_contain_longs(&self, items: &Int64Array) -> CometResult { - Ok(items + pub fn might_contain_longs(&self, items: &Int64Array) -> BooleanArray { + items .iter() .map(|v| v.map(|x| self.might_contain_long(x))) - .collect()) + .collect() } } diff --git a/pom.xml b/pom.xml index 2b4515630f..bdc55ecd92 100644 --- a/pom.xml +++ b/pom.xml @@ -86,6 +86,7 @@ under the License. -Djdk.reflect.useDirectMethodHandle=false -ea -Xmx4g -Xss4m ${extraJavaTestArgs} + spark-3.3-plus @@ -494,7 +495,8 @@ under the License. 3.2.2 3.2 1.12.0 - spark-3.2 + + not-needed-yet diff --git a/spark/src/test/spark-3.3-plus/org/apache/comet/CometExpressionPlusSuite.scala b/spark/src/test/spark-3.3-plus/org/apache/comet/CometExpression3_3PlusSuite.scala similarity index 88% rename from spark/src/test/spark-3.3-plus/org/apache/comet/CometExpressionPlusSuite.scala rename to spark/src/test/spark-3.3-plus/org/apache/comet/CometExpression3_3PlusSuite.scala index 2f751143eb..6102777fc5 100644 --- a/spark/src/test/spark-3.3-plus/org/apache/comet/CometExpressionPlusSuite.scala +++ b/spark/src/test/spark-3.3-plus/org/apache/comet/CometExpression3_3PlusSuite.scala @@ -19,7 +19,7 @@ package org.apache.comet -import org.apache.spark.sql.{Column, CometTestBase, DataFrame, Row} +import org.apache.spark.sql.{Column, CometTestBase} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.expressions.{BloomFilterMightContain, Expression, ExpressionInfo} @@ -29,7 +29,7 @@ import org.apache.spark.util.sketch.BloomFilter import java.io.ByteArrayOutputStream import scala.util.Random -class CometExpressionPlusSuite extends CometTestBase with AdaptiveSparkPlanHelper { +class CometExpression3_3PlusSuite extends CometTestBase with AdaptiveSparkPlanHelper { import testImplicits._ val func_might_contain = new FunctionIdentifier("might_contain") @@ -49,6 +49,7 @@ class CometExpressionPlusSuite extends CometTestBase with AdaptiveSparkPlanHelpe test("test BloomFilterMightContain can take a constant value input") { val table = "test" + withTable(table) { sql(s"create table $table(col1 long, col2 int) using parquet") sql(s"insert into $table values (201, 1)") @@ -62,6 +63,7 @@ class CometExpressionPlusSuite extends CometTestBase with AdaptiveSparkPlanHelpe test("test NULL inputs for BloomFilterMightContain") { val table = "test" + withTable(table) { sql(s"create table $table(col1 long, col2 int) using parquet") sql(s"insert into $table values (201, 1), (null, 2)") @@ -77,13 +79,9 @@ class CometExpressionPlusSuite extends CometTestBase with AdaptiveSparkPlanHelpe } test("test BloomFilterMightContain from random input") { - val bf = BloomFilter.create(100000, 10000) - val longs = (0 until 10000).map(_ => Random.nextLong()) - longs.foreach(bf.put) - val os = new ByteArrayOutputStream() - bf.writeTo(os) - val bfBytes = os.toByteArray + val (longs, bfBytes) = bloomFilterFromRandomInput(10000, 10000) val table = "test" + withTable(table) { sql(s"create table $table(col1 long, col2 binary) using parquet") spark.createDataset(longs).map(x => (x, bfBytes)).toDF("col1", "col2").write.insertInto(table) @@ -97,6 +95,12 @@ class CometExpressionPlusSuite extends CometTestBase with AdaptiveSparkPlanHelpe } } - - + private def bloomFilterFromRandomInput(expectedItems: Long, expectedBits: Long): (Seq[Long], Array[Byte]) = { + val bf = BloomFilter.create(expectedItems, expectedBits) + val longs = (0 until expectedItems.toInt).map(_ => Random.nextLong()) + longs.foreach(bf.put) + val os = new ByteArrayOutputStream() + bf.writeTo(os) + (longs, os.toByteArray) + } }