Skip to content

Commit

Permalink
feat: Support BloomFilterMightContain expr
Browse files Browse the repository at this point in the history
  • Loading branch information
advancedxy committed Mar 9, 2024
1 parent 488c523 commit dbddcd3
Show file tree
Hide file tree
Showing 16 changed files with 603 additions and 6 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use crate::{
execution::datafusion::util::spark_bloom_filter::SparkBloomFilter, parquet::data_type::AsBytes,
};
use arrow::record_batch::RecordBatch;
use arrow_array::{BooleanArray, Int64Array};
use arrow_schema::DataType;
use datafusion::{common::Result, physical_plan::ColumnarValue};
use datafusion_common::{internal_err, DataFusionError, Result as DataFusionResult, ScalarValue};
use datafusion_physical_expr::{aggregate::utils::down_cast_any_ref, PhysicalExpr};
use log::info;
use once_cell::sync::OnceCell;
use std::{
any::Any,
fmt::Display,
hash::{Hash, Hasher},
sync::Arc,
};

#[derive(Debug)]
pub struct BloomFilterMightContain {
pub bloom_filter_expr: Arc<dyn PhysicalExpr>,
pub value_expr: Arc<dyn PhysicalExpr>,
bloom_filter: OnceCell<Option<SparkBloomFilter>>,
}

impl Display for BloomFilterMightContain {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(
f,
"BloomFilterMightContain [bloom_filter_expr: {}, value_expr: {}]",
self.bloom_filter_expr, self.value_expr
)
}
}

impl PartialEq<dyn Any> for BloomFilterMightContain {
fn eq(&self, _other: &dyn Any) -> bool {
down_cast_any_ref(_other)
.downcast_ref::<Self>()
.map(|other| {
self.bloom_filter_expr.eq(&other.bloom_filter_expr)
&& self.value_expr.eq(&other.value_expr)
})
.unwrap_or(false)
}
}

impl BloomFilterMightContain {
pub fn new(
bloom_filter_expr: Arc<dyn PhysicalExpr>,
value_expr: Arc<dyn PhysicalExpr>,
) -> Self {
Self {
bloom_filter_expr,
value_expr,
bloom_filter: Default::default(),
}
}
}

impl PhysicalExpr for BloomFilterMightContain {
fn as_any(&self) -> &dyn Any {
self
}

fn data_type(&self, _input_schema: &arrow_schema::Schema) -> Result<DataType> {
Ok(DataType::Boolean)
}

fn nullable(&self, _input_schema: &arrow_schema::Schema) -> Result<bool> {
Ok(true)
}

fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult<ColumnarValue> {
// lazily get the spark bloom filter
if self.bloom_filter.get().is_none() {
let bloom_filter_bytes = self.bloom_filter_expr.evaluate(batch)?;
match bloom_filter_bytes {
ColumnarValue::Array(_) => {
return internal_err!(
"Bloom filter expression must be evaluated as a scalar value"
);
}
ColumnarValue::Scalar(ScalarValue::Binary(v)) => {
info!("init for bloom filter");
let filter = v.map(|v| SparkBloomFilter::new_from_buf(v.as_bytes()));
self.bloom_filter.get_or_init(|| filter);
}
_ => {
return internal_err!("Bloom filter expression must be binary type");
}
}
}
let num_rows = batch.num_rows();
let lazy_filter = self.bloom_filter.get().unwrap();
if lazy_filter.is_none() {
// when the bloom filter is null, we should return a boolean array with all nulls
return Ok(ColumnarValue::Array(Arc::new(BooleanArray::new_null(
num_rows,
))));
} else {
let spark_filter = lazy_filter.as_ref().unwrap();
let values = self.value_expr.evaluate(batch)?;
match values {
ColumnarValue::Array(array) => {
let array = array
.as_any()
.downcast_ref::<Int64Array>()
.expect("value_expr must be evaluated as an int64 array");
Ok(ColumnarValue::Array(Arc::new(
spark_filter.might_contain_longs(array)?,
)))
}
ColumnarValue::Scalar(a) => match a {
ScalarValue::Int64(v) => {
let result = v.map(|v| spark_filter.might_contain_long(v));
Ok(ColumnarValue::Scalar(ScalarValue::Boolean(result)))
}
_ => {
return internal_err!(
"value_expr must be evaluated as an int64 array or a int64 scalar"
);
}
},
}
}
}

fn children(&self) -> Vec<Arc<dyn PhysicalExpr>> {
vec![self.bloom_filter_expr.clone(), self.value_expr.clone()]
}

fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn PhysicalExpr>>,
) -> Result<Arc<dyn PhysicalExpr>> {
Ok(Arc::new(BloomFilterMightContain::new(
children[0].clone(),
children[1].clone(),
)))
}

