diff --git a/native/Cargo.lock b/native/Cargo.lock index c0f22fa1a..9bf8247d0 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -867,6 +867,7 @@ dependencies = [ "criterion", "datafusion", "datafusion-comet-spark-expr", + "datafusion-comet-utils", "datafusion-common", "datafusion-expr", "datafusion-physical-expr", @@ -909,8 +910,17 @@ dependencies = [ "arrow", "arrow-schema", "datafusion", + "datafusion-comet-utils", "datafusion-common", "datafusion-functions", + "datafusion-physical-expr", +] + +[[package]] +name = "datafusion-comet-utils" +version = "0.1.0" +dependencies = [ + "datafusion-physical-plan", ] [[package]] diff --git a/native/Cargo.toml b/native/Cargo.toml index 13860fbdf..53afed85a 100644 --- a/native/Cargo.toml +++ b/native/Cargo.toml @@ -16,7 +16,7 @@ # under the License. [workspace] -members = ["core", "spark-expr"] +members = ["core", "spark-expr", "utils"] resolver = "2" [workspace.package] @@ -43,8 +43,11 @@ datafusion-common = { git = "https://github.com/apache/datafusion.git", rev = "4 datafusion = { default-features = false, git = "https://github.com/apache/datafusion.git", rev = "40.0.0-rc1", features = ["unicode_expressions", "crypto_expressions"] } datafusion-functions = { git = "https://github.com/apache/datafusion.git", rev = "40.0.0-rc1", features = ["crypto_expressions"] } datafusion-expr = { git = "https://github.com/apache/datafusion.git", rev = "40.0.0-rc1", default-features = false } +datafusion-physical-plan = { git = "https://github.com/apache/datafusion.git", rev = "40.0.0-rc1", default-features = false } datafusion-physical-expr-common = { git = "https://github.com/apache/datafusion.git", rev = "40.0.0-rc1", default-features = false } datafusion-physical-expr = { git = "https://github.com/apache/datafusion.git", rev = "40.0.0-rc1", default-features = false } +datafusion-comet-spark-expr = { path = "spark-expr", version = "0.1.0" } +datafusion-comet-utils = { path = "utils", version = "0.1.0" } [profile.release] debug = true diff --git a/native/core/Cargo.toml b/native/core/Cargo.toml index 6432118d6..be135d4e9 100644 --- a/native/core/Cargo.toml +++ b/native/core/Cargo.toml @@ -77,7 +77,8 @@ once_cell = "1.18.0" regex = "1.9.6" crc32fast = "1.3.2" simd-adler32 = "0.3.7" -datafusion-comet-spark-expr = { path = "../spark-expr", version = "0.1.0" } +datafusion-comet-spark-expr = { workspace = true } +datafusion-comet-utils = { workspace = true } [build-dependencies] prost-build = "0.9.0" diff --git a/native/core/src/execution/datafusion/expressions/mod.rs b/native/core/src/execution/datafusion/expressions/mod.rs index 98b422dce..d573c2377 100644 --- a/native/core/src/execution/datafusion/expressions/mod.rs +++ b/native/core/src/execution/datafusion/expressions/mod.rs @@ -20,7 +20,6 @@ pub mod bitwise_not; pub mod cast; pub mod checkoverflow; -pub mod if_expr; mod normalize_nan; pub mod scalar_funcs; pub use normalize_nan::NormalizeNaNAndZero; diff --git a/native/core/src/execution/datafusion/expressions/utils.rs b/native/core/src/execution/datafusion/expressions/utils.rs index ee8646a78..6a7ec2e12 100644 --- a/native/core/src/execution/datafusion/expressions/utils.rs +++ b/native/core/src/execution/datafusion/expressions/utils.rs @@ -30,24 +30,10 @@ use arrow_array::{cast::AsArray, types::ArrowPrimitiveType}; use arrow_schema::DataType; use chrono::{DateTime, Offset, TimeZone}; use datafusion_common::cast::as_generic_string_array; -use datafusion_physical_expr::PhysicalExpr; use num::integer::div_floor; -use std::{any::Any, sync::Arc}; +use std::sync::Arc; -/// An utility function from DataFusion. It is not exposed by DataFusion. -pub fn down_cast_any_ref(any: &dyn Any) -> &dyn Any { - if any.is::>() { - any.downcast_ref::>() - .unwrap() - .as_any() - } else if any.is::>() { - any.downcast_ref::>() - .unwrap() - .as_any() - } else { - any - } -} +pub use datafusion_comet_utils::down_cast_any_ref; /// Preprocesses input arrays to add timezone information from Spark to Arrow array datatype or /// to apply timezone offset. diff --git a/native/core/src/execution/datafusion/planner.rs b/native/core/src/execution/datafusion/planner.rs index ee208ac74..23960c307 100644 --- a/native/core/src/execution/datafusion/planner.rs +++ b/native/core/src/execution/datafusion/planner.rs @@ -79,7 +79,6 @@ use crate::{ checkoverflow::CheckOverflow, correlation::Correlation, covariance::Covariance, - if_expr::IfExpr, negative, scalar_funcs::create_comet_physical_fun, stats::StatsType, @@ -108,7 +107,7 @@ use crate::{ }; use super::expressions::{create_named_struct::CreateNamedStruct, EvalMode}; -use datafusion_comet_spark_expr::abs::Abs; +use datafusion_comet_spark_expr::{Abs, IfExpr}; // For clippy error on type_complexity. type ExecResult = Result; diff --git a/native/spark-expr/Cargo.toml b/native/spark-expr/Cargo.toml index d10d04944..8bf76dff6 100644 --- a/native/spark-expr/Cargo.toml +++ b/native/spark-expr/Cargo.toml @@ -32,6 +32,8 @@ arrow-schema = { workspace = true } datafusion = { workspace = true } datafusion-common = { workspace = true } datafusion-functions = { workspace = true } +datafusion-physical-expr = { workspace = true } +datafusion-comet-utils = { workspace = true } [lib] name = "datafusion_comet_spark_expr" diff --git a/native/core/src/execution/datafusion/expressions/if_expr.rs b/native/spark-expr/src/if_expr.rs similarity index 95% rename from native/core/src/execution/datafusion/expressions/if_expr.rs rename to native/spark-expr/src/if_expr.rs index fa235cc66..c04494ec4 100644 --- a/native/core/src/execution/datafusion/expressions/if_expr.rs +++ b/native/spark-expr/src/if_expr.rs @@ -31,7 +31,7 @@ use datafusion::logical_expr::ColumnarValue; use datafusion_common::{cast::as_boolean_array, Result}; use datafusion_physical_expr::PhysicalExpr; -use crate::execution::datafusion::expressions::utils::down_cast_any_ref; +use datafusion_comet_utils::down_cast_any_ref; #[derive(Debug, Hash)] pub struct IfExpr { @@ -147,15 +147,6 @@ impl PartialEq for IfExpr { } } -/// Create an If expression -pub fn if_fn( - if_expr: Arc, - true_expr: Arc, - false_expr: Arc, -) -> Result> { - Ok(Arc::new(IfExpr::new(if_expr, true_expr, false_expr))) -} - #[cfg(test)] mod tests { use arrow::{array::StringArray, datatypes::*}; @@ -165,6 +156,15 @@ mod tests { use super::*; + /// Create an If expression + fn if_fn( + if_expr: Arc, + true_expr: Arc, + false_expr: Arc, + ) -> Result> { + Ok(Arc::new(IfExpr::new(if_expr, true_expr, false_expr))) + } + #[test] fn test_if_1() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]); diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs index 3873754be..c36e8855e 100644 --- a/native/spark-expr/src/lib.rs +++ b/native/spark-expr/src/lib.rs @@ -18,7 +18,11 @@ use std::error::Error; use std::fmt::{Display, Formatter}; -pub mod abs; +mod abs; +mod if_expr; + +pub use abs::Abs; +pub use if_expr::IfExpr; /// Spark supports three evaluation modes when evaluating expressions, which affect /// the behavior when processing input values that are invalid or would result in an diff --git a/native/utils/Cargo.toml b/native/utils/Cargo.toml new file mode 100644 index 000000000..05ddd3488 --- /dev/null +++ b/native/utils/Cargo.toml @@ -0,0 +1,34 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "datafusion-comet-utils" +description = "DataFusion Comet Utilities" +version = { workspace = true } +homepage = { workspace = true } +repository = { workspace = true } +authors = { workspace = true } +readme = { workspace = true } +license = { workspace = true } +edition = { workspace = true } + +[dependencies] +datafusion-physical-plan = { workspace = true } + +[lib] +name = "datafusion_comet_utils" +path = "src/lib.rs" diff --git a/native/utils/README.md b/native/utils/README.md new file mode 100644 index 000000000..513c6245e --- /dev/null +++ b/native/utils/README.md @@ -0,0 +1,22 @@ + + +# datafusion-comet-utils + +This crate provides utilities for use in the [Apache DataFusion Comet](https://github.com/apache/datafusion-comet/) project. \ No newline at end of file diff --git a/native/utils/src/lib.rs b/native/utils/src/lib.rs new file mode 100644 index 000000000..54ff55b46 --- /dev/null +++ b/native/utils/src/lib.rs @@ -0,0 +1,36 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use datafusion_physical_plan::PhysicalExpr; + +/// A utility function from DataFusion. It is not exposed by DataFusion. +pub fn down_cast_any_ref(any: &dyn Any) -> &dyn Any { + if any.is::>() { + any.downcast_ref::>() + .unwrap() + .as_any() + } else if any.is::>() { + any.downcast_ref::>() + .unwrap() + .as_any() + } else { + any + } +}