diff --git a/native/Cargo.lock b/native/Cargo.lock index f73f28629..649e137f0 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -861,7 +861,6 @@ dependencies = [ "async-trait", "brotli", "bytes", - "chrono", "crc32fast", "criterion", "datafusion", diff --git a/native/core/Cargo.toml b/native/core/Cargo.toml index c252fad6d..90ead502f 100644 --- a/native/core/Cargo.toml +++ b/native/core/Cargo.toml @@ -64,7 +64,6 @@ bytes = "1.5.0" tempfile = "3.8.0" ahash = { version = "0.8", default-features = false } itertools = "0.11.0" -chrono = { workspace = true } paste = "1.0.14" datafusion-common = { workspace = true } datafusion = { workspace = true } diff --git a/native/core/benches/cast_from_string.rs b/native/core/benches/cast_from_string.rs index 9a9ab18cc..efc7987c5 100644 --- a/native/core/benches/cast_from_string.rs +++ b/native/core/benches/cast_from_string.rs @@ -17,8 +17,8 @@ use arrow_array::{builder::StringBuilder, RecordBatch}; use arrow_schema::{DataType, Field, Schema}; -use comet::execution::datafusion::expressions::{cast::Cast, EvalMode}; use criterion::{criterion_group, criterion_main, Criterion}; +use datafusion_comet_spark_expr::{Cast, EvalMode}; use datafusion_physical_expr::{expressions::Column, PhysicalExpr}; use std::sync::Arc; diff --git a/native/core/benches/cast_numeric.rs b/native/core/benches/cast_numeric.rs index 35f24ce53..f9ed1fae2 100644 --- a/native/core/benches/cast_numeric.rs +++ b/native/core/benches/cast_numeric.rs @@ -17,8 +17,8 @@ use arrow_array::{builder::Int32Builder, RecordBatch}; use arrow_schema::{DataType, Field, Schema}; -use comet::execution::datafusion::expressions::{cast::Cast, EvalMode}; use criterion::{criterion_group, criterion_main, Criterion}; +use datafusion_comet_spark_expr::{Cast, EvalMode}; use datafusion_physical_expr::{expressions::Column, PhysicalExpr}; use std::sync::Arc; diff --git a/native/core/src/execution/datafusion/expressions/mod.rs b/native/core/src/execution/datafusion/expressions/mod.rs index f6fb26b6a..3c0a5b26b 100644 --- a/native/core/src/execution/datafusion/expressions/mod.rs +++ b/native/core/src/execution/datafusion/expressions/mod.rs @@ -18,7 +18,6 @@ //! Native DataFusion expressions pub mod bitwise_not; -pub use datafusion_comet_spark_expr::cast; pub mod checkoverflow; mod normalize_nan; pub mod scalar_funcs; @@ -37,7 +36,6 @@ pub mod stddev; pub mod strings; pub mod subquery; pub mod sum_decimal; -pub mod temporal; pub mod unbound; mod utils; pub mod variance; diff --git a/native/core/src/execution/datafusion/expressions/utils.rs b/native/core/src/execution/datafusion/expressions/utils.rs index d253b251f..540fca86b 100644 --- a/native/core/src/execution/datafusion/expressions/utils.rs +++ b/native/core/src/execution/datafusion/expressions/utils.rs @@ -16,4 +16,4 @@ // under the License. // re-export for legacy reasons -pub use datafusion_comet_spark_expr::utils::{array_with_timezone, down_cast_any_ref}; +pub use datafusion_comet_spark_expr::utils::down_cast_any_ref; diff --git a/native/core/src/execution/datafusion/planner.rs b/native/core/src/execution/datafusion/planner.rs index 23960c307..7e6383059 100644 --- a/native/core/src/execution/datafusion/planner.rs +++ b/native/core/src/execution/datafusion/planner.rs @@ -75,7 +75,6 @@ use crate::{ avg_decimal::AvgDecimal, bitwise_not::BitwiseNotExpr, bloom_filter_might_contain::BloomFilterMightContain, - cast::Cast, checkoverflow::CheckOverflow, correlation::Correlation, covariance::Covariance, @@ -86,7 +85,6 @@ use crate::{ strings::{Contains, EndsWith, Like, StartsWith, StringSpaceExec, SubstringExec}, subquery::Subquery, sum_decimal::SumDecimal, - temporal::{DateTruncExec, HourExec, MinuteExec, SecondExec, TimestampTruncExec}, unbound::UnboundColumn, variance::Variance, NormalizeNaNAndZero, @@ -107,7 +105,9 @@ use crate::{ }; use super::expressions::{create_named_struct::CreateNamedStruct, EvalMode}; -use datafusion_comet_spark_expr::{Abs, IfExpr}; +use datafusion_comet_spark_expr::{ + Abs, Cast, DateTruncExec, HourExec, IfExpr, MinuteExec, SecondExec, TimestampTruncExec, +}; // For clippy error on type_complexity. type ExecResult = Result; diff --git a/native/core/src/execution/kernels/mod.rs b/native/core/src/execution/kernels/mod.rs index 76d4e1807..675dcd489 100644 --- a/native/core/src/execution/kernels/mod.rs +++ b/native/core/src/execution/kernels/mod.rs @@ -21,4 +21,3 @@ mod hash; pub use hash::hash; pub(crate) mod strings; -pub(crate) mod temporal; diff --git a/native/spark-expr/src/kernels/mod.rs b/native/spark-expr/src/kernels/mod.rs new file mode 100644 index 000000000..88aa34b1a --- /dev/null +++ b/native/spark-expr/src/kernels/mod.rs @@ -0,0 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Kernels + +pub(crate) mod temporal; diff --git a/native/core/src/execution/kernels/temporal.rs b/native/spark-expr/src/kernels/temporal.rs similarity index 95% rename from native/core/src/execution/kernels/temporal.rs rename to native/spark-expr/src/kernels/temporal.rs index 9cf35af1a..6f2474e8d 100644 --- a/native/core/src/execution/kernels/temporal.rs +++ b/native/spark-expr/src/kernels/temporal.rs @@ -32,24 +32,18 @@ use arrow_array::{ use arrow_schema::TimeUnit; -use crate::errors::ExpressionError; +use crate::SparkError; // Copied from arrow_arith/temporal.rs macro_rules! return_compute_error_with { ($msg:expr, $param:expr) => { - return { - Err(ExpressionError::ArrowError(format!( - "{}: {:?}", - $msg, $param - ))) - } + return { Err(SparkError::Internal(format!("{}: {:?}", $msg, $param))) } }; } // The number of days between the beginning of the proleptic gregorian calendar (0001-01-01) // and the beginning of the Unix Epoch (1970-01-01) const DAYS_TO_UNIX_EPOCH: i32 = 719_163; -const MICROS_TO_UNIX_EPOCH: i64 = 62_167_132_800 * 1_000_000; // Copied from arrow_arith/temporal.rs with modification to the output datatype // Transforms a array of NaiveDate to an array of Date32 after applying an operation @@ -102,7 +96,7 @@ fn as_timestamp_tz_with_op, T: ArrowTemporalT mut builder: PrimitiveBuilder, tz: &str, op: F, -) -> Result +) -> Result where F: Fn(DateTime) -> i64, i64: From, @@ -113,7 +107,7 @@ where Some(value) => match as_datetime_with_timezone::(value.into(), tz) { Some(time) => builder.append_value(op(time)), _ => { - return Err(ExpressionError::ArrowError( + return Err(SparkError::Internal( "Unable to read value as datetime".to_string(), )); } @@ -129,7 +123,7 @@ fn as_timestamp_tz_with_op_single( builder: &mut PrimitiveBuilder, tz: &Tz, op: F, -) -> Result<(), ExpressionError> +) -> Result<(), SparkError> where F: Fn(DateTime) -> i64, i64: From, @@ -138,7 +132,7 @@ where Some(value) => match as_datetime_with_timezone::(value.into(), *tz) { Some(time) => builder.append_value(op(time)), _ => { - return Err(ExpressionError::ArrowError( + return Err(SparkError::Internal( "Unable to read value as datetime".to_string(), )); } @@ -256,7 +250,7 @@ fn trunc_date_to_microsec(dt: T) -> Option { /// array is an array of Date32 values. The array may be a dictionary array. /// /// format is a scalar string specifying the format to apply to the timestamp value. -pub fn date_trunc_dyn(array: &dyn Array, format: String) -> Result { +pub(crate) fn date_trunc_dyn(array: &dyn Array, format: String) -> Result { match array.data_type().clone() { DataType::Dictionary(_, _) => { downcast_dictionary_array!( @@ -279,10 +273,10 @@ pub fn date_trunc_dyn(array: &dyn Array, format: String) -> Result( +pub(crate) fn date_trunc( array: &PrimitiveArray, format: String, -) -> Result +) -> Result where T: ArrowTemporalType + ArrowNumericType, i64: From, @@ -311,7 +305,7 @@ where builder, |dt| as_days_from_unix_epoch(trunc_date_to_week(dt)), )), - _ => Err(ExpressionError::ArrowError(format!( + _ => Err(SparkError::Internal(format!( "Unsupported format: {:?} for function 'date_trunc'", format ))), @@ -331,10 +325,10 @@ where /// /// format is an array of strings specifying the format to apply to the corresponding date value. /// The array may be a dictionary array. -pub fn date_trunc_array_fmt_dyn( +pub(crate) fn date_trunc_array_fmt_dyn( array: &dyn Array, formats: &dyn Array, -) -> Result { +) -> Result { match (array.data_type().clone(), formats.data_type().clone()) { (DataType::Dictionary(_, v), DataType::Dictionary(_, f)) => { if !matches!(*v, DataType::Date32) { @@ -403,7 +397,7 @@ pub fn date_trunc_array_fmt_dyn( .expect("Unexpected value type in formats"), ) .map(|a| Arc::new(a) as ArrayRef), - (dt, fmt) => Err(ExpressionError::ArrowError(format!( + (dt, fmt) => Err(SparkError::Internal(format!( "Unsupported datatype: {:}, format: {:?} for function 'date_trunc'", dt, fmt ))), @@ -434,7 +428,7 @@ macro_rules! date_trunc_array_fmt_helper { "WEEK" => Ok(as_datetime_with_op_single(val, &mut builder, |dt| { as_days_from_unix_epoch(trunc_date_to_week(dt)) })), - _ => Err(ExpressionError::ArrowError(format!( + _ => Err(SparkError::Internal(format!( "Unsupported format: {:?} for function 'date_trunc'", $formats.value(index) ))), @@ -454,7 +448,7 @@ macro_rules! date_trunc_array_fmt_helper { fn date_trunc_array_fmt_plain_plain( array: &Date32Array, formats: &StringArray, -) -> Result +) -> Result where { let data_type = array.data_type(); @@ -464,7 +458,7 @@ where fn date_trunc_array_fmt_plain_dict( array: &Date32Array, formats: &TypedDictionaryArray, -) -> Result +) -> Result where K: ArrowDictionaryKeyType, { @@ -475,7 +469,7 @@ where fn date_trunc_array_fmt_dict_plain( array: &TypedDictionaryArray, formats: &StringArray, -) -> Result +) -> Result where K: ArrowDictionaryKeyType, { @@ -486,7 +480,7 @@ where fn date_trunc_array_fmt_dict_dict( array: &TypedDictionaryArray, formats: &TypedDictionaryArray, -) -> Result +) -> Result where K: ArrowDictionaryKeyType, F: ArrowDictionaryKeyType, @@ -503,7 +497,10 @@ where /// timezone or no timezone. The array may be a dictionary array. /// /// format is a scalar string specifying the format to apply to the timestamp value. -pub fn timestamp_trunc_dyn(array: &dyn Array, format: String) -> Result { +pub(crate) fn timestamp_trunc_dyn( + array: &dyn Array, + format: String, +) -> Result { match array.data_type().clone() { DataType::Dictionary(_, _) => { downcast_dictionary_array!( @@ -526,10 +523,10 @@ pub fn timestamp_trunc_dyn(array: &dyn Array, format: String) -> Result( +pub(crate) fn timestamp_trunc( array: &PrimitiveArray, format: String, -) -> Result +) -> Result where T: ArrowTemporalType + ArrowNumericType, i64: From, @@ -589,7 +586,7 @@ where as_micros_from_unix_epoch_utc(trunc_date_to_microsec(dt)) }) } - _ => Err(ExpressionError::ArrowError(format!( + _ => Err(SparkError::Internal(format!( "Unsupported format: {:?} for function 'timestamp_trunc'", format ))), @@ -611,10 +608,10 @@ where /// /// format is an array of strings specifying the format to apply to the corresponding timestamp /// value. The array may be a dictionary array. -pub fn timestamp_trunc_array_fmt_dyn( +pub(crate) fn timestamp_trunc_array_fmt_dyn( array: &dyn Array, formats: &dyn Array, -) -> Result { +) -> Result { match (array.data_type().clone(), formats.data_type().clone()) { (DataType::Dictionary(_, _), DataType::Dictionary(_, _)) => { downcast_dictionary_array!( @@ -669,7 +666,7 @@ pub fn timestamp_trunc_array_fmt_dyn( dt => return_compute_error_with!("timestamp_trunc does not support", dt), ) } - (dt, fmt) => Err(ExpressionError::ArrowError(format!( + (dt, fmt) => Err(SparkError::Internal(format!( "Unsupported datatype: {:}, format: {:?} for function 'timestamp_trunc'", dt, fmt ))), @@ -740,7 +737,7 @@ macro_rules! timestamp_trunc_array_fmt_helper { as_micros_from_unix_epoch_utc(trunc_date_to_microsec(dt)) }) } - _ => Err(ExpressionError::ArrowError(format!( + _ => Err(SparkError::Internal(format!( "Unsupported format: {:?} for function 'timestamp_trunc'", $formats.value(index) ))), @@ -762,7 +759,7 @@ macro_rules! timestamp_trunc_array_fmt_helper { fn timestamp_trunc_array_fmt_plain_plain( array: &PrimitiveArray, formats: &StringArray, -) -> Result +) -> Result where T: ArrowTemporalType + ArrowNumericType, i64: From, @@ -773,7 +770,7 @@ where fn timestamp_trunc_array_fmt_plain_dict( array: &PrimitiveArray, formats: &TypedDictionaryArray, -) -> Result +) -> Result where T: ArrowTemporalType + ArrowNumericType, i64: From, @@ -786,7 +783,7 @@ where fn timestamp_trunc_array_fmt_dict_plain( array: &TypedDictionaryArray>, formats: &StringArray, -) -> Result +) -> Result where T: ArrowTemporalType + ArrowNumericType, i64: From, @@ -799,7 +796,7 @@ where fn timestamp_trunc_array_fmt_dict_dict( array: &TypedDictionaryArray>, formats: &TypedDictionaryArray, -) -> Result +) -> Result where T: ArrowTemporalType + ArrowNumericType, i64: From, @@ -812,7 +809,7 @@ where #[cfg(test)] mod tests { - use crate::execution::kernels::temporal::{ + use crate::kernels::temporal::{ date_trunc, date_trunc_array_fmt_dyn, timestamp_trunc, timestamp_trunc_array_fmt_dyn, }; use arrow_array::{ diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs index 3c726f52a..5168e0e80 100644 --- a/native/spark-expr/src/lib.rs +++ b/native/spark-expr/src/lib.rs @@ -16,16 +16,20 @@ // under the License. mod abs; -pub mod cast; +mod cast; mod error; mod if_expr; +mod kernels; +mod temporal; pub mod timezone; pub mod utils; pub use abs::Abs; +pub use cast::Cast; pub use error::{SparkError, SparkResult}; pub use if_expr::IfExpr; +pub use temporal::{DateTruncExec, HourExec, MinuteExec, SecondExec, TimestampTruncExec}; /// Spark supports three evaluation modes when evaluating expressions, which affect /// the behavior when processing input values that are invalid or would result in an diff --git a/native/core/src/execution/datafusion/expressions/temporal.rs b/native/spark-expr/src/temporal.rs similarity index 98% rename from native/core/src/execution/datafusion/expressions/temporal.rs rename to native/spark-expr/src/temporal.rs index 69fbb7910..ea30d3383 100644 --- a/native/core/src/execution/datafusion/expressions/temporal.rs +++ b/native/spark-expr/src/temporal.rs @@ -31,12 +31,10 @@ use datafusion::logical_expr::ColumnarValue; use datafusion_common::{DataFusionError, ScalarValue::Utf8}; use datafusion_physical_expr::PhysicalExpr; -use crate::execution::{ - datafusion::expressions::utils::{array_with_timezone, down_cast_any_ref}, - kernels::temporal::{ - date_trunc_array_fmt_dyn, date_trunc_dyn, timestamp_trunc_array_fmt_dyn, - timestamp_trunc_dyn, - }, +use crate::utils::{array_with_timezone, down_cast_any_ref}; + +use crate::kernels::temporal::{ + date_trunc_array_fmt_dyn, date_trunc_dyn, timestamp_trunc_array_fmt_dyn, timestamp_trunc_dyn, }; #[derive(Debug, Hash)]