diff --git a/native/core/src/execution/datafusion/planner.rs b/native/core/src/execution/datafusion/planner.rs index 83f86dbee..33c4924cb 100644 --- a/native/core/src/execution/datafusion/planner.rs +++ b/native/core/src/execution/datafusion/planner.rs @@ -100,7 +100,8 @@ use datafusion_comet_proto::{ }; use datafusion_comet_spark_expr::{ ArrayInsert, Cast, CreateNamedStruct, DateTruncExpr, GetArrayStructFields, GetStructField, - HourExpr, IfExpr, ListExtract, MinuteExpr, RLike, SecondExpr, TimestampTruncExpr, ToJson, + HourExpr, IfExpr, ListExtract, MinuteExpr, RLike, SecondExpr, SparkCastOptions, + TimestampTruncExpr, ToJson, }; use datafusion_common::scalar::ScalarStructBuilder; use datafusion_common::{ @@ -388,14 +389,11 @@ impl PhysicalPlanner { ExprStruct::Cast(expr) => { let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?; let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); - let timezone = expr.timezone.clone(); let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; Ok(Arc::new(Cast::new( child, datatype, - eval_mode, - timezone, - expr.allow_incompat, + SparkCastOptions::new(eval_mode, &expr.timezone, expr.allow_incompat), ))) } ExprStruct::Hour(expr) => { @@ -806,24 +804,21 @@ impl PhysicalPlanner { let data_type = return_type.map(to_arrow_datatype).unwrap(); // For some Decimal128 operations, we need wider internal digits. // Cast left and right to Decimal256 and cast the result back to Decimal128 - let left = Arc::new(Cast::new_without_timezone( + let left = Arc::new(Cast::new( left, DataType::Decimal256(p1, s1), - EvalMode::Legacy, - false, + SparkCastOptions::new_without_timezone(EvalMode::Legacy, false), )); - let right = Arc::new(Cast::new_without_timezone( + let right = Arc::new(Cast::new( right, DataType::Decimal256(p2, s2), - EvalMode::Legacy, - false, + SparkCastOptions::new_without_timezone(EvalMode::Legacy, false), )); let child = Arc::new(BinaryExpr::new(left, op, right)); - Ok(Arc::new(Cast::new_without_timezone( + Ok(Arc::new(Cast::new( child, data_type, - EvalMode::Legacy, - false, + SparkCastOptions::new_without_timezone(EvalMode::Legacy, false), ))) } ( diff --git a/native/spark-expr/benches/cast_from_string.rs b/native/spark-expr/benches/cast_from_string.rs index 056ada2eb..c6b0bcf39 100644 --- a/native/spark-expr/benches/cast_from_string.rs +++ b/native/spark-expr/benches/cast_from_string.rs @@ -18,36 +18,18 @@ use arrow_array::{builder::StringBuilder, RecordBatch}; use arrow_schema::{DataType, Field, Schema}; use criterion::{criterion_group, criterion_main, Criterion}; -use datafusion_comet_spark_expr::{Cast, EvalMode}; +use datafusion_comet_spark_expr::{Cast, EvalMode, SparkCastOptions}; use datafusion_physical_expr::{expressions::Column, PhysicalExpr}; use std::sync::Arc; fn criterion_benchmark(c: &mut Criterion) { let batch = create_utf8_batch(); let expr = Arc::new(Column::new("a", 0)); - let timezone = "".to_string(); - let cast_string_to_i8 = Cast::new( - expr.clone(), - DataType::Int8, - EvalMode::Legacy, - timezone.clone(), - false, - ); - let cast_string_to_i16 = Cast::new( - expr.clone(), - DataType::Int16, - EvalMode::Legacy, - timezone.clone(), - false, - ); - let cast_string_to_i32 = Cast::new( - expr.clone(), - DataType::Int32, - EvalMode::Legacy, - timezone.clone(), - false, - ); - let cast_string_to_i64 = Cast::new(expr, DataType::Int64, EvalMode::Legacy, timezone, false); + let spark_cast_options = SparkCastOptions::new(EvalMode::Legacy, "", false); + let cast_string_to_i8 = Cast::new(expr.clone(), DataType::Int8, spark_cast_options.clone()); + let cast_string_to_i16 = Cast::new(expr.clone(), DataType::Int16, spark_cast_options.clone()); + let cast_string_to_i32 = Cast::new(expr.clone(), DataType::Int32, spark_cast_options.clone()); + let cast_string_to_i64 = Cast::new(expr, DataType::Int64, spark_cast_options); let mut group = c.benchmark_group("cast_string_to_int"); group.bench_function("cast_string_to_i8", |b| { diff --git a/native/spark-expr/benches/cast_numeric.rs b/native/spark-expr/benches/cast_numeric.rs index 15ef1a5a2..8ec8b2f89 100644 --- a/native/spark-expr/benches/cast_numeric.rs +++ b/native/spark-expr/benches/cast_numeric.rs @@ -18,29 +18,17 @@ use arrow_array::{builder::Int32Builder, RecordBatch}; use arrow_schema::{DataType, Field, Schema}; use criterion::{criterion_group, criterion_main, Criterion}; -use datafusion_comet_spark_expr::{Cast, EvalMode}; +use datafusion_comet_spark_expr::{Cast, EvalMode, SparkCastOptions}; use datafusion_physical_expr::{expressions::Column, PhysicalExpr}; use std::sync::Arc; fn criterion_benchmark(c: &mut Criterion) { let batch = create_int32_batch(); let expr = Arc::new(Column::new("a", 0)); - let timezone = "".to_string(); - let cast_i32_to_i8 = Cast::new( - expr.clone(), - DataType::Int8, - EvalMode::Legacy, - timezone.clone(), - false, - ); - let cast_i32_to_i16 = Cast::new( - expr.clone(), - DataType::Int16, - EvalMode::Legacy, - timezone.clone(), - false, - ); - let cast_i32_to_i64 = Cast::new(expr, DataType::Int64, EvalMode::Legacy, timezone, false); + let spark_cast_options = SparkCastOptions::new_without_timezone(EvalMode::Legacy, false); + let cast_i32_to_i8 = Cast::new(expr.clone(), DataType::Int8, spark_cast_options.clone()); + let cast_i32_to_i16 = Cast::new(expr.clone(), DataType::Int16, spark_cast_options.clone()); + let cast_i32_to_i64 = Cast::new(expr, DataType::Int64, spark_cast_options); let mut group = c.benchmark_group("cast_int_to_int"); group.bench_function("cast_i32_to_i8", |b| { diff --git a/native/spark-expr/src/cast.rs b/native/spark-expr/src/cast.rs index 13263a595..f62d0220c 100644 --- a/native/spark-expr/src/cast.rs +++ b/native/spark-expr/src/cast.rs @@ -138,14 +138,7 @@ impl TimeStampInfo { pub struct Cast { pub child: Arc, pub data_type: DataType, - pub eval_mode: EvalMode, - - /// When cast from/to timezone related types, we need timezone, which will be resolved with - /// session local timezone by an analyzer in Spark. - pub timezone: String, - - /// Whether to allow casts that are known to be incompatible with Spark - pub allow_incompat: bool, + pub cast_options: SparkCastOptions, } macro_rules! cast_utf8_to_int { @@ -547,30 +540,41 @@ impl Cast { pub fn new( child: Arc, data_type: DataType, - eval_mode: EvalMode, - timezone: String, - allow_incompat: bool, + cast_options: SparkCastOptions, ) -> Self { Self { child, data_type, - timezone, + cast_options, + } + } +} + +/// Spark cast options +#[derive(Debug, Clone, Hash, PartialEq, Eq)] +pub struct SparkCastOptions { + /// Spark evaluation mode + pub eval_mode: EvalMode, + /// When cast from/to timezone related types, we need timezone, which will be resolved with + /// session local timezone by an analyzer in Spark. + pub timezone: String, + /// Allow casts that are supported but not guaranteed to be 100% compatible + pub allow_incompat: bool, +} + +impl SparkCastOptions { + pub fn new(eval_mode: EvalMode, timezone: &str, allow_incompat: bool) -> Self { + Self { eval_mode, + timezone: timezone.to_string(), allow_incompat, } } - pub fn new_without_timezone( - child: Arc, - data_type: DataType, - eval_mode: EvalMode, - allow_incompat: bool, - ) -> Self { + pub fn new_without_timezone(eval_mode: EvalMode, allow_incompat: bool) -> Self { Self { - child, - data_type, - timezone: "".to_string(), eval_mode, + timezone: "".to_string(), allow_incompat, } } @@ -582,33 +586,21 @@ impl Cast { pub fn spark_cast( arg: ColumnarValue, data_type: &DataType, - eval_mode: EvalMode, - timezone: &str, - allow_incompat: bool, + cast_options: &SparkCastOptions, ) -> DataFusionResult { match arg { ColumnarValue::Array(array) => Ok(ColumnarValue::Array(cast_array( array, data_type, - eval_mode, - timezone.to_owned(), - allow_incompat, + cast_options, )?)), ColumnarValue::Scalar(scalar) => { // Note that normally CAST(scalar) should be fold in Spark JVM side. However, for // some cases e.g., scalar subquery, Spark will not fold it, so we need to handle it // here. let array = scalar.to_array()?; - let scalar = ScalarValue::try_from_array( - &cast_array( - array, - data_type, - eval_mode, - timezone.to_owned(), - allow_incompat, - )?, - 0, - )?; + let scalar = + ScalarValue::try_from_array(&cast_array(array, data_type, cast_options)?, 0)?; Ok(ColumnarValue::Scalar(scalar)) } } @@ -617,12 +609,11 @@ pub fn spark_cast( fn cast_array( array: ArrayRef, to_type: &DataType, - eval_mode: EvalMode, - timezone: String, - allow_incompat: bool, + cast_options: &SparkCastOptions, ) -> DataFusionResult { - let array = array_with_timezone(array, timezone.clone(), Some(to_type))?; + let array = array_with_timezone(array, cast_options.timezone.clone(), Some(to_type))?; let from_type = array.data_type().clone(); + let array = match &from_type { DataType::Dictionary(key_type, value_type) if key_type.as_ref() == &DataType::Int32 @@ -636,13 +627,7 @@ fn cast_array( let casted_dictionary = DictionaryArray::::new( dict_array.keys().clone(), - cast_array( - Arc::clone(dict_array.values()), - to_type, - eval_mode, - timezone, - allow_incompat, - )?, + cast_array(Arc::clone(dict_array.values()), to_type, cast_options)?, ); let casted_result = match to_type { @@ -654,6 +639,7 @@ fn cast_array( _ => array, }; let from_type = array.data_type(); + let eval_mode = cast_options.eval_mode; let cast_result = match (from_type, to_type) { (DataType::Utf8, DataType::Boolean) => spark_cast_utf8_to_boolean::(&array, eval_mode), @@ -661,7 +647,7 @@ fn cast_array( spark_cast_utf8_to_boolean::(&array, eval_mode) } (DataType::Utf8, DataType::Timestamp(_, _)) => { - cast_string_to_timestamp(&array, to_type, eval_mode, &timezone) + cast_string_to_timestamp(&array, to_type, eval_mode, &cast_options.timezone) } (DataType::Utf8, DataType::Date32) => cast_string_to_date(&array, to_type, eval_mode), (DataType::Int64, DataType::Int32) @@ -712,17 +698,15 @@ fn cast_array( spark_cast_nonintegral_numeric_to_integral(&array, eval_mode, from_type, to_type) } (DataType::Struct(_), DataType::Utf8) => { - Ok(casts_struct_to_string(array.as_struct(), &timezone)?) + Ok(casts_struct_to_string(array.as_struct(), cast_options)?) } (DataType::Struct(_), DataType::Struct(_)) => Ok(cast_struct_to_struct( array.as_struct(), from_type, to_type, - eval_mode, - timezone, - allow_incompat, + cast_options, )?), - _ if is_datafusion_spark_compatible(from_type, to_type, allow_incompat) => { + _ if is_datafusion_spark_compatible(from_type, to_type, cast_options.allow_incompat) => { // use DataFusion cast only when we know that it is compatible with Spark Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?) } @@ -825,9 +809,7 @@ fn cast_struct_to_struct( array: &StructArray, from_type: &DataType, to_type: &DataType, - eval_mode: EvalMode, - timezone: String, - allow_incompat: bool, + cast_options: &SparkCastOptions, ) -> DataFusionResult { match (from_type, to_type) { (DataType::Struct(_), DataType::Struct(to_fields)) => { @@ -836,9 +818,7 @@ fn cast_struct_to_struct( let cast_field = cast_array( Arc::clone(array.column(i)), to_fields[i].data_type(), - eval_mode, - timezone.clone(), - allow_incompat, + cast_options, )?; cast_fields.push((Arc::clone(&to_fields[i]), cast_field)); } @@ -848,7 +828,10 @@ fn cast_struct_to_struct( } } -fn casts_struct_to_string(array: &StructArray, timezone: &str) -> DataFusionResult { +fn casts_struct_to_string( + array: &StructArray, + spark_cast_options: &SparkCastOptions, +) -> DataFusionResult { // cast each field to a string let string_arrays: Vec = array .columns() @@ -857,9 +840,7 @@ fn casts_struct_to_string(array: &StructArray, timezone: &str) -> DataFusionResu spark_cast( ColumnarValue::Array(Arc::clone(arr)), &DataType::Utf8, - EvalMode::Legacy, - timezone, - true, + spark_cast_options, ) .and_then(|cv| cv.into_array(arr.len())) }) @@ -1464,7 +1445,7 @@ impl Display for Cast { write!( f, "Cast [data_type: {}, timezone: {}, child: {}, eval_mode: {:?}]", - self.data_type, self.timezone, self.child, &self.eval_mode + self.data_type, self.cast_options.timezone, self.child, &self.cast_options.eval_mode ) } } @@ -1475,9 +1456,8 @@ impl PartialEq for Cast { .downcast_ref::() .map(|x| { self.child.eq(&x.child) - && self.timezone.eq(&x.timezone) + && self.cast_options.eq(&x.cast_options) && self.data_type.eq(&x.data_type) - && self.eval_mode.eq(&x.eval_mode) }) .unwrap_or(false) } @@ -1498,13 +1478,7 @@ impl PhysicalExpr for Cast { fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult { let arg = self.child.evaluate(batch)?; - spark_cast( - arg, - &self.data_type, - self.eval_mode, - &self.timezone, - self.allow_incompat, - ) + spark_cast(arg, &self.data_type, &self.cast_options) } fn children(&self) -> Vec<&Arc> { @@ -1519,9 +1493,7 @@ impl PhysicalExpr for Cast { 1 => Ok(Arc::new(Cast::new( Arc::clone(&children[0]), self.data_type.clone(), - self.eval_mode, - self.timezone.clone(), - self.allow_incompat, + self.cast_options.clone(), ))), _ => internal_err!("Cast should have exactly one child"), } @@ -1531,9 +1503,7 @@ impl PhysicalExpr for Cast { let mut s = state; self.child.hash(&mut s); self.data_type.hash(&mut s); - self.timezone.hash(&mut s); - self.eval_mode.hash(&mut s); - self.allow_incompat.hash(&mut s); + self.cast_options.hash(&mut s); self.hash(&mut s); } } @@ -2110,12 +2080,11 @@ mod tests { let timezone = "UTC".to_string(); // test casting string dictionary array to timestamp array + let cast_options = SparkCastOptions::new(EvalMode::Legacy, &timezone, false); let result = cast_array( dict_array, &DataType::Timestamp(TimeUnit::Microsecond, Some(timezone.clone().into())), - EvalMode::Legacy, - timezone.clone(), - false, + &cast_options, )?; assert_eq!( *result.data_type(), @@ -2320,12 +2289,11 @@ mod tests { fn test_cast_unsupported_timestamp_to_date() { // Since datafusion uses chrono::Datetime internally not all dates representable by TimestampMicrosecondType are supported let timestamps: PrimitiveArray = vec![i64::MAX].into(); + let cast_options = SparkCastOptions::new(EvalMode::Legacy, "UTC", false); let result = cast_array( Arc::new(timestamps.with_timezone("Europe/Copenhagen")), &DataType::Date32, - EvalMode::Legacy, - "UTC".to_owned(), - false, + &cast_options, ); assert!(result.is_err()) } @@ -2333,12 +2301,11 @@ mod tests { #[test] fn test_cast_invalid_timezone() { let timestamps: PrimitiveArray = vec![i64::MAX].into(); + let cast_options = SparkCastOptions::new(EvalMode::Legacy, "Not a valid timezone", false); let result = cast_array( Arc::new(timestamps.with_timezone("Europe/Copenhagen")), &DataType::Date32, - EvalMode::Legacy, - "Not a valid timezone".to_owned(), - false, + &cast_options, ); assert!(result.is_err()) } @@ -2360,9 +2327,7 @@ mod tests { let string_array = cast_array( c, &DataType::Utf8, - EvalMode::Legacy, - "UTC".to_owned(), - false, + &SparkCastOptions::new(EvalMode::Legacy, "UTC", false), ) .unwrap(); let string_array = string_array.as_string::(); @@ -2396,9 +2361,7 @@ mod tests { let cast_array = spark_cast( ColumnarValue::Array(c), &DataType::Struct(fields), - EvalMode::Legacy, - "UTC", - false, + &SparkCastOptions::new(EvalMode::Legacy, "UTC", false), ) .unwrap(); if let ColumnarValue::Array(cast_array) = cast_array { @@ -2429,9 +2392,7 @@ mod tests { let cast_array = spark_cast( ColumnarValue::Array(c), &DataType::Struct(fields), - EvalMode::Legacy, - "UTC", - false, + &SparkCastOptions::new(EvalMode::Legacy, "UTC", false), ) .unwrap(); if let ColumnarValue::Array(cast_array) = cast_array { diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs index 3ec2e886b..c227b3a02 100644 --- a/native/spark-expr/src/lib.rs +++ b/native/spark-expr/src/lib.rs @@ -34,7 +34,7 @@ pub mod timezone; mod to_json; pub mod utils; -pub use cast::{spark_cast, Cast}; +pub use cast::{spark_cast, Cast, SparkCastOptions}; pub use error::{SparkError, SparkResult}; pub use if_expr::IfExpr; pub use list::{ArrayInsert, GetArrayStructFields, ListExtract}; @@ -47,7 +47,7 @@ pub use to_json::ToJson; /// the behavior when processing input values that are invalid or would result in an /// error, such as divide by zero errors, and also affects behavior when converting /// between types. -#[derive(Debug, Hash, PartialEq, Clone, Copy)] +#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)] pub enum EvalMode { /// Legacy is the default behavior in Spark prior to Spark 4.0. This mode silently ignores /// or replaces errors during SQL operations. Operations resulting in errors (like diff --git a/native/spark-expr/src/to_json.rs b/native/spark-expr/src/to_json.rs index 7d38cbf1b..1f68eb860 100644 --- a/native/spark-expr/src/to_json.rs +++ b/native/spark-expr/src/to_json.rs @@ -19,6 +19,7 @@ // of the Spark-specific compatibility features that we need (including // being able to specify Spark-compatible cast from all types to string) +use crate::cast::SparkCastOptions; use crate::{spark_cast, EvalMode}; use arrow_array::builder::StringBuilder; use arrow_array::{Array, ArrayRef, RecordBatch, StringArray, StructArray}; @@ -117,9 +118,7 @@ fn array_to_json_string(arr: &Arc, timezone: &str) -> Result