fn dyn_hash(&self, state: &mut dyn Hasher) {
let mut s = state;
self.bloom_filter_expr.hash(&mut s);
self.value_expr.hash(&mut s);
}
}
1 change: 1 addition & 0 deletions core/src/execution/datafusion/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ pub mod scalar_funcs;
pub use normalize_nan::NormalizeNaNAndZero;
pub mod avg;
pub mod avg_decimal;
pub mod bloom_filter_might_contain;
pub mod strings;
pub mod subquery;
pub mod sum_decimal;
Expand Down
1 change: 1 addition & 0 deletions core/src/execution/datafusion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ mod operators;
pub mod planner;
pub(crate) mod shuffle_writer;
mod spark_hash;
mod util;
10 changes: 10 additions & 0 deletions core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ use crate::{
avg::Avg,
avg_decimal::AvgDecimal,
bitwise_not::BitwiseNotExpr,
bloom_filter_might_contain::BloomFilterMightContain,
cast::Cast,
checkoverflow::CheckOverflow,
if_expr::IfExpr,
Expand Down Expand Up @@ -525,6 +526,15 @@ impl PhysicalPlanner {
let data_type = to_arrow_datatype(expr.datatype.as_ref().unwrap());
Ok(Arc::new(Subquery::new(self.exec_context_id, id, data_type)))
}
ExprStruct::BloomFilterMightContain(expr) => {
let bloom_filter_expr =
self.create_expr(expr.bloom_filter.as_ref().unwrap(), input_schema.clone())?;
let value_expr = self.create_expr(expr.value.as_ref().unwrap(), input_schema)?;
Ok(Arc::new(BloomFilterMightContain::new(
bloom_filter_expr,
value_expr,
)))
}
expr => Err(ExecutionError::GeneralError(format!(
"Not implemented: {:?}",
expr
Expand Down
2 changes: 1 addition & 1 deletion core/src/execution/datafusion/spark_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use datafusion::{
};

#[inline]
fn spark_compatible_murmur3_hash<T: AsRef<[u8]>>(data: T, seed: u32) -> u32 {
pub fn spark_compatible_murmur3_hash<T: AsRef<[u8]>>(data: T, seed: u32) -> u32 {
#[inline]
fn mix_k1(mut k1: i32) -> i32 {
k1 = k1.mul_wrapping(0xcc9e2d51u32 as i32);
Expand Down
19 changes: 19 additions & 0 deletions core/src/execution/datafusion/util/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

pub mod spark_bit_array;
pub mod spark_bloom_filter;
130 changes: 130 additions & 0 deletions core/src/execution/datafusion/util/spark_bit_array.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#[derive(Debug, Hash)]
pub struct SparkBitArray {
data: Vec<u64>,
bit_count: usize,
}

impl SparkBitArray {
pub fn new(buf: Vec<u64>) -> Self {
let num_bits = buf.iter().map(|x| x.count_ones() as usize).sum();
Self {
data: buf,
bit_count: num_bits,
}
}

pub fn new_from_bit_count(num_bits: usize) -> Self {
let num_words = (num_bits + 63) / 64;
debug_assert!(num_words < u32::MAX as usize, "num_words is too large");
Self {
data: vec![0u64; num_words],
bit_count: num_bits,
}
}

pub fn set(&mut self, index: usize) -> bool {
if !self.get(index) {
self.data[index >> 6] |= 1u64 << (index & 0x3f);
self.bit_count += 1;
true
} else {
false
}
}

pub fn get(&self, index: usize) -> bool {
(self.data[index >> 6] & (1u64 << (index & 0x3f))) != 0
}

pub fn bit_size(&self) -> u64 {
self.data.len() as u64 * 64
}

pub fn cardinality(&self) -> usize {
self.bit_count
}
}

#[cfg(test)]
mod test {
use super::*;

#[test]
fn test_spark_bit_array() {
let buf = vec![0u64; 4];
let mut array = SparkBitArray::new(buf);
assert_eq!(array.bit_size(), 256);
assert_eq!(array.cardinality(), 0);

assert!(!array.get(0));
assert!(!array.get(1));
assert!(!array.get(63));
assert!(!array.get(64));
assert!(!array.get(65));
assert!(!array.get(127));
assert!(!array.get(128));
assert!(!array.get(129));

assert!(array.set(0));
assert!(array.set(1));
assert!(array.set(63));
assert!(array.set(64));
assert!(array.set(65));
assert!(array.set(127));
assert!(array.set(128));
assert!(array.set(129));

assert_eq!(array.cardinality(), 8);
assert_eq!(array.bit_size(), 256);

assert!(array.get(0));
// already set so should return false
assert!(!array.set(0));

// not set values should return false for get
assert!(!array.get(2));
assert!(!array.get(62));
}

#[test]
fn test_spark_bit_with_non_empty_buffer() {
let buf = vec![8u64; 4];
let mut array = SparkBitArray::new(buf);
assert_eq!(array.bit_size(), 256);
assert_eq!(array.cardinality(), 4);

// already set bits should return true
assert!(array.get(3));
assert!(array.get(67));
assert!(array.get(131));
assert!(array.get(195));

// other unset bits should return false
assert!(!array.get(0));
assert!(!array.get(1));

// set bits
assert!(array.set(0));
assert!(array.set(1));

// check cardinality
assert_eq!(array.cardinality(), 6);
}
}
Loading

0 comments on commit dbddcd3

Please sign in to comment.