Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
advancedxy committed Mar 13, 2024
1 parent 1cdb5e3 commit e13c168
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 89 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,11 @@ use crate::{
execution::datafusion::util::spark_bloom_filter::SparkBloomFilter, parquet::data_type::AsBytes,
};
use arrow::record_batch::RecordBatch;
use arrow_array::{BooleanArray, Int64Array};
use arrow_schema::DataType;
use datafusion::{common::Result, physical_plan::ColumnarValue};
use datafusion_common::{internal_err, DataFusionError, Result as DataFusionResult, ScalarValue};
use arrow_array::{cast::as_primitive_array, BooleanArray};
use arrow_schema::{DataType, Schema};
use datafusion::physical_plan::ColumnarValue;
use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue};
use datafusion_physical_expr::{aggregate::utils::down_cast_any_ref, PhysicalExpr};
use once_cell::sync::OnceCell;
use std::{
any::Any,
fmt::Display,
Expand All @@ -34,11 +33,12 @@ use std::{

/// A physical expression that checks if a value might be in a bloom filter. It corresponds to the
/// Spark's `BloomFilterMightContain` expression.
#[derive(Debug)]
#[derive(Debug, Hash)]
pub struct BloomFilterMightContain {
pub bloom_filter_expr: Arc<dyn PhysicalExpr>,
pub value_expr: Arc<dyn PhysicalExpr>,
bloom_filter: OnceCell<Option<SparkBloomFilter>>,
bloom_filter: Option<SparkBloomFilter>,
}

impl Display for BloomFilterMightContain {
Expand All @@ -63,15 +63,33 @@ impl PartialEq<dyn Any> for BloomFilterMightContain {
}
}

fn evaluate_bloom_filter(
bloom_filter_expr: &Arc<dyn PhysicalExpr>,
) -> Result<Option<SparkBloomFilter>> {
// bloom_filter_expr must be a literal/scalar subquery expression, so we can evaluate it
// with an empty batch with empty schema
let batch = RecordBatch::new_empty(Arc::new(Schema::empty()));
let bloom_filter_bytes = bloom_filter_expr.evaluate(&batch)?;
match bloom_filter_bytes {
ColumnarValue::Scalar(ScalarValue::Binary(v)) => {
Ok(v.map(|v| SparkBloomFilter::new_from_buf(v.as_bytes())))
}
_ => internal_err!("Bloom filter expression must be evaluated as a scalar binary value"),
}
}

impl BloomFilterMightContain {
pub fn new(
bloom_filter_expr: Arc<dyn PhysicalExpr>,
value_expr: Arc<dyn PhysicalExpr>,
) -> Self {
// early evaluate the bloom_filter_expr to get the actual bloom filter
let bloom_filter = evaluate_bloom_filter(&bloom_filter_expr)
.expect("bloom_filter_expr could be evaluated statically");
Self {
bloom_filter_expr,
value_expr,
bloom_filter: Default::default(),
bloom_filter,
}
}
}
Expand All @@ -81,66 +99,40 @@ impl PhysicalExpr for BloomFilterMightContain {
self
}

fn data_type(&self, _input_schema: &arrow_schema::Schema) -> Result<DataType> {
fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
Ok(DataType::Boolean)
}

fn nullable(&self, _input_schema: &arrow_schema::Schema) -> Result<bool> {
fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
Ok(true)
}

fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult<ColumnarValue> {
fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
// lazily get the spark bloom filter
if self.bloom_filter.get().is_none() {
let bloom_filter_bytes = self.bloom_filter_expr.evaluate(batch)?;
match bloom_filter_bytes {
ColumnarValue::Array(_) => {
return internal_err!(
"Bloom filter expression must be evaluated as a scalar value"
);
}
ColumnarValue::Scalar(ScalarValue::Binary(v)) => {
let filter = v.map(|v| SparkBloomFilter::new_from_buf(v.as_bytes()));
self.bloom_filter.get_or_init(|| filter);
}
_ => {
return internal_err!("Bloom filter expression must be binary type");
}
}
}
let num_rows = batch.num_rows();
let lazy_filter = self.bloom_filter.get().unwrap();
if lazy_filter.is_none() {
// when the bloom filter is null, we should return a boolean array with all nulls
Ok(ColumnarValue::Array(Arc::new(BooleanArray::new_null(
num_rows,
))))
} else {
let spark_filter = lazy_filter.as_ref().unwrap();
let values = self.value_expr.evaluate(batch)?;
match values {
ColumnarValue::Array(array) => {
let array = array
.as_any()
.downcast_ref::<Int64Array>()
.expect("value_expr must be evaluated as an int64 array");
Ok(ColumnarValue::Array(Arc::new(
spark_filter.might_contain_longs(array)?,
)))
}
ColumnarValue::Scalar(a) => match a {
ScalarValue::Int64(v) => {
self.bloom_filter
.as_ref()
.map(|spark_filter| {
let values = self.value_expr.evaluate(batch)?;
match values {
ColumnarValue::Array(array) => {
let boolean_array =
spark_filter.might_contain_longs(as_primitive_array(&array));
Ok(ColumnarValue::Array(Arc::new(boolean_array)))
}
ColumnarValue::Scalar(ScalarValue::Int64(v)) => {
let result = v.map(|v| spark_filter.might_contain_long(v));
Ok(ColumnarValue::Scalar(ScalarValue::Boolean(result)))
}
_ => {
internal_err!(
"value_expr must be evaluated as an int64 array or a int64 scalar"
)
}
},
}
}
_ => internal_err!("value expression must be int64 type"),
}
})
.unwrap_or_else(|| {
// when the bloom filter is null, we should return a boolean array with all nulls
Ok(ColumnarValue::Array(Arc::new(BooleanArray::new_null(
num_rows,
))))
})
}

fn children(&self) -> Vec<Arc<dyn PhysicalExpr>> {
Expand All @@ -161,5 +153,6 @@ impl PhysicalExpr for BloomFilterMightContain {
let mut s = state;
self.bloom_filter_expr.hash(&mut s);
self.value_expr.hash(&mut s);
self.hash(&mut s);
}
}
2 changes: 1 addition & 1 deletion core/src/execution/datafusion/spark_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use datafusion::{
};

#[inline]
pub fn spark_compatible_murmur3_hash<T: AsRef<[u8]>>(data: T, seed: u32) -> u32 {
pub(crate) fn spark_compatible_murmur3_hash<T: AsRef<[u8]>>(data: T, seed: u32) -> u32 {
#[inline]
fn mix_k1(mut k1: i32) -> i32 {
k1 = k1.mul_wrapping(0xcc9e2d51u32 as i32);
Expand Down
10 changes: 1 addition & 9 deletions core/src/execution/datafusion/util/spark_bit_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
/// A simple bit array implementation that simulates the behavior of Spark's BitArray which is
/// used in the BloomFilter implementation. Some methods are not implemented as they are not
/// required for the current use case.
#[derive(Debug, Hash)]
pub struct SparkBitArray {
data: Vec<u64>,
Expand All @@ -33,15 +34,6 @@ impl SparkBitArray {
}
}

pub fn new_from_bit_count(num_bits: usize) -> Self {
let num_words = (num_bits + 63) / 64;
debug_assert!(num_words < u32::MAX as usize, "num_words is too large");
Self {
data: vec![0u64; num_words],
bit_count: num_bits,
}
}

pub fn set(&mut self, index: usize) -> bool {
if !self.get(index) {
self.data[index >> 6] |= 1u64 << (index & 0x3f);
Expand Down
18 changes: 7 additions & 11 deletions core/src/execution/datafusion/util/spark_bloom_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,8 @@
// specific language governing permissions and limitations
// under the License.

use crate::{
errors::CometResult,
execution::datafusion::{
spark_hash::spark_compatible_murmur3_hash, util::spark_bit_array::SparkBitArray,
},
use crate::execution::datafusion::{
spark_hash::spark_compatible_murmur3_hash, util::spark_bit_array::SparkBitArray,
};
use arrow_array::{ArrowNativeTypeOp, BooleanArray, Int64Array};

Expand All @@ -28,6 +25,7 @@ const SPARK_BLOOM_FILTER_VERSION_1: i32 = 1;
/// A Bloom filter implementation that simulates the behavior of Spark's BloomFilter.
/// It's not a complete implementation of Spark's BloomFilter, but just add the minimum
/// methods to support mightContainsLong in the native side.
#[derive(Debug, Hash)]
pub struct SparkBloomFilter {
bits: SparkBitArray,
Expand Down Expand Up @@ -60,9 +58,7 @@ impl SparkBloomFilter {

pub fn put_long(&mut self, item: i64) -> bool {
// Here we first hash the input long element into 2 int hash values, h1 and h2, then produce
// n hash values by `h1 + i * h2` with 1 <= i <= numHashFunctions.
// Note that `CountMinSketch` use a different strategy, it hashes the input long element
// with every i to produce n hash values.
// n hash values by `h1 + i * h2` with 1 <= i <= num_hashes.
let h1 = spark_compatible_murmur3_hash(item.to_le_bytes(), 0);
let h2 = spark_compatible_murmur3_hash(item.to_le_bytes(), h1);
let bit_size = self.bits.bit_size() as i32;
Expand Down Expand Up @@ -93,10 +89,10 @@ impl SparkBloomFilter {
true
}

pub fn might_contain_longs(&self, items: &Int64Array) -> CometResult<BooleanArray> {
Ok(items
pub fn might_contain_longs(&self, items: &Int64Array) -> BooleanArray {
items
.iter()
.map(|v| v.map(|x| self.might_contain_long(x)))
.collect())
.collect()
}
}
4 changes: 3 additions & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ under the License.
-Djdk.reflect.useDirectMethodHandle=false
</extraJavaTestArgs>
<argLine>-ea -Xmx4g -Xss4m ${extraJavaTestArgs}</argLine>
<additional.test.source>spark-3.3-plus</additional.test.source>
</properties>

<dependencyManagement>
Expand Down Expand Up @@ -494,7 +495,8 @@ under the License.
<spark.version>3.2.2</spark.version>
<spark.version.short>3.2</spark.version.short>
<parquet.version>1.12.0</parquet.version>
<additional.test.source>spark-3.2</additional.test.source>
<!-- we don't add special test suits for spark-3.2, so a not existed dir is specified-->
<additional.test.source>not-needed-yet</additional.test.source>
</properties>
</profile>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

package org.apache.comet

import org.apache.spark.sql.{Column, CometTestBase, DataFrame, Row}
import org.apache.spark.sql.{Column, CometTestBase}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.expressions.{BloomFilterMightContain, Expression, ExpressionInfo}
Expand All @@ -29,7 +29,7 @@ import org.apache.spark.util.sketch.BloomFilter
import java.io.ByteArrayOutputStream
import scala.util.Random

class CometExpressionPlusSuite extends CometTestBase with AdaptiveSparkPlanHelper {
class CometExpression3_3PlusSuite extends CometTestBase with AdaptiveSparkPlanHelper {
import testImplicits._

val func_might_contain = new FunctionIdentifier("might_contain")
Expand All @@ -49,6 +49,7 @@ class CometExpressionPlusSuite extends CometTestBase with AdaptiveSparkPlanHelpe

test("test BloomFilterMightContain can take a constant value input") {
val table = "test"

withTable(table) {
sql(s"create table $table(col1 long, col2 int) using parquet")
sql(s"insert into $table values (201, 1)")
Expand All @@ -62,6 +63,7 @@ class CometExpressionPlusSuite extends CometTestBase with AdaptiveSparkPlanHelpe

test("test NULL inputs for BloomFilterMightContain") {
val table = "test"

withTable(table) {
sql(s"create table $table(col1 long, col2 int) using parquet")
sql(s"insert into $table values (201, 1), (null, 2)")
Expand All @@ -77,13 +79,9 @@ class CometExpressionPlusSuite extends CometTestBase with AdaptiveSparkPlanHelpe
}

test("test BloomFilterMightContain from random input") {
val bf = BloomFilter.create(100000, 10000)
val longs = (0 until 10000).map(_ => Random.nextLong())
longs.foreach(bf.put)
val os = new ByteArrayOutputStream()
bf.writeTo(os)
val bfBytes = os.toByteArray
val (longs, bfBytes) = bloomFilterFromRandomInput(10000, 10000)
val table = "test"

withTable(table) {
sql(s"create table $table(col1 long, col2 binary) using parquet")
spark.createDataset(longs).map(x => (x, bfBytes)).toDF("col1", "col2").write.insertInto(table)
Expand All @@ -97,6 +95,12 @@ class CometExpressionPlusSuite extends CometTestBase with AdaptiveSparkPlanHelpe
}
}



private def bloomFilterFromRandomInput(expectedItems: Long, expectedBits: Long): (Seq[Long], Array[Byte]) = {
val bf = BloomFilter.create(expectedItems, expectedBits)
val longs = (0 until expectedItems.toInt).map(_ => Random.nextLong())
longs.foreach(bf.put)
val os = new ByteArrayOutputStream()
bf.writeTo(os)
(longs, os.toByteArray)
}
}

0 comments on commit e13c168

Please sign in to comment.