From ded3dd6703fe62c1ad2b3cb9034975aec410ff32 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 24 Jul 2024 15:11:21 -0600 Subject: [PATCH] perf: Optimize IfExpr by delegating to CaseExpr (#681) * Unify IF and CASE expressions * revert test changes * fix --- native/Cargo.lock | 3 + native/core/Cargo.toml | 8 - native/spark-expr/Cargo.toml | 17 +++ .../benches/cast_from_string.rs | 0 .../benches/cast_numeric.rs | 0 native/spark-expr/benches/conditional.rs | 139 ++++++++++++++++++ native/spark-expr/src/if_expr.rs | 44 ++---- 7 files changed, 173 insertions(+), 38 deletions(-) rename native/{core => spark-expr}/benches/cast_from_string.rs (100%) rename native/{core => spark-expr}/benches/cast_numeric.rs (100%) create mode 100644 native/spark-expr/benches/conditional.rs diff --git a/native/Cargo.lock b/native/Cargo.lock index fb4ce70fe..580610748 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -918,12 +918,15 @@ dependencies = [ "arrow-schema", "chrono", "chrono-tz 0.8.6", + "criterion", "datafusion", "datafusion-common", "datafusion-expr", "datafusion-physical-expr", + "datafusion-physical-expr-common", "datafusion-physical-plan", "num", + "rand", "regex", "thiserror", ] diff --git a/native/core/Cargo.toml b/native/core/Cargo.toml index 3046c1d8f..e396976d7 100644 --- a/native/core/Cargo.toml +++ b/native/core/Cargo.toml @@ -112,14 +112,6 @@ harness = false name = "row_columnar" harness = false -[[bench]] -name = "cast_from_string" -harness = false - -[[bench]] -name = "cast_numeric" -harness = false - [[bench]] name = "shuffle_writer" harness = false diff --git a/native/spark-expr/Cargo.toml b/native/spark-expr/Cargo.toml index 192ed102b..aa4fcfc5f 100644 --- a/native/spark-expr/Cargo.toml +++ b/native/spark-expr/Cargo.toml @@ -34,6 +34,7 @@ chrono = { workspace = true } datafusion = { workspace = true } datafusion-common = { workspace = true } datafusion-expr = { workspace = true } +datafusion-physical-expr-common = { workspace = true } datafusion-physical-expr = { workspace = true } datafusion-physical-plan = { workspace = true } chrono-tz = { workspace = true } @@ -41,6 +42,22 @@ num = { workspace = true } regex = { workspace = true } thiserror = { workspace = true } +[dev-dependencies] +criterion = "0.5.1" +rand = "0.8.5" + [lib] name = "datafusion_comet_spark_expr" path = "src/lib.rs" + +[[bench]] +name = "cast_from_string" +harness = false + +[[bench]] +name = "cast_numeric" +harness = false + +[[bench]] +name = "conditional" +harness = false \ No newline at end of file diff --git a/native/core/benches/cast_from_string.rs b/native/spark-expr/benches/cast_from_string.rs similarity index 100% rename from native/core/benches/cast_from_string.rs rename to native/spark-expr/benches/cast_from_string.rs diff --git a/native/core/benches/cast_numeric.rs b/native/spark-expr/benches/cast_numeric.rs similarity index 100% rename from native/core/benches/cast_numeric.rs rename to native/spark-expr/benches/cast_numeric.rs diff --git a/native/spark-expr/benches/conditional.rs b/native/spark-expr/benches/conditional.rs new file mode 100644 index 000000000..d86ef76f8 --- /dev/null +++ b/native/spark-expr/benches/conditional.rs @@ -0,0 +1,139 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::{Field, Schema}; +use arrow::record_batch::RecordBatch; +use arrow_array::builder::{Int32Builder, StringBuilder}; +use arrow_schema::DataType; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_comet_spark_expr::IfExpr; +use datafusion_common::ScalarValue; +use datafusion_expr::Operator; +use datafusion_physical_expr::expressions::{BinaryExpr, CaseExpr}; +use datafusion_physical_expr_common::expressions::column::Column; +use datafusion_physical_expr_common::expressions::Literal; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use std::sync::Arc; + +fn make_col(name: &str, index: usize) -> Arc { + Arc::new(Column::new(name, index)) +} + +fn make_lit_i32(n: i32) -> Arc { + Arc::new(Literal::new(ScalarValue::Int32(Some(n)))) +} + +fn make_null_lit() -> Arc { + Arc::new(Literal::new(ScalarValue::Utf8(None))) +} + +fn criterion_benchmark(c: &mut Criterion) { + // create input data + let mut c1 = Int32Builder::new(); + let mut c2 = StringBuilder::new(); + let mut c3 = StringBuilder::new(); + for i in 0..1000 { + c1.append_value(i); + if i % 7 == 0 { + c2.append_null(); + } else { + c2.append_value(&format!("string {i}")); + } + if i % 9 == 0 { + c3.append_null(); + } else { + c3.append_value(&format!("other string {i}")); + } + } + let c1 = Arc::new(c1.finish()); + let c2 = Arc::new(c2.finish()); + let c3 = Arc::new(c3.finish()); + let schema = Schema::new(vec![ + Field::new("c1", DataType::Int32, true), + Field::new("c2", DataType::Utf8, true), + Field::new("c3", DataType::Utf8, true), + ]); + let batch = RecordBatch::try_new(Arc::new(schema), vec![c1, c2, c3]).unwrap(); + + // use same predicate for all benchmarks + let predicate = Arc::new(BinaryExpr::new( + make_col("c1", 0), + Operator::LtEq, + make_lit_i32(500), + )); + + // CASE WHEN c1 <= 500 THEN 1 ELSE 0 END + c.bench_function("case_when: scalar or scalar", |b| { + let expr = Arc::new( + CaseExpr::try_new( + None, + vec![(predicate.clone(), make_lit_i32(1))], + Some(make_lit_i32(0)), + ) + .unwrap(), + ); + b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) + }); + c.bench_function("if: scalar or scalar", |b| { + let expr = Arc::new(IfExpr::new( + predicate.clone(), + make_lit_i32(1), + make_lit_i32(0), + )); + b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) + }); + + // CASE WHEN c1 <= 500 THEN c2 [ELSE NULL] END + c.bench_function("case_when: column or null", |b| { + let expr = Arc::new( + CaseExpr::try_new(None, vec![(predicate.clone(), make_col("c2", 1))], None).unwrap(), + ); + b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) + }); + c.bench_function("if: column or null", |b| { + let expr = Arc::new(IfExpr::new( + predicate.clone(), + make_col("c2", 1), + make_null_lit(), + )); + b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) + }); + + // CASE WHEN c1 <= 500 THEN c2 ELSE c3 END + c.bench_function("case_when: expr or expr", |b| { + let expr = Arc::new( + CaseExpr::try_new( + None, + vec![(predicate.clone(), make_col("c2", 1))], + Some(make_col("c3", 2)), + ) + .unwrap(), + ); + b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) + }); + c.bench_function("if: expr or expr", |b| { + let expr = Arc::new(IfExpr::new( + predicate.clone(), + make_col("c2", 1), + make_col("c3", 2), + )); + b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/native/spark-expr/src/if_expr.rs b/native/spark-expr/src/if_expr.rs index fa52c5d5b..a5344140b 100644 --- a/native/spark-expr/src/if_expr.rs +++ b/native/spark-expr/src/if_expr.rs @@ -22,22 +22,24 @@ use std::{ }; use arrow::{ - array::*, - compute::{and, is_null, kernels::zip::zip, not, or_kleene}, datatypes::{DataType, Schema}, record_batch::RecordBatch, }; use datafusion::logical_expr::ColumnarValue; -use datafusion_common::{cast::as_boolean_array, Result}; -use datafusion_physical_expr::PhysicalExpr; +use datafusion_common::Result; +use datafusion_physical_expr::{expressions::CaseExpr, PhysicalExpr}; use crate::utils::down_cast_any_ref; +/// IfExpr is a wrapper around CaseExpr, because `IF(a, b, c)` is semantically equivalent to +/// `CASE WHEN a THEN b ELSE c END`. #[derive(Debug, Hash)] pub struct IfExpr { if_expr: Arc, true_expr: Arc, false_expr: Arc, + // we delegate to case_expr for evaluation + case_expr: Arc, } impl std::fmt::Display for IfExpr { @@ -58,9 +60,12 @@ impl IfExpr { false_expr: Arc, ) -> Self { Self { - if_expr, - true_expr, - false_expr, + if_expr: if_expr.clone(), + true_expr: true_expr.clone(), + false_expr: false_expr.clone(), + case_expr: Arc::new( + CaseExpr::try_new(None, vec![(if_expr, true_expr)], Some(false_expr)).unwrap(), + ), } } } @@ -85,29 +90,7 @@ impl PhysicalExpr for IfExpr { } fn evaluate(&self, batch: &RecordBatch) -> Result { - let mut remainder = BooleanArray::from(vec![true; batch.num_rows()]); - - // evaluate if condition on batch - let if_value = self.if_expr.evaluate_selection(batch, &remainder)?; - let if_value = if_value.into_array(batch.num_rows())?; - let if_value = - as_boolean_array(&if_value).expect("if expression did not return a BooleanArray"); - - let true_value = self.true_expr.evaluate_selection(batch, if_value)?; - let true_value = true_value.into_array(batch.num_rows())?; - - remainder = and( - &remainder, - &or_kleene(¬(if_value)?, &is_null(if_value)?)?, - )?; - - let false_value = self - .false_expr - .evaluate_selection(batch, &remainder)? - .into_array(batch.num_rows())?; - let current_value = zip(&remainder, &false_value, &true_value)?; - - Ok(ColumnarValue::Array(current_value)) + self.case_expr.evaluate(batch) } fn children(&self) -> Vec<&Arc> { @@ -150,6 +133,7 @@ impl PartialEq for IfExpr { #[cfg(test)] mod tests { use arrow::{array::StringArray, datatypes::*}; + use arrow_array::Int32Array; use datafusion::logical_expr::Operator; use datafusion_common::cast::as_int32_array; use datafusion_physical_expr::expressions::{binary, col, lit};