Skip to content

Commit

Permalink
chore: Move temporal kernels and expressions to spark-expr crate (#660)
Browse files Browse the repository at this point in the history
* Move temporal expressions to spark-expr crate

* reduce public api

* reduce public api

* update imports in benchmarks

* fmt

* remove unused dep
  • Loading branch information
andygrove authored Jul 15, 2024
1 parent c434872 commit ab1d30a
Show file tree
Hide file tree
Showing 12 changed files with 69 additions and 55 deletions.
1 change: 0 additions & 1 deletion native/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion native/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ bytes = "1.5.0"
tempfile = "3.8.0"
ahash = { version = "0.8", default-features = false }
itertools = "0.11.0"
chrono = { workspace = true }
paste = "1.0.14"
datafusion-common = { workspace = true }
datafusion = { workspace = true }
Expand Down
2 changes: 1 addition & 1 deletion native/core/benches/cast_from_string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

use arrow_array::{builder::StringBuilder, RecordBatch};
use arrow_schema::{DataType, Field, Schema};
use comet::execution::datafusion::expressions::{cast::Cast, EvalMode};
use criterion::{criterion_group, criterion_main, Criterion};
use datafusion_comet_spark_expr::{Cast, EvalMode};
use datafusion_physical_expr::{expressions::Column, PhysicalExpr};
use std::sync::Arc;

Expand Down
2 changes: 1 addition & 1 deletion native/core/benches/cast_numeric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

use arrow_array::{builder::Int32Builder, RecordBatch};
use arrow_schema::{DataType, Field, Schema};
use comet::execution::datafusion::expressions::{cast::Cast, EvalMode};
use criterion::{criterion_group, criterion_main, Criterion};
use datafusion_comet_spark_expr::{Cast, EvalMode};
use datafusion_physical_expr::{expressions::Column, PhysicalExpr};
use std::sync::Arc;

Expand Down
2 changes: 0 additions & 2 deletions native/core/src/execution/datafusion/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
//! Native DataFusion expressions
pub mod bitwise_not;
pub use datafusion_comet_spark_expr::cast;
pub mod checkoverflow;
mod normalize_nan;
pub mod scalar_funcs;
Expand All @@ -37,7 +36,6 @@ pub mod stddev;
pub mod strings;
pub mod subquery;
pub mod sum_decimal;
pub mod temporal;
pub mod unbound;
mod utils;
pub mod variance;
Expand Down
2 changes: 1 addition & 1 deletion native/core/src/execution/datafusion/expressions/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@
// under the License.

// re-export for legacy reasons
pub use datafusion_comet_spark_expr::utils::{array_with_timezone, down_cast_any_ref};
pub use datafusion_comet_spark_expr::utils::down_cast_any_ref;
6 changes: 3 additions & 3 deletions native/core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ use crate::{
avg_decimal::AvgDecimal,
bitwise_not::BitwiseNotExpr,
bloom_filter_might_contain::BloomFilterMightContain,
cast::Cast,
checkoverflow::CheckOverflow,
correlation::Correlation,
covariance::Covariance,
Expand All @@ -86,7 +85,6 @@ use crate::{
strings::{Contains, EndsWith, Like, StartsWith, StringSpaceExec, SubstringExec},
subquery::Subquery,
sum_decimal::SumDecimal,
temporal::{DateTruncExec, HourExec, MinuteExec, SecondExec, TimestampTruncExec},
unbound::UnboundColumn,
variance::Variance,
NormalizeNaNAndZero,
Expand All @@ -107,7 +105,9 @@ use crate::{
};

use super::expressions::{create_named_struct::CreateNamedStruct, EvalMode};
use datafusion_comet_spark_expr::{Abs, IfExpr};
use datafusion_comet_spark_expr::{
Abs, Cast, DateTruncExec, HourExec, IfExpr, MinuteExec, SecondExec, TimestampTruncExec,
};

// For clippy error on type_complexity.
type ExecResult<T> = Result<T, ExecutionError>;
Expand Down
1 change: 0 additions & 1 deletion native/core/src/execution/kernels/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,3 @@ mod hash;
pub use hash::hash;

pub(crate) mod strings;
pub(crate) mod temporal;
20 changes: 20 additions & 0 deletions native/spark-expr/src/kernels/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! Kernels
pub(crate) mod temporal;
Original file line number Diff line number Diff line change
Expand Up @@ -32,24 +32,18 @@ use arrow_array::{

use arrow_schema::TimeUnit;

use crate::errors::ExpressionError;
use crate::SparkError;

// Copied from arrow_arith/temporal.rs
macro_rules! return_compute_error_with {
($msg:expr, $param:expr) => {
return {
Err(ExpressionError::ArrowError(format!(
"{}: {:?}",
$msg, $param
)))
}
return { Err(SparkError::Internal(format!("{}: {:?}", $msg, $param))) }
};
}

// The number of days between the beginning of the proleptic gregorian calendar (0001-01-01)
// and the beginning of the Unix Epoch (1970-01-01)
const DAYS_TO_UNIX_EPOCH: i32 = 719_163;
const MICROS_TO_UNIX_EPOCH: i64 = 62_167_132_800 * 1_000_000;

// Copied from arrow_arith/temporal.rs with modification to the output datatype
// Transforms a array of NaiveDate to an array of Date32 after applying an operation
Expand Down Expand Up @@ -102,7 +96,7 @@ fn as_timestamp_tz_with_op<A: ArrayAccessor<Item = T::Native>, T: ArrowTemporalT
mut builder: PrimitiveBuilder<TimestampMicrosecondType>,
tz: &str,
op: F,
) -> Result<TimestampMicrosecondArray, ExpressionError>
) -> Result<TimestampMicrosecondArray, SparkError>
where
F: Fn(DateTime<Tz>) -> i64,
i64: From<T::Native>,
Expand All @@ -113,7 +107,7 @@ where
Some(value) => match as_datetime_with_timezone::<T>(value.into(), tz) {
Some(time) => builder.append_value(op(time)),
_ => {
return Err(ExpressionError::ArrowError(
return Err(SparkError::Internal(
"Unable to read value as datetime".to_string(),
));
}
Expand All @@ -129,7 +123,7 @@ fn as_timestamp_tz_with_op_single<T: ArrowTemporalType, F>(
builder: &mut PrimitiveBuilder<TimestampMicrosecondType>,
tz: &Tz,
op: F,
) -> Result<(), ExpressionError>
) -> Result<(), SparkError>
where
F: Fn(DateTime<Tz>) -> i64,
i64: From<T::Native>,
Expand All @@ -138,7 +132,7 @@ where
Some(value) => match as_datetime_with_timezone::<T>(value.into(), *tz) {
Some(time) => builder.append_value(op(time)),
_ => {
return Err(ExpressionError::ArrowError(
return Err(SparkError::Internal(
"Unable to read value as datetime".to_string(),
));
}
Expand Down Expand Up @@ -256,7 +250,7 @@ fn trunc_date_to_microsec<T: Timelike>(dt: T) -> Option<T> {
/// array is an array of Date32 values. The array may be a dictionary array.
///
/// format is a scalar string specifying the format to apply to the timestamp value.
pub fn date_trunc_dyn(array: &dyn Array, format: String) -> Result<ArrayRef, ExpressionError> {
pub(crate) fn date_trunc_dyn(array: &dyn Array, format: String) -> Result<ArrayRef, SparkError> {
match array.data_type().clone() {
DataType::Dictionary(_, _) => {
downcast_dictionary_array!(
Expand All @@ -279,10 +273,10 @@ pub fn date_trunc_dyn(array: &dyn Array, format: String) -> Result<ArrayRef, Exp
}
}

pub fn date_trunc<T>(
pub(crate) fn date_trunc<T>(
array: &PrimitiveArray<T>,
format: String,
) -> Result<Date32Array, ExpressionError>
) -> Result<Date32Array, SparkError>
where
T: ArrowTemporalType + ArrowNumericType,
i64: From<T::Native>,
Expand Down Expand Up @@ -311,7 +305,7 @@ where
builder,
|dt| as_days_from_unix_epoch(trunc_date_to_week(dt)),
)),
_ => Err(ExpressionError::ArrowError(format!(
_ => Err(SparkError::Internal(format!(
"Unsupported format: {:?} for function 'date_trunc'",
format
))),
Expand All @@ -331,10 +325,10 @@ where
///
/// format is an array of strings specifying the format to apply to the corresponding date value.
/// The array may be a dictionary array.
pub fn date_trunc_array_fmt_dyn(
pub(crate) fn date_trunc_array_fmt_dyn(
array: &dyn Array,
formats: &dyn Array,
) -> Result<ArrayRef, ExpressionError> {
) -> Result<ArrayRef, SparkError> {
match (array.data_type().clone(), formats.data_type().clone()) {
(DataType::Dictionary(_, v), DataType::Dictionary(_, f)) => {
if !matches!(*v, DataType::Date32) {
Expand Down Expand Up @@ -403,7 +397,7 @@ pub fn date_trunc_array_fmt_dyn(
.expect("Unexpected value type in formats"),
)
.map(|a| Arc::new(a) as ArrayRef),
(dt, fmt) => Err(ExpressionError::ArrowError(format!(
(dt, fmt) => Err(SparkError::Internal(format!(
"Unsupported datatype: {:}, format: {:?} for function 'date_trunc'",
dt, fmt
))),
Expand Down Expand Up @@ -434,7 +428,7 @@ macro_rules! date_trunc_array_fmt_helper {
"WEEK" => Ok(as_datetime_with_op_single(val, &mut builder, |dt| {
as_days_from_unix_epoch(trunc_date_to_week(dt))
})),
_ => Err(ExpressionError::ArrowError(format!(
_ => Err(SparkError::Internal(format!(
"Unsupported format: {:?} for function 'date_trunc'",
$formats.value(index)
))),
Expand All @@ -454,7 +448,7 @@ macro_rules! date_trunc_array_fmt_helper {
fn date_trunc_array_fmt_plain_plain(
array: &Date32Array,
formats: &StringArray,
) -> Result<Date32Array, ExpressionError>
) -> Result<Date32Array, SparkError>
where
{
let data_type = array.data_type();
Expand All @@ -464,7 +458,7 @@ where
fn date_trunc_array_fmt_plain_dict<K>(
array: &Date32Array,
formats: &TypedDictionaryArray<K, StringArray>,
) -> Result<Date32Array, ExpressionError>
) -> Result<Date32Array, SparkError>
where
K: ArrowDictionaryKeyType,
{
Expand All @@ -475,7 +469,7 @@ where
fn date_trunc_array_fmt_dict_plain<K>(
array: &TypedDictionaryArray<K, Date32Array>,
formats: &StringArray,
) -> Result<Date32Array, ExpressionError>
) -> Result<Date32Array, SparkError>
where
K: ArrowDictionaryKeyType,
{
Expand All @@ -486,7 +480,7 @@ where
fn date_trunc_array_fmt_dict_dict<K, F>(
array: &TypedDictionaryArray<K, Date32Array>,
formats: &TypedDictionaryArray<F, StringArray>,
) -> Result<Date32Array, ExpressionError>
) -> Result<Date32Array, SparkError>
where
K: ArrowDictionaryKeyType,
F: ArrowDictionaryKeyType,
Expand All @@ -503,7 +497,10 @@ where
/// timezone or no timezone. The array may be a dictionary array.
///
/// format is a scalar string specifying the format to apply to the timestamp value.
pub fn timestamp_trunc_dyn(array: &dyn Array, format: String) -> Result<ArrayRef, ExpressionError> {
pub(crate) fn timestamp_trunc_dyn(
array: &dyn Array,
format: String,
) -> Result<ArrayRef, SparkError> {
match array.data_type().clone() {
DataType::Dictionary(_, _) => {
downcast_dictionary_array!(
Expand All @@ -526,10 +523,10 @@ pub fn timestamp_trunc_dyn(array: &dyn Array, format: String) -> Result<ArrayRef
}
}

pub fn timestamp_trunc<T>(
pub(crate) fn timestamp_trunc<T>(
array: &PrimitiveArray<T>,
format: String,
) -> Result<TimestampMicrosecondArray, ExpressionError>
) -> Result<TimestampMicrosecondArray, SparkError>
where
T: ArrowTemporalType + ArrowNumericType,
i64: From<T::Native>,
Expand Down Expand Up @@ -589,7 +586,7 @@ where
as_micros_from_unix_epoch_utc(trunc_date_to_microsec(dt))
})
}
_ => Err(ExpressionError::ArrowError(format!(
_ => Err(SparkError::Internal(format!(
"Unsupported format: {:?} for function 'timestamp_trunc'",
format
))),
Expand All @@ -611,10 +608,10 @@ where
///
/// format is an array of strings specifying the format to apply to the corresponding timestamp
/// value. The array may be a dictionary array.
pub fn timestamp_trunc_array_fmt_dyn(
pub(crate) fn timestamp_trunc_array_fmt_dyn(
array: &dyn Array,
formats: &dyn Array,
) -> Result<ArrayRef, ExpressionError> {
) -> Result<ArrayRef, SparkError> {
match (array.data_type().clone(), formats.data_type().clone()) {
(DataType::Dictionary(_, _), DataType::Dictionary(_, _)) => {
downcast_dictionary_array!(
Expand Down Expand Up @@ -669,7 +666,7 @@ pub fn timestamp_trunc_array_fmt_dyn(
dt => return_compute_error_with!("timestamp_trunc does not support", dt),
)
}
(dt, fmt) => Err(ExpressionError::ArrowError(format!(
(dt, fmt) => Err(SparkError::Internal(format!(
"Unsupported datatype: {:}, format: {:?} for function 'timestamp_trunc'",
dt, fmt
))),
Expand Down Expand Up @@ -740,7 +737,7 @@ macro_rules! timestamp_trunc_array_fmt_helper {
as_micros_from_unix_epoch_utc(trunc_date_to_microsec(dt))
})
}
_ => Err(ExpressionError::ArrowError(format!(
_ => Err(SparkError::Internal(format!(
"Unsupported format: {:?} for function 'timestamp_trunc'",
$formats.value(index)
))),
Expand All @@ -762,7 +759,7 @@ macro_rules! timestamp_trunc_array_fmt_helper {
fn timestamp_trunc_array_fmt_plain_plain<T>(
array: &PrimitiveArray<T>,
formats: &StringArray,
) -> Result<TimestampMicrosecondArray, ExpressionError>
) -> Result<TimestampMicrosecondArray, SparkError>
where
T: ArrowTemporalType + ArrowNumericType,
i64: From<T::Native>,
Expand All @@ -773,7 +770,7 @@ where
fn timestamp_trunc_array_fmt_plain_dict<T, K>(
array: &PrimitiveArray<T>,
formats: &TypedDictionaryArray<K, StringArray>,
) -> Result<TimestampMicrosecondArray, ExpressionError>
) -> Result<TimestampMicrosecondArray, SparkError>
where
T: ArrowTemporalType + ArrowNumericType,
i64: From<T::Native>,
Expand All @@ -786,7 +783,7 @@ where
fn timestamp_trunc_array_fmt_dict_plain<T, K>(
array: &TypedDictionaryArray<K, PrimitiveArray<T>>,
formats: &StringArray,
) -> Result<TimestampMicrosecondArray, ExpressionError>
) -> Result<TimestampMicrosecondArray, SparkError>
where
T: ArrowTemporalType + ArrowNumericType,
i64: From<T::Native>,
Expand All @@ -799,7 +796,7 @@ where
fn timestamp_trunc_array_fmt_dict_dict<T, K, F>(
array: &TypedDictionaryArray<K, PrimitiveArray<T>>,
formats: &TypedDictionaryArray<F, StringArray>,
) -> Result<TimestampMicrosecondArray, ExpressionError>
) -> Result<TimestampMicrosecondArray, SparkError>
where
T: ArrowTemporalType + ArrowNumericType,
i64: From<T::Native>,
Expand All @@ -812,7 +809,7 @@ where

#[cfg(test)]
mod tests {
use crate::execution::kernels::temporal::{
use crate::kernels::temporal::{
date_trunc, date_trunc_array_fmt_dyn, timestamp_trunc, timestamp_trunc_array_fmt_dyn,
};
use arrow_array::{
Expand Down
Loading

0 comments on commit ab1d30a

Please sign in to comment.