Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support BloomFilterMightContain expr #179

Merged
merged 7 commits into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions core/src/common/bit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,17 @@ pub fn read_num_bytes_u32(size: usize, src: &[u8]) -> u32 {
trailing_bits(v as u64, size * 8) as u32
}

/// Similar to the `read_num_bytes` but read nums from bytes in big-endian order
/// This is used to read bytes from Java's OutputStream which writes bytes in big-endian
macro_rules! read_num_be_bytes {
($ty:ty, $size:expr, $src:expr) => {{
debug_assert!($size <= $src.len());
let mut buffer = <$ty as $crate::common::bit::FromBytes>::Buffer::default();
buffer.as_mut()[..$size].copy_from_slice(&$src[..$size]);
<$ty>::from_be_bytes(buffer)
}};
}

/// Converts value `val` of type `T` to a byte vector, by reading `num_bytes` from `val`.
/// NOTE: if `val` is less than the size of `T` then it can be truncated.
#[inline]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use crate::{
execution::datafusion::util::spark_bloom_filter::SparkBloomFilter, parquet::data_type::AsBytes,
};
use arrow::record_batch::RecordBatch;
use arrow_array::cast::as_primitive_array;
use arrow_schema::{DataType, Schema};
use datafusion::physical_plan::ColumnarValue;
use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue};
use datafusion_physical_expr::{aggregate::utils::down_cast_any_ref, PhysicalExpr};
use std::{
any::Any,
fmt::Display,
hash::{Hash, Hasher},
sync::Arc,
};

/// A physical expression that checks if a value might be in a bloom filter. It corresponds to the
/// Spark's `BloomFilterMightContain` expression.
#[derive(Debug, Hash)]
pub struct BloomFilterMightContain {
pub bloom_filter_expr: Arc<dyn PhysicalExpr>,
pub value_expr: Arc<dyn PhysicalExpr>,
bloom_filter: Option<SparkBloomFilter>,
}

impl Display for BloomFilterMightContain {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(
f,
"BloomFilterMightContain [bloom_filter_expr: {}, value_expr: {}]",
self.bloom_filter_expr, self.value_expr
)
}
}

impl PartialEq<dyn Any> for BloomFilterMightContain {
fn eq(&self, _other: &dyn Any) -> bool {
down_cast_any_ref(_other)
.downcast_ref::<Self>()
.map(|other| {
self.bloom_filter_expr.eq(&other.bloom_filter_expr)
&& self.value_expr.eq(&other.value_expr)
})
.unwrap_or(false)
}
}

fn evaluate_bloom_filter(
bloom_filter_expr: &Arc<dyn PhysicalExpr>,
) -> Result<Option<SparkBloomFilter>> {
// bloom_filter_expr must be a literal/scalar subquery expression, so we can evaluate it
// with an empty batch with empty schema
let batch = RecordBatch::new_empty(Arc::new(Schema::empty()));
let bloom_filter_bytes = bloom_filter_expr.evaluate(&batch)?;
match bloom_filter_bytes {
ColumnarValue::Scalar(ScalarValue::Binary(v)) => {
Ok(v.map(|v| SparkBloomFilter::new(v.as_bytes())))
}
_ => internal_err!("Bloom filter expression should be evaluated as a scalar binary value"),
}
}

impl BloomFilterMightContain {
pub fn try_new(
bloom_filter_expr: Arc<dyn PhysicalExpr>,
value_expr: Arc<dyn PhysicalExpr>,
) -> Result<Self> {
// early evaluate the bloom_filter_expr to get the actual bloom filter
let bloom_filter = evaluate_bloom_filter(&bloom_filter_expr)?;
Ok(Self {
bloom_filter_expr,
value_expr,
bloom_filter,
})
}
}

impl PhysicalExpr for BloomFilterMightContain {
fn as_any(&self) -> &dyn Any {
self
}

fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
Ok(DataType::Boolean)
}

fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
Ok(true)
}

fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
self.bloom_filter
.as_ref()
.map(|spark_filter| {
let values = self.value_expr.evaluate(batch)?;
match values {
ColumnarValue::Array(array) => {
let boolean_array =
spark_filter.might_contain_longs(as_primitive_array(&array));
Ok(ColumnarValue::Array(Arc::new(boolean_array)))
}
ColumnarValue::Scalar(ScalarValue::Int64(v)) => {
let result = v.map(|v| spark_filter.might_contain_long(v));
Ok(ColumnarValue::Scalar(ScalarValue::Boolean(result)))
}
_ => internal_err!("value expression should be int64 type"),
}
})
.unwrap_or_else(|| {
// when the bloom filter is null, we should return null for all the input
Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None)))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than use ScalarValue::Null, I think ScalarValue::Boolean(None) is more appropriate? Since it contains the data type info

Copy link
Member

@viirya viirya Mar 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ScalarValue::Boolean(None) is correct. ScalarValue::Null is null type.

})
}

fn children(&self) -> Vec<Arc<dyn PhysicalExpr>> {
vec![self.bloom_filter_expr.clone(), self.value_expr.clone()]
}

fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn PhysicalExpr>>,
) -> Result<Arc<dyn PhysicalExpr>> {
Ok(Arc::new(BloomFilterMightContain::try_new(
children[0].clone(),
children[1].clone(),
)?))
}

