From dbddcd37f5a56d1e24d9c7b092f89c9c1460fe7e Mon Sep 17 00:00:00 2001 From: Xianjin YE Date: Sun, 10 Mar 2024 00:30:07 +0800 Subject: [PATCH] feat: Support BloomFilterMightContain expr --- .../expressions/bloom_filter_might_contain.rs | 165 ++++++++++++++++++ .../execution/datafusion/expressions/mod.rs | 1 + core/src/execution/datafusion/mod.rs | 1 + core/src/execution/datafusion/planner.rs | 10 ++ core/src/execution/datafusion/spark_hash.rs | 2 +- core/src/execution/datafusion/util/mod.rs | 19 ++ .../datafusion/util/spark_bit_array.rs | 130 ++++++++++++++ .../datafusion/util/spark_bloom_filter.rs | 117 +++++++++++++ core/src/execution/proto/expr.proto | 6 + pom.xml | 8 + spark/pom.xml | 18 ++ .../comet/CometSparkSessionExtensions.scala | 1 - .../apache/comet/serde/QueryPlanSerde.scala | 19 +- .../comet/shims/ShimQueryPlanSerde.scala | 6 +- .../org/apache/comet/CometCastSuite.scala | 4 +- .../comet/CometExpressionPlusSuite.scala | 102 +++++++++++ 16 files changed, 603 insertions(+), 6 deletions(-) create mode 100644 core/src/execution/datafusion/expressions/bloom_filter_might_contain.rs create mode 100644 core/src/execution/datafusion/util/mod.rs create mode 100644 core/src/execution/datafusion/util/spark_bit_array.rs create mode 100644 core/src/execution/datafusion/util/spark_bloom_filter.rs create mode 100644 spark/src/test/spark-3.3-plus/org/apache/comet/CometExpressionPlusSuite.scala diff --git a/core/src/execution/datafusion/expressions/bloom_filter_might_contain.rs b/core/src/execution/datafusion/expressions/bloom_filter_might_contain.rs new file mode 100644 index 0000000000..c08ed5e3a6 --- /dev/null +++ b/core/src/execution/datafusion/expressions/bloom_filter_might_contain.rs @@ -0,0 +1,165 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{ + execution::datafusion::util::spark_bloom_filter::SparkBloomFilter, parquet::data_type::AsBytes, +}; +use arrow::record_batch::RecordBatch; +use arrow_array::{BooleanArray, Int64Array}; +use arrow_schema::DataType; +use datafusion::{common::Result, physical_plan::ColumnarValue}; +use datafusion_common::{internal_err, DataFusionError, Result as DataFusionResult, ScalarValue}; +use datafusion_physical_expr::{aggregate::utils::down_cast_any_ref, PhysicalExpr}; +use log::info; +use once_cell::sync::OnceCell; +use std::{ + any::Any, + fmt::Display, + hash::{Hash, Hasher}, + sync::Arc, +}; + +#[derive(Debug)] +pub struct BloomFilterMightContain { + pub bloom_filter_expr: Arc, + pub value_expr: Arc, + bloom_filter: OnceCell>, +} + +impl Display for BloomFilterMightContain { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!( + f, + "BloomFilterMightContain [bloom_filter_expr: {}, value_expr: {}]", + self.bloom_filter_expr, self.value_expr + ) + } +} + +impl PartialEq for BloomFilterMightContain { + fn eq(&self, _other: &dyn Any) -> bool { + down_cast_any_ref(_other) + .downcast_ref::() + .map(|other| { + self.bloom_filter_expr.eq(&other.bloom_filter_expr) + && self.value_expr.eq(&other.value_expr) + }) + .unwrap_or(false) + } +} + +impl BloomFilterMightContain { + pub fn new( + bloom_filter_expr: Arc, + value_expr: Arc, + ) -> Self { + Self { + bloom_filter_expr, + value_expr, + bloom_filter: Default::default(), + } + } +} + +impl PhysicalExpr for BloomFilterMightContain { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, _input_schema: &arrow_schema::Schema) -> Result { + Ok(DataType::Boolean) + } + + fn nullable(&self, _input_schema: &arrow_schema::Schema) -> Result { + Ok(true) + } + + fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult { + // lazily get the spark bloom filter + if self.bloom_filter.get().is_none() { + let bloom_filter_bytes = self.bloom_filter_expr.evaluate(batch)?; + match bloom_filter_bytes { + ColumnarValue::Array(_) => { + return internal_err!( + "Bloom filter expression must be evaluated as a scalar value" + ); + } + ColumnarValue::Scalar(ScalarValue::Binary(v)) => { + info!("init for bloom filter"); + let filter = v.map(|v| SparkBloomFilter::new_from_buf(v.as_bytes())); + self.bloom_filter.get_or_init(|| filter); + } + _ => { + return internal_err!("Bloom filter expression must be binary type"); + } + } + } + let num_rows = batch.num_rows(); + let lazy_filter = self.bloom_filter.get().unwrap(); + if lazy_filter.is_none() { + // when the bloom filter is null, we should return a boolean array with all nulls + return Ok(ColumnarValue::Array(Arc::new(BooleanArray::new_null( + num_rows, + )))); + } else { + let spark_filter = lazy_filter.as_ref().unwrap(); + let values = self.value_expr.evaluate(batch)?; + match values { + ColumnarValue::Array(array) => { + let array = array + .as_any() + .downcast_ref::() + .expect("value_expr must be evaluated as an int64 array"); + Ok(ColumnarValue::Array(Arc::new( + spark_filter.might_contain_longs(array)?, + ))) + } + ColumnarValue::Scalar(a) => match a { + ScalarValue::Int64(v) => { + let result = v.map(|v| spark_filter.might_contain_long(v)); + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(result))) + } + _ => { + return internal_err!( + "value_expr must be evaluated as an int64 array or a int64 scalar" + ); + } + }, + } + } + } + + fn children(&self) -> Vec> { + vec![self.bloom_filter_expr.clone(), self.value_expr.clone()] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(BloomFilterMightContain::new( + children[0].clone(), + children[1].clone(), + ))) + } + + fn dyn_hash(&self, state: &mut dyn Hasher) { + let mut s = state; + self.bloom_filter_expr.hash(&mut s); + self.value_expr.hash(&mut s); + } +} diff --git a/core/src/execution/datafusion/expressions/mod.rs b/core/src/execution/datafusion/expressions/mod.rs index cfc312510b..69cdf3e997 100644 --- a/core/src/execution/datafusion/expressions/mod.rs +++ b/core/src/execution/datafusion/expressions/mod.rs @@ -26,6 +26,7 @@ pub mod scalar_funcs; pub use normalize_nan::NormalizeNaNAndZero; pub mod avg; pub mod avg_decimal; +pub mod bloom_filter_might_contain; pub mod strings; pub mod subquery; pub mod sum_decimal; diff --git a/core/src/execution/datafusion/mod.rs b/core/src/execution/datafusion/mod.rs index f9fafeb292..c464eeed0b 100644 --- a/core/src/execution/datafusion/mod.rs +++ b/core/src/execution/datafusion/mod.rs @@ -22,3 +22,4 @@ mod operators; pub mod planner; pub(crate) mod shuffle_writer; mod spark_hash; +mod util; diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index 33cf636a2e..8cf1f86960 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -59,6 +59,7 @@ use crate::{ avg::Avg, avg_decimal::AvgDecimal, bitwise_not::BitwiseNotExpr, + bloom_filter_might_contain::BloomFilterMightContain, cast::Cast, checkoverflow::CheckOverflow, if_expr::IfExpr, @@ -525,6 +526,15 @@ impl PhysicalPlanner { let data_type = to_arrow_datatype(expr.datatype.as_ref().unwrap()); Ok(Arc::new(Subquery::new(self.exec_context_id, id, data_type))) } + ExprStruct::BloomFilterMightContain(expr) => { + let bloom_filter_expr = + self.create_expr(expr.bloom_filter.as_ref().unwrap(), input_schema.clone())?; + let value_expr = self.create_expr(expr.value.as_ref().unwrap(), input_schema)?; + Ok(Arc::new(BloomFilterMightContain::new( + bloom_filter_expr, + value_expr, + ))) + } expr => Err(ExecutionError::GeneralError(format!( "Not implemented: {:?}", expr diff --git a/core/src/execution/datafusion/spark_hash.rs b/core/src/execution/datafusion/spark_hash.rs index aeefccf5bb..eadce34e03 100644 --- a/core/src/execution/datafusion/spark_hash.rs +++ b/core/src/execution/datafusion/spark_hash.rs @@ -32,7 +32,7 @@ use datafusion::{ }; #[inline] -fn spark_compatible_murmur3_hash>(data: T, seed: u32) -> u32 { +pub fn spark_compatible_murmur3_hash>(data: T, seed: u32) -> u32 { #[inline] fn mix_k1(mut k1: i32) -> i32 { k1 = k1.mul_wrapping(0xcc9e2d51u32 as i32); diff --git a/core/src/execution/datafusion/util/mod.rs b/core/src/execution/datafusion/util/mod.rs new file mode 100644 index 0000000000..75b763af5b --- /dev/null +++ b/core/src/execution/datafusion/util/mod.rs @@ -0,0 +1,19 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub mod spark_bit_array; +pub mod spark_bloom_filter; diff --git a/core/src/execution/datafusion/util/spark_bit_array.rs b/core/src/execution/datafusion/util/spark_bit_array.rs new file mode 100644 index 0000000000..33b55dd2ca --- /dev/null +++ b/core/src/execution/datafusion/util/spark_bit_array.rs @@ -0,0 +1,130 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#[derive(Debug, Hash)] +pub struct SparkBitArray { + data: Vec, + bit_count: usize, +} + +impl SparkBitArray { + pub fn new(buf: Vec) -> Self { + let num_bits = buf.iter().map(|x| x.count_ones() as usize).sum(); + Self { + data: buf, + bit_count: num_bits, + } + } + + pub fn new_from_bit_count(num_bits: usize) -> Self { + let num_words = (num_bits + 63) / 64; + debug_assert!(num_words < u32::MAX as usize, "num_words is too large"); + Self { + data: vec![0u64; num_words], + bit_count: num_bits, + } + } + + pub fn set(&mut self, index: usize) -> bool { + if !self.get(index) { + self.data[index >> 6] |= 1u64 << (index & 0x3f); + self.bit_count += 1; + true + } else { + false + } + } + + pub fn get(&self, index: usize) -> bool { + (self.data[index >> 6] & (1u64 << (index & 0x3f))) != 0 + } + + pub fn bit_size(&self) -> u64 { + self.data.len() as u64 * 64 + } + + pub fn cardinality(&self) -> usize { + self.bit_count + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_spark_bit_array() { + let buf = vec![0u64; 4]; + let mut array = SparkBitArray::new(buf); + assert_eq!(array.bit_size(), 256); + assert_eq!(array.cardinality(), 0); + + assert!(!array.get(0)); + assert!(!array.get(1)); + assert!(!array.get(63)); + assert!(!array.get(64)); + assert!(!array.get(65)); + assert!(!array.get(127)); + assert!(!array.get(128)); + assert!(!array.get(129)); + + assert!(array.set(0)); + assert!(array.set(1)); + assert!(array.set(63)); + assert!(array.set(64)); + assert!(array.set(65)); + assert!(array.set(127)); + assert!(array.set(128)); + assert!(array.set(129)); + + assert_eq!(array.cardinality(), 8); + assert_eq!(array.bit_size(), 256); + + assert!(array.get(0)); + // already set so should return false + assert!(!array.set(0)); + + // not set values should return false for get + assert!(!array.get(2)); + assert!(!array.get(62)); + } + + #[test] + fn test_spark_bit_with_non_empty_buffer() { + let buf = vec![8u64; 4]; + let mut array = SparkBitArray::new(buf); + assert_eq!(array.bit_size(), 256); + assert_eq!(array.cardinality(), 4); + + // already set bits should return true + assert!(array.get(3)); + assert!(array.get(67)); + assert!(array.get(131)); + assert!(array.get(195)); + + // other unset bits should return false + assert!(!array.get(0)); + assert!(!array.get(1)); + + // set bits + assert!(array.set(0)); + assert!(array.set(1)); + + // check cardinality + assert_eq!(array.cardinality(), 6); + } +} diff --git a/core/src/execution/datafusion/util/spark_bloom_filter.rs b/core/src/execution/datafusion/util/spark_bloom_filter.rs new file mode 100644 index 0000000000..46c975db07 --- /dev/null +++ b/core/src/execution/datafusion/util/spark_bloom_filter.rs @@ -0,0 +1,117 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{ + errors::CometResult, + execution::datafusion::{ + spark_hash::spark_compatible_murmur3_hash, util::spark_bit_array::SparkBitArray, + }, +}; +use arrow_array::{ArrowNativeTypeOp, BooleanArray, Int64Array}; + +const SPARK_BLOOM_FILTER_VERSION_1: i32 = 1; + +/// A Bloom filter implementation that simulates the behavior of Spark's BloomFilter. +/// It's not a complete implementation of Spark's BloomFilter, but just add the minimum +/// methods to support mightContainsLong in the native side. +#[derive(Debug, Hash)] +pub struct SparkBloomFilter { + bits: SparkBitArray, + num_hashes: u32, +} + +/// similar to the `read_num_be_bytes` macro in `bit.rs` but read nums from bytes in big-endian +macro_rules! read_num_be_bytes { + ($ty:ty, $size:expr, $src:expr) => {{ + debug_assert!($size <= $src.len()); + let mut buffer = <$ty as $crate::common::bit::FromBytes>::Buffer::default(); + buffer.as_mut()[..$size].copy_from_slice(&$src[..$size]); + <$ty>::from_be_bytes(buffer) + }}; +} + +impl SparkBloomFilter { + pub fn new_from_buf(buf: &[u8]) -> Self { + let mut offset = 0; + let version = read_num_be_bytes!(i32, 4, buf[offset..]); + offset += 4; + assert_eq!( + version, SPARK_BLOOM_FILTER_VERSION_1, + "Unsupported BloomFilter version" + ); + let num_hashes = read_num_be_bytes!(i32, 4, buf[offset..]); + offset += 4; + let num_words = read_num_be_bytes!(i32, 4, buf[offset..]); + offset += 4; + let mut bits = vec![0u64; num_words as usize]; + for i in 0..num_words { + bits[i as usize] = read_num_be_bytes!(i64, 8, buf[offset..]) as u64; + offset += 8; + } + Self { + bits: SparkBitArray::new(bits), + num_hashes: num_hashes as u32, + } + } + + pub fn put_long(&mut self, item: i64) -> bool { + // Here we first hash the input long element into 2 int hash values, h1 and h2, then produce + // n hash values by `h1 + i * h2` with 1 <= i <= numHashFunctions. + // Note that `CountMinSketch` use a different strategy, it hashes the input long element + // with every i to produce n hash values. + let h1 = spark_compatible_murmur3_hash(item.to_le_bytes(), 0); + let h2 = spark_compatible_murmur3_hash(item.to_le_bytes(), h1); + let bit_size = self.bits.bit_size() as i32; + let mut bit_changed = false; + for i in 1..=self.num_hashes { + let mut combined_hash = (h1 as i32).add_wrapping((i as i32).mul_wrapping(h2 as i32)); + if combined_hash < 0 { + combined_hash = !combined_hash; + } + bit_changed |= self.bits.set((combined_hash % bit_size) as usize) + } + bit_changed + } + + pub fn might_contain_long(&self, item: i64) -> bool { + let h1 = spark_compatible_murmur3_hash(item.to_le_bytes(), 0); + let h2 = spark_compatible_murmur3_hash(item.to_le_bytes(), h1); + let bit_size = self.bits.bit_size() as i32; + for i in 1..=self.num_hashes { + let mut combined_hash = (h1 as i32).add_wrapping((i as i32).mul_wrapping(h2 as i32)); + if combined_hash < 0 { + combined_hash = !combined_hash; + } + if !self.bits.get((combined_hash % bit_size) as usize) { + return false; + } + } + true + } + + pub fn might_contain_longs(&self, items: &Int64Array) -> CometResult { + Ok(items + .iter() + .map(|v| { + v.and_then(|v| { + let might_contain = self.might_contain_long(v); + Some(might_contain) + }) + }) + .collect()) + } +} diff --git a/core/src/execution/proto/expr.proto b/core/src/execution/proto/expr.proto index 8aa81b7672..2a6df4a701 100644 --- a/core/src/execution/proto/expr.proto +++ b/core/src/execution/proto/expr.proto @@ -76,6 +76,7 @@ message Expr { Abs abs = 49; Subquery subquery = 50; UnboundReference unbound = 51; + BloomFilterMightContain bloom_filter_might_contain = 52; } } @@ -414,6 +415,11 @@ message Subquery { DataType datatype = 2; } +message BloomFilterMightContain { + Expr bloom_filter = 1; + Expr value = 2; +} + enum SortDirection { Ascending = 0; Descending = 1; diff --git a/pom.xml b/pom.xml index f19ce70deb..12a85d50e0 100644 --- a/pom.xml +++ b/pom.xml @@ -494,6 +494,7 @@ under the License. 3.2.2 3.2 1.12.0 + spark-3.2 @@ -504,6 +505,7 @@ under the License. 3.3.2 3.3 1.12.0 + spark-3.3-plus @@ -513,6 +515,7 @@ under the License. 2.12.17 3.4 1.13.1 + spark-3.3-plus @@ -777,6 +780,11 @@ under the License. jacoco-maven-plugin ${jacoco.version} + + org.codehaus.mojo + build-helper-maven-plugin + 3.2.0 + diff --git a/spark/pom.xml b/spark/pom.xml index 7e54fde060..31d80bbe69 100644 --- a/spark/pom.xml +++ b/spark/pom.xml @@ -233,6 +233,24 @@ under the License. net.alchim31.maven scala-maven-plugin + + org.codehaus.mojo + build-helper-maven-plugin + + + add-test-source + generate-test-sources + + add-test-source + + + + src/test/${additional.test.source} + + + + + diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index 87c2265fcb..5720b69354 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -26,7 +26,6 @@ import org.apache.spark.internal.Logging import org.apache.spark.network.util.ByteUnit import org.apache.spark.sql.SparkSession import org.apache.spark.sql.SparkSessionExtensions -import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.comet._ import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle} diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index b27fa3a754..03c546c141 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, SinglePartition} import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils -import org.apache.spark.sql.comet.{CometHashAggregateExec, CometPlan, CometSinkPlaceHolder, DecimalPrecision} +import org.apache.spark.sql.comet.{CometSinkPlaceHolder, DecimalPrecision} import org.apache.spark.sql.execution import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.aggregate.HashAggregateExec @@ -1566,6 +1566,23 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { "make_decimal", DecimalType(precision, scale), childExpr) + case b @ BinaryExpression(_, _) if isBloomFilterMightContain(b) => + val bloomFilter = b.left + val value = b.right + val bloomFilterExpr = exprToProtoInternal(bloomFilter, inputs) + val valueExpr = exprToProtoInternal(value, inputs) + if (bloomFilterExpr.isDefined && valueExpr.isDefined) { + val builder = ExprOuterClass.BloomFilterMightContain.newBuilder() + builder.setBloomFilter(bloomFilterExpr.get) + builder.setValue(valueExpr.get) + Some( + ExprOuterClass.Expr + .newBuilder() + .setBloomFilterMightContain(builder) + .build()) + } else { + None + } case e => emitWarning(s"unsupported Spark expression: '$e' of class '${e.getClass.getName}") diff --git a/spark/src/main/scala/org/apache/comet/shims/ShimQueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/shims/ShimQueryPlanSerde.scala index c47b399cf3..fe7a8f0f6c 100644 --- a/spark/src/main/scala/org/apache/comet/shims/ShimQueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/shims/ShimQueryPlanSerde.scala @@ -19,7 +19,7 @@ package org.apache.comet.shims -import org.apache.spark.sql.catalyst.expressions.BinaryArithmetic +import org.apache.spark.sql.catalyst.expressions.{BinaryArithmetic, BinaryExpression} import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate trait ShimQueryPlanSerde { @@ -44,4 +44,8 @@ trait ShimQueryPlanSerde { failOnError.head } } + + def isBloomFilterMightContain(binary: BinaryExpression): Boolean = { + binary.getClass.getName == "org.apache.spark.sql.catalyst.expressions.BloomFilterMightContain" + } } diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index 565d2264b7..317371fb90 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -90,13 +90,13 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { Range(0, len).map(_ => chars.charAt(r.nextInt(chars.length))).mkString } - private def fuzzCastFromString(chars: String, maxLen: Int, toType: DataType) { + private def fuzzCastFromString(chars: String, maxLen: Int, toType: DataType): Unit = { val r = new Random(0) val inputs = Range(0, 10000).map(_ => genString(r, chars, maxLen)) castTest(inputs.toDF("a"), toType) } - private def castTest(input: DataFrame, toType: DataType) { + private def castTest(input: DataFrame, toType: DataType): Unit = { withTempPath { dir => val df = roundtripParquet(input, dir) .withColumn("converted", col("a").cast(toType)) diff --git a/spark/src/test/spark-3.3-plus/org/apache/comet/CometExpressionPlusSuite.scala b/spark/src/test/spark-3.3-plus/org/apache/comet/CometExpressionPlusSuite.scala new file mode 100644 index 0000000000..56829a0bad --- /dev/null +++ b/spark/src/test/spark-3.3-plus/org/apache/comet/CometExpressionPlusSuite.scala @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet + +import org.apache.spark.sql.{Column, CometTestBase, DataFrame, Row} +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.catalyst.expressions.{BloomFilterMightContain, Expression, ExpressionInfo} +import org.apache.spark.sql.functions.{col, lit} +import org.apache.spark.util.sketch.BloomFilter + +import java.io.ByteArrayOutputStream +import scala.util.Random + +class CometExpressionPlusSuite extends CometTestBase with AdaptiveSparkPlanHelper { + import testImplicits._ + + val func_might_contain = new FunctionIdentifier("might_contain") + + override def beforeAll(): Unit = { + super.beforeAll() + // Register 'might_contain' to builtin. + spark.sessionState.functionRegistry.registerFunction(func_might_contain, + new ExpressionInfo(classOf[BloomFilterMightContain].getName, "might_contain"), + (children: Seq[Expression]) => BloomFilterMightContain(children.head, children(1))) + } + + override def afterAll(): Unit = { + spark.sessionState.functionRegistry.dropFunction(func_might_contain) + super.afterAll() + } + + test("test BloomFilterMightContain can take a constant value input") { + val table = "test" + withTable(table) { + sql(s"create table $table(col1 long, col2 int) using parquet") + sql(s"insert into $table values (201, 1)") + checkSparkAnswerAndOperator( + s""" + |SELECT might_contain( + |X'00000001000000050000000343A2EC6EA8C117E2D3CDB767296B144FC5BFBCED9737F267', col1) FROM $table + |""".stripMargin) + } + } + + test("test NULl inputs for BloomFilterMightContain") { + val table = "test" + withTable(table) { + sql(s"create table $table(col1 long, col2 int) using parquet") + sql(s"insert into $table values (201, 1), (null, 2)") + checkSparkAnswerAndOperator( + s""" + |SELECT might_contain(null, null) both_null, + | might_contain(null, 1L) null_bf, + | might_contain( + | X'00000001000000050000000343A2EC6EA8C117E2D3CDB767296B144FC5BFBCED9737F267', col1) null_value + | FROM $table + |""".stripMargin) + } + } + + test("test BloomFilterMightContain from random input") { + val bf = BloomFilter.create(1000, 100) + val longs = (0 until 100).map(_ => Random.nextLong()) + longs.foreach(bf.put) + val os = new ByteArrayOutputStream() + bf.writeTo(os) + val bfBytes = os.toByteArray + val table = "test" + withTable(table) { + sql(s"create table $table(col1 long, col2 binary) using parquet") + spark.createDataset(longs).map(x => (x, bfBytes)).toDF("col1", "col2").write.insertInto(table) + val df = spark.table(table).select(new Column(BloomFilterMightContain(lit(bfBytes).expr, col("col1").expr))) + checkSparkAnswerAndOperator(df) + // check with scalar subquery + checkSparkAnswerAndOperator( + s""" + |SELECT might_contain((select first(col2) as col2 from $table), col1) FROM $table + |""".stripMargin) + } + } + + + +}