diff --git a/native/Cargo.lock b/native/Cargo.lock index c3aae93af..95092f754 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -917,11 +917,13 @@ dependencies = [ "arrow-schema", "chrono", "chrono-tz 0.8.6", + "criterion", "datafusion", "datafusion-common", "datafusion-expr", "datafusion-functions", "datafusion-physical-expr", + "datafusion-physical-expr-common", "datafusion-physical-plan", "num", "regex", diff --git a/native/core/src/execution/datafusion/planner.rs b/native/core/src/execution/datafusion/planner.rs index 03635dd7b..9eb4e32a3 100644 --- a/native/core/src/execution/datafusion/planner.rs +++ b/native/core/src/execution/datafusion/planner.rs @@ -108,7 +108,8 @@ use datafusion_comet_proto::{ spark_partitioning::{partitioning::PartitioningStruct, Partitioning as SparkPartitioning}, }; use datafusion_comet_spark_expr::{ - Abs, Cast, DateTruncExpr, HourExpr, IfExpr, MinuteExpr, SecondExpr, TimestampTruncExpr, + Abs, CaseWhenExprOrNull, Cast, DateTruncExpr, HourExpr, IfExpr, MinuteExpr, SecondExpr, + TimestampTruncExpr, }; // For clippy error on type_complexity. @@ -541,6 +542,27 @@ impl PhysicalPlanner { Some(self.create_expr(case_when.else_expr.as_ref().unwrap(), input_schema)?) } }; + + // TODO remove this optimization when we upgrade to DataFusion 41, + // which contains https://github.com/apache/datafusion/pull/11534 + + // optimized path for CASE WHEN predicate THEN expr ELSE null END + if else_phy_expr.is_none() && when_then_pairs.len() == 1 { + let when_then = &when_then_pairs[0]; + // CaseWhenExprOrNull is only safe to use for expressions that do not + // have side effects, and it is only suitable to use for expressions + // that are inexpensive to compute (such as a column reference) + // because it will be evaluated for all rows in the batch rather + // than just the rows where the predicate is true. + // For now, we limit the use to raw column references + if when_then.1.as_any().is::() { + return Ok(Arc::new(CaseWhenExprOrNull::new( + Arc::clone(&when_then.0), + Arc::clone(&when_then.1), + ))); + } + } + Ok(Arc::new(CaseExpr::try_new( None, when_then_pairs, diff --git a/native/spark-expr/Cargo.toml b/native/spark-expr/Cargo.toml index 976a1f36f..7ad824768 100644 --- a/native/spark-expr/Cargo.toml +++ b/native/spark-expr/Cargo.toml @@ -36,12 +36,20 @@ datafusion-common = { workspace = true } datafusion-functions = { workspace = true } datafusion-expr = { workspace = true } datafusion-physical-expr = { workspace = true } +datafusion-physical-expr-common = { workspace = true } datafusion-physical-plan = { workspace = true } chrono-tz = { workspace = true } num = { workspace = true } regex = { workspace = true } thiserror = { workspace = true } +[dev-dependencies] +criterion = "0.5" + [lib] name = "datafusion_comet_spark_expr" path = "src/lib.rs" + +[[bench]] +harness = false +name = "case_when" diff --git a/native/spark-expr/benches/case_when.rs b/native/spark-expr/benches/case_when.rs new file mode 100644 index 000000000..6a233d46e --- /dev/null +++ b/native/spark-expr/benches/case_when.rs @@ -0,0 +1,78 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::{Field, Schema}; +use arrow::record_batch::RecordBatch; +use arrow_array::builder::{Int32Builder, StringBuilder}; +use arrow_schema::DataType; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_comet_spark_expr::CaseWhenExprOrNull; +use datafusion_common::ScalarValue; +use datafusion_expr::Operator; +use datafusion_physical_expr::expressions::BinaryExpr; +use datafusion_physical_expr_common::expressions::column::Column; +use datafusion_physical_expr_common::expressions::Literal; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use std::sync::Arc; + +fn make_col(name: &str, index: usize) -> Arc { + Arc::new(Column::new(name, index)) +} + +fn make_lit_i32(n: i32) -> Arc { + Arc::new(Literal::new(ScalarValue::Int32(Some(n)))) +} + +fn criterion_benchmark(c: &mut Criterion) { + // create input data + let mut c1 = Int32Builder::new(); + let mut c2 = StringBuilder::new(); + for i in 0..1000 { + c1.append_value(i); + if i % 7 == 0 { + c2.append_null(); + } else { + c2.append_value(&format!("string {i}")); + } + } + let c1 = Arc::new(c1.finish()); + let c2 = Arc::new(c2.finish()); + let schema = Schema::new(vec![ + Field::new("c1", DataType::Int32, true), + Field::new("c2", DataType::Utf8, true), + ]); + let batch = RecordBatch::try_new(Arc::new(schema), vec![c1, c2]).unwrap(); + + // use same predicate for all benchmarks + let predicate = Arc::new(BinaryExpr::new( + make_col("c1", 0), + Operator::LtEq, + make_lit_i32(500), + )); + + // CASE WHEN expr THEN col ELSE null END + c.bench_function("expr_or_null", |b| { + let expr = Arc::new(CaseWhenExprOrNull::new( + predicate.clone(), + make_col("c2", 1), + )); + b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/native/spark-expr/src/case_expr_or_null.rs b/native/spark-expr/src/case_expr_or_null.rs new file mode 100644 index 000000000..a9cf4c351 --- /dev/null +++ b/native/spark-expr/src/case_expr_or_null.rs @@ -0,0 +1,206 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{ + any::Any, + fmt::{Debug, Display, Formatter}, + hash::{Hash, Hasher}, + sync::Arc, +}; + +use arrow::compute::nullif; +use arrow_array::{Array, BooleanArray, RecordBatch}; +use arrow_schema::{DataType, Schema}; + +use datafusion_common::{exec_err, Result}; +use datafusion_expr::ColumnarValue; +use datafusion_physical_plan::{DisplayAs, DisplayFormatType, PhysicalExpr}; + +use crate::utils::down_cast_any_ref; + +/// Specialization of `CASE WHEN .. THEN .. ELSE null END` where +/// the else condition is a null literal. +/// +/// CaseWhenExprOrNull is only safe to use for expressions that do not +/// have side effects, and it is only suitable to use for expressions +/// that are inexpensive to compute (such as a column reference) +/// because it will be evaluated for all rows in the batch rather +/// than just the rows where the predicate is true. +/// +/// The performance advantage of this expression is that it +/// avoids copying data and simply modifies the null bitmask +/// of the evaluated expression based on the inverse of the +/// predicate expression. +#[derive(Debug, Hash)] +pub struct CaseWhenExprOrNull { + /// The WHEN predicate + predicate: Arc, + /// The THEN expression + expr: Arc, +} + +impl CaseWhenExprOrNull { + pub fn new(predicate: Arc, input: Arc) -> Self { + Self { + predicate, + expr: input, + } + } +} + +impl Display for CaseWhenExprOrNull { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "ExprOrNull(predicate={}, expr={})", + self.predicate, self.expr + ) + } +} + +impl DisplayAs for CaseWhenExprOrNull { + fn fmt_as(&self, _: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + write!( + f, + "ExprOrNull(predicate={}, expr={})", + self.predicate, self.expr + ) + } +} + +impl PhysicalExpr for CaseWhenExprOrNull { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, input_schema: &Schema) -> Result { + self.expr.data_type(input_schema) + } + + fn nullable(&self, _input_schema: &Schema) -> Result { + Ok(true) + } + + fn evaluate(&self, batch: &RecordBatch) -> Result { + if let ColumnarValue::Array(bit_mask) = self.predicate.evaluate(batch)? { + let bit_mask = bit_mask + .as_any() + .downcast_ref::() + .expect("predicate should evaluate to a boolean array"); + // invert the bitmask + let bit_mask = arrow::compute::kernels::boolean::not(bit_mask)?; + match self.expr.evaluate(batch)? { + ColumnarValue::Array(array) => Ok(ColumnarValue::Array(nullif(&array, &bit_mask)?)), + ColumnarValue::Scalar(_) => exec_err!("expression did not evaluate to an array"), + } + } else { + exec_err!("predicate did not evaluate to an array") + } + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> Result> { + Ok(self) + } + + fn dyn_hash(&self, state: &mut dyn Hasher) { + let mut s = state; + self.predicate.hash(&mut s); + self.expr.hash(&mut s); + self.hash(&mut s); + } +} + +impl PartialEq for CaseWhenExprOrNull { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| self.predicate.eq(&x.predicate) && self.expr.eq(&x.expr)) + .unwrap_or(false) + } +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use arrow_array::builder::{Int32Builder, StringBuilder}; + use arrow_array::{Array, RecordBatch}; + use arrow_schema::{DataType, Field, Schema}; + + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, Operator}; + use datafusion_physical_expr::expressions::{BinaryExpr, CaseExpr}; + use datafusion_physical_expr_common::expressions::column::Column; + use datafusion_physical_expr_common::expressions::Literal; + use datafusion_physical_expr_common::physical_expr::PhysicalExpr; + + use crate::CaseWhenExprOrNull; + + #[test] + fn test() -> Result<()> { + // create input data + let mut c1 = Int32Builder::new(); + let mut c2 = StringBuilder::new(); + for i in 0..1000 { + c1.append_value(i); + if i % 7 == 0 { + c2.append_null(); + } else { + c2.append_value(&format!("string {i}")); + } + } + let c1 = Arc::new(c1.finish()); + let c2 = Arc::new(c2.finish()); + let schema = Schema::new(vec![ + Field::new("c1", DataType::Int32, true), + Field::new("c2", DataType::Utf8, true), + ]); + let batch = RecordBatch::try_new(Arc::new(schema), vec![c1, c2]).unwrap(); + + // CaseWhenExprOrNull should produce same results as CaseExpr + let predicate = Arc::new(BinaryExpr::new( + make_col("c1", 0), + Operator::LtEq, + make_lit_i32(250), + )); + let expr1 = CaseWhenExprOrNull::new(predicate.clone(), make_col("c2", 1)); + let expr2 = CaseExpr::try_new(None, vec![(predicate, make_col("c2", 1))], None)?; + match (expr1.evaluate(&batch)?, expr2.evaluate(&batch)?) { + (ColumnarValue::Array(array1), ColumnarValue::Array(array2)) => { + assert_eq!(array1.len(), array2.len()); + assert_eq!(array1.null_count(), array2.null_count()); + } + _ => unreachable!(), + } + Ok(()) + } + + fn make_col(name: &str, index: usize) -> Arc { + Arc::new(Column::new(name, index)) + } + + fn make_lit_i32(n: i32) -> Arc { + Arc::new(Literal::new(ScalarValue::Int32(Some(n)))) + } +} diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs index 91d61f70a..ed6966154 100644 --- a/native/spark-expr/src/lib.rs +++ b/native/spark-expr/src/lib.rs @@ -20,12 +20,14 @@ mod cast; mod error; mod if_expr; +mod case_expr_or_null; mod kernels; mod temporal; pub mod timezone; pub mod utils; pub use abs::Abs; +pub use case_expr_or_null::CaseWhenExprOrNull; pub use cast::Cast; pub use error::{SparkError, SparkResult}; pub use if_expr::IfExpr; diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 65217767d..54b23a790 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -1281,6 +1281,13 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { withTable(table) { sql(s"create table $table(id int) using parquet") sql(s"insert into $table values(1), (NULL), (2), (2), (3), (3), (4), (5), (NULL)") + checkSparkAnswerAndOperator( + s"SELECT CASE WHEN id > 2 THEN 3333 ELSE NULL END FROM $table") + checkSparkAnswerAndOperator( + s"SELECT CASE WHEN id > 2 THEN id ELSE NULL END FROM $table") + checkSparkAnswerAndOperator( + s"SELECT CASE WHEN id > 2 THEN id + 1 ELSE NULL END FROM $table") + checkSparkAnswerAndOperator(s"SELECT CASE WHEN id > 2 THEN id + 1 END FROM $table") checkSparkAnswerAndOperator( s"SELECT CASE WHEN id > 2 THEN 3333 WHEN id > 1 THEN 2222 ELSE 1111 END FROM $table") checkSparkAnswerAndOperator(