fn dyn_hash(&self, state: &mut dyn Hasher) {
let mut s = state;
self.bloom_filter_expr.hash(&mut s);
self.value_expr.hash(&mut s);
self.hash(&mut s);
}
}
1 change: 1 addition & 0 deletions core/src/execution/datafusion/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ pub mod scalar_funcs;
pub use normalize_nan::NormalizeNaNAndZero;
pub mod avg;
pub mod avg_decimal;
pub mod bloom_filter_might_contain;
pub mod strings;
pub mod subquery;
pub mod sum_decimal;
Expand Down
1 change: 1 addition & 0 deletions core/src/execution/datafusion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ mod operators;
pub mod planner;
pub(crate) mod shuffle_writer;
mod spark_hash;
mod util;
10 changes: 10 additions & 0 deletions core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ use crate::{
avg::Avg,
avg_decimal::AvgDecimal,
bitwise_not::BitwiseNotExpr,
bloom_filter_might_contain::BloomFilterMightContain,
cast::Cast,
checkoverflow::CheckOverflow,
if_expr::IfExpr,
Expand Down Expand Up @@ -534,6 +535,15 @@ impl PhysicalPlanner {
let data_type = to_arrow_datatype(expr.datatype.as_ref().unwrap());
Ok(Arc::new(Subquery::new(self.exec_context_id, id, data_type)))
}
ExprStruct::BloomFilterMightContain(expr) => {
let bloom_filter_expr =
self.create_expr(expr.bloom_filter.as_ref().unwrap(), input_schema.clone())?;
let value_expr = self.create_expr(expr.value.as_ref().unwrap(), input_schema)?;
Ok(Arc::new(BloomFilterMightContain::try_new(
bloom_filter_expr,
value_expr,
)?))
}
expr => Err(ExecutionError::GeneralError(format!(
"Not implemented: {:?}",
expr
Expand Down
2 changes: 1 addition & 1 deletion core/src/execution/datafusion/spark_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use datafusion::{
};

#[inline]
fn spark_compatible_murmur3_hash<T: AsRef<[u8]>>(data: T, seed: u32) -> u32 {
pub(crate) fn spark_compatible_murmur3_hash<T: AsRef<[u8]>>(data: T, seed: u32) -> u32 {
#[inline]
fn mix_k1(mut k1: i32) -> i32 {
k1 = k1.mul_wrapping(0xcc9e2d51u32 as i32);
Expand Down
19 changes: 19 additions & 0 deletions core/src/execution/datafusion/util/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

pub mod spark_bit_array;
pub mod spark_bloom_filter;
131 changes: 131 additions & 0 deletions core/src/execution/datafusion/util/spark_bit_array.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

/// A simple bit array implementation that simulates the behavior of Spark's BitArray which is
/// used in the BloomFilter implementation. Some methods are not implemented as they are not
/// required for the current use case.
#[derive(Debug, Hash)]
advancedxy marked this conversation as resolved.
Show resolved Hide resolved
pub struct SparkBitArray {
advancedxy marked this conversation as resolved.
Show resolved Hide resolved
data: Vec<u64>,
bit_count: usize,
}

impl SparkBitArray {
pub fn new(buf: Vec<u64>) -> Self {
let num_bits = buf.iter().map(|x| x.count_ones() as usize).sum();
Self {
data: buf,
bit_count: num_bits,
}
}

pub fn set(&mut self, index: usize) -> bool {
if !self.get(index) {
// see the get method for the explanation of the shift operators
self.data[index >> 6] |= 1u64 << (index & 0x3f);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why index & 0x3f? Spark BitArray doesn't do this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Java and Rust have different semantics about bit shift left.

For Java, the shit left operator will rotate bits if the number of bits to be shifted are large than 64

jshell> 1 << 65
$1 ==> 2

jshell> 1 << 129
$5 ==> 2

Rust doesn't support this semantic, it will panic at overflow.

1u64 << 65 // panics shift left with overflow

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, it is not rotated. Java shift operators defines:

If the promoted type of the left-hand operand is long, then only the six lowest-order bits of the right-hand operand are used as the shift distance. It is as if the right-hand operand were subjected to a bitwise logical AND operator & with the mask value 0x3f (0b111111).[11] The shift distance actually used is therefore always in the range 0 to 63, inclusive.

https://en.wikipedia.org/wiki/Bitwise_operation

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should add a comment on this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, it is not rotated.

Hmm, thanks for the correction and the info. I didn't find an authentic place about how java defines its shift operators and thought it was a rotated shift.

Maybe we should add a comment on this?

Of course.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this PR is ready for another round of review.

I will address this comment together with other issues if any.

self.bit_count += 1;
true
} else {
false
}
}

pub fn get(&self, index: usize) -> bool {
// Java version: (data[(int) (index >> 6)] & (1L << (index))) != 0
// Rust and Java have different semantics for the shift operators. Java's shift operators
// explicitly mask the right-hand operand with 0x3f [1], while Rust's shift operators does
// not do this, it will panic with shift left with overflow for large right-hand operand.
// To fix this, we need to mask the right-hand operand with 0x3f in the rust side.
// [1]: https://docs.oracle.com/javase/specs/jls/se7/html/jls-15.html#jls-15.19
(self.data[index >> 6] & (1u64 << (index & 0x3f))) != 0
}

pub fn bit_size(&self) -> u64 {
self.data.len() as u64 * 64
}

pub fn cardinality(&self) -> usize {
self.bit_count
}
}

#[cfg(test)]
mod test {
use super::*;

#[test]
fn test_spark_bit_array() {
let buf = vec![0u64; 4];
let mut array = SparkBitArray::new(buf);
assert_eq!(array.bit_size(), 256);
assert_eq!(array.cardinality(), 0);

assert!(!array.get(0));
assert!(!array.get(1));
assert!(!array.get(63));
assert!(!array.get(64));
assert!(!array.get(65));
assert!(!array.get(127));
assert!(!array.get(128));
assert!(!array.get(129));

assert!(array.set(0));
assert!(array.set(1));
assert!(array.set(63));
assert!(array.set(64));
assert!(array.set(65));
assert!(array.set(127));
assert!(array.set(128));
assert!(array.set(129));

assert_eq!(array.cardinality(), 8);
assert_eq!(array.bit_size(), 256);

assert!(array.get(0));
// already set so should return false
assert!(!array.set(0));

// not set values should return false for get
assert!(!array.get(2));
assert!(!array.get(62));
}

#[test]
fn test_spark_bit_with_non_empty_buffer() {
let buf = vec![8u64; 4];
let mut array = SparkBitArray::new(buf);
assert_eq!(array.bit_size(), 256);
assert_eq!(array.cardinality(), 4);

// already set bits should return true
assert!(array.get(3));
assert!(array.get(67));
assert!(array.get(131));
assert!(array.get(195));

// other unset bits should return false
assert!(!array.get(0));
assert!(!array.get(1));

// set bits
assert!(array.set(0));
assert!(array.set(1));

// check cardinality
assert_eq!(array.cardinality(), 6);
}
}
Loading
Loading