diff --git a/native/core/src/execution/datafusion/planner.rs b/native/core/src/execution/datafusion/planner.rs index 40fe4515c..bbe452cf8 100644 --- a/native/core/src/execution/datafusion/planner.rs +++ b/native/core/src/execution/datafusion/planner.rs @@ -19,7 +19,7 @@ use super::expressions::EvalMode; use crate::execution::datafusion::expressions::comet_scalar_funcs::create_comet_physical_fun; -use crate::execution::operators::{CopyMode, FilterExec}; +use crate::execution::operators::{CopyMode, FilterExec as CometFilterExec}; use crate::{ errors::ExpressionError, execution::{ @@ -55,6 +55,7 @@ use datafusion::functions_aggregate::bit_and_or_xor::{bit_and_udaf, bit_or_udaf, use datafusion::functions_aggregate::min_max::max_udaf; use datafusion::functions_aggregate::min_max::min_udaf; use datafusion::functions_aggregate::sum::sum_udaf; +use datafusion::physical_plan::filter::FilterExec; use datafusion::physical_plan::windows::BoundedWindowAggExec; use datafusion::physical_plan::InputOrderMode; use datafusion::{ @@ -102,7 +103,7 @@ use datafusion_comet_proto::{ }; use datafusion_comet_spark_expr::{ Cast, CreateNamedStruct, DateTruncExpr, GetArrayStructFields, GetStructField, HourExpr, IfExpr, - ListExtract, MinuteExpr, RLike, SecondExpr, TimestampTruncExpr, ToJson, + ListExtract, MinuteExpr, RLike, SecondExpr, SparkCastOptions, TimestampTruncExpr, ToJson, }; use datafusion_common::config::TableParquetOptions; use datafusion_common::scalar::ScalarStructBuilder; @@ -392,14 +393,11 @@ impl PhysicalPlanner { ExprStruct::Cast(expr) => { let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?; let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); - let timezone = expr.timezone.clone(); let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; Ok(Arc::new(Cast::new( child, datatype, - eval_mode, - timezone, - expr.allow_incompat, + SparkCastOptions::new(eval_mode, &expr.timezone, expr.allow_incompat), ))) } ExprStruct::Hour(expr) => { @@ -767,24 +765,21 @@ impl PhysicalPlanner { let data_type = return_type.map(to_arrow_datatype).unwrap(); // For some Decimal128 operations, we need wider internal digits. // Cast left and right to Decimal256 and cast the result back to Decimal128 - let left = Arc::new(Cast::new_without_timezone( + let left = Arc::new(Cast::new( left, DataType::Decimal256(p1, s1), - EvalMode::Legacy, - false, + SparkCastOptions::new_without_timezone(EvalMode::Legacy, false), )); - let right = Arc::new(Cast::new_without_timezone( + let right = Arc::new(Cast::new( right, DataType::Decimal256(p2, s2), - EvalMode::Legacy, - false, + SparkCastOptions::new_without_timezone(EvalMode::Legacy, false), )); let child = Arc::new(BinaryExpr::new(left, op, right)); - Ok(Arc::new(Cast::new_without_timezone( + Ok(Arc::new(Cast::new( child, data_type, - EvalMode::Legacy, - false, + SparkCastOptions::new_without_timezone(EvalMode::Legacy, false), ))) } ( @@ -851,7 +846,11 @@ impl PhysicalPlanner { let predicate = self.create_expr(filter.predicate.as_ref().unwrap(), child.schema())?; - Ok((scans, Arc::new(FilterExec::try_new(predicate, child)?))) + if can_reuse_input_batch(&child) { + Ok((scans, Arc::new(CometFilterExec::try_new(predicate, child)?))) + } else { + Ok((scans, Arc::new(FilterExec::try_new(predicate, child)?))) + } } OpStruct::HashAgg(agg) => { assert!(children.len() == 1); diff --git a/native/core/src/execution/datafusion/schema_adapter.rs b/native/core/src/execution/datafusion/schema_adapter.rs index 664f92d4c..ce858f65b 100644 --- a/native/core/src/execution/datafusion/schema_adapter.rs +++ b/native/core/src/execution/datafusion/schema_adapter.rs @@ -19,9 +19,9 @@ use arrow::compute::can_cast_types; use arrow_array::{new_null_array, Array, RecordBatch, RecordBatchOptions}; -use arrow_schema::{DataType, Schema, SchemaRef}; +use arrow_schema::{DataType, Schema, SchemaRef, TimeUnit}; use datafusion::datasource::schema_adapter::{SchemaAdapter, SchemaAdapterFactory, SchemaMapper}; -use datafusion_comet_spark_expr::{spark_cast, EvalMode}; +use datafusion_comet_spark_expr::{spark_cast, EvalMode, SparkCastOptions}; use datafusion_common::plan_err; use datafusion_expr::ColumnarValue; use std::sync::Arc; @@ -38,11 +38,11 @@ impl SchemaAdapterFactory for CometSchemaAdapterFactory { /// schema. fn create( &self, - projected_table_schema: SchemaRef, + required_schema: SchemaRef, table_schema: SchemaRef, ) -> Box { Box::new(CometSchemaAdapter { - projected_table_schema, + required_schema, table_schema, }) } @@ -54,7 +54,7 @@ impl SchemaAdapterFactory for CometSchemaAdapterFactory { pub struct CometSchemaAdapter { /// The schema for the table, projected to include only the fields being output (projected) by the /// associated ParquetExec - projected_table_schema: SchemaRef, + required_schema: SchemaRef, /// The entire table schema for the table we're using this to adapt. /// /// This is used to evaluate any filters pushed down into the scan @@ -69,7 +69,7 @@ impl SchemaAdapter for CometSchemaAdapter { /// /// Panics if index is not in range for the table schema fn map_column_index(&self, index: usize, file_schema: &Schema) -> Option { - let field = self.projected_table_schema.field(index); + let field = self.required_schema.field(index); Some(file_schema.fields.find(field.name())?.0) } @@ -87,42 +87,34 @@ impl SchemaAdapter for CometSchemaAdapter { file_schema: &Schema, ) -> datafusion_common::Result<(Arc, Vec)> { let mut projection = Vec::with_capacity(file_schema.fields().len()); - let mut field_mappings = vec![None; self.projected_table_schema.fields().len()]; + let mut field_mappings = vec![None; self.required_schema.fields().len()]; for (file_idx, file_field) in file_schema.fields.iter().enumerate() { if let Some((table_idx, table_field)) = - self.projected_table_schema.fields().find(file_field.name()) + self.required_schema.fields().find(file_field.name()) { - // workaround for struct casting - match (file_field.data_type(), table_field.data_type()) { - // TODO need to use Comet cast logic to determine which casts are supported, - // but for now just add a hack to support casting between struct types - (DataType::Struct(_), DataType::Struct(_)) => { - field_mappings[table_idx] = Some(projection.len()); - projection.push(file_idx); - } - _ => { - if can_cast_types(file_field.data_type(), table_field.data_type()) { - field_mappings[table_idx] = Some(projection.len()); - projection.push(file_idx); - } else { - return plan_err!( - "Cannot cast file schema field {} of type {:?} to table schema field of type {:?}", - file_field.name(), - file_field.data_type(), - table_field.data_type() - ); - } - } + if comet_can_cast_types(file_field.data_type(), table_field.data_type()) { + field_mappings[table_idx] = Some(projection.len()); + projection.push(file_idx); + } else { + return plan_err!( + "Cannot cast file schema field {} of type {:?} to required schema field of type {:?}", + file_field.name(), + file_field.data_type(), + table_field.data_type() + ); } } } + let mut cast_options = SparkCastOptions::new(EvalMode::Legacy, "UTC", false); + cast_options.is_adapting_schema = true; Ok(( Arc::new(SchemaMapping { - projected_table_schema: Arc::::clone(&self.projected_table_schema), + required_schema: Arc::::clone(&self.required_schema), field_mappings, table_schema: Arc::::clone(&self.table_schema), + cast_options }), projection, )) @@ -161,7 +153,7 @@ impl SchemaAdapter for CometSchemaAdapter { pub struct SchemaMapping { /// The schema of the table. This is the expected schema after conversion /// and it should match the schema of the query result. - projected_table_schema: SchemaRef, + required_schema: SchemaRef, /// Mapping from field index in `projected_table_schema` to index in /// projected file_schema. /// @@ -173,6 +165,8 @@ pub struct SchemaMapping { /// This contains all fields in the table, regardless of if they will be /// projected out or not. table_schema: SchemaRef, + + cast_options: SparkCastOptions, } impl SchemaMapper for SchemaMapping { @@ -185,7 +179,7 @@ impl SchemaMapper for SchemaMapping { let batch_cols = batch.columns().to_vec(); let cols = self - .projected_table_schema + .required_schema // go through each field in the projected schema .fields() .iter() @@ -204,10 +198,7 @@ impl SchemaMapper for SchemaMapping { spark_cast( ColumnarValue::Array(Arc::clone(&batch_cols[batch_idx])), field.data_type(), - // TODO need to pass in configs here - EvalMode::Legacy, - "UTC", - false, + &self.cast_options, )? .into_array(batch_rows) }, @@ -218,7 +209,7 @@ impl SchemaMapper for SchemaMapping { // Necessary to handle empty batches let options = RecordBatchOptions::new().with_row_count(Some(batch.num_rows())); - let schema = Arc::::clone(&self.projected_table_schema); + let schema = Arc::::clone(&self.required_schema); let record_batch = RecordBatch::try_new_with_options(schema, cols, &options)?; Ok(record_batch) } @@ -255,10 +246,7 @@ impl SchemaMapper for SchemaMapping { spark_cast( ColumnarValue::Array(Arc::clone(batch_col)), table_field.data_type(), - // TODO need to pass in configs here - EvalMode::Legacy, - "UTC", - false, + &self.cast_options, )? .into_array(batch_col.len()) // and if that works, return the field and column. @@ -277,3 +265,16 @@ impl SchemaMapper for SchemaMapping { Ok(record_batch) } } + +fn comet_can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { + // TODO this is just a quick hack to get tests passing + match (from_type, to_type) { + (DataType::Struct(_), DataType::Struct(_)) => { + // workaround for struct casting + true + } + // TODO this is maybe no longer needed + (_, DataType::Timestamp(TimeUnit::Nanosecond, _)) => false, + _ => can_cast_types(from_type, to_type), + } +} diff --git a/native/core/src/execution/operators/filter.rs b/native/core/src/execution/operators/filter.rs index d9a54712d..18a094602 100644 --- a/native/core/src/execution/operators/filter.rs +++ b/native/core/src/execution/operators/filter.rs @@ -227,7 +227,7 @@ impl DisplayAs for FilterExec { impl ExecutionPlan for FilterExec { fn name(&self) -> &'static str { - "FilterExec" + "CometFilterExec" } /// Return a reference to Any that can be used for downcasting diff --git a/native/spark-expr/src/cast.rs b/native/spark-expr/src/cast.rs index 8ef9ca291..17ab73b72 100644 --- a/native/spark-expr/src/cast.rs +++ b/native/spark-expr/src/cast.rs @@ -35,11 +35,18 @@ use arrow::{ use arrow_array::builder::StringBuilder; use arrow_array::{DictionaryArray, StringArray, StructArray}; use arrow_schema::{DataType, Field, Schema}; +use chrono::{NaiveDate, NaiveDateTime, TimeZone, Timelike}; +use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; use datafusion_common::{ cast::as_generic_string_array, internal_err, Result as DataFusionResult, ScalarValue, }; use datafusion_expr::ColumnarValue; use datafusion_physical_expr::PhysicalExpr; +use num::{ + cast::AsPrimitive, integer::div_floor, traits::CheckedNeg, CheckedSub, Integer, Num, + ToPrimitive, +}; +use regex::Regex; use std::str::FromStr; use std::{ any::Any, @@ -49,14 +56,6 @@ use std::{ sync::Arc, }; -use chrono::{NaiveDate, NaiveDateTime, TimeZone, Timelike}; -use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; -use num::{ - cast::AsPrimitive, integer::div_floor, traits::CheckedNeg, CheckedSub, Integer, Num, - ToPrimitive, -}; -use regex::Regex; - use crate::timezone; use crate::utils::array_with_timezone; @@ -138,14 +137,7 @@ impl TimeStampInfo { pub struct Cast { pub child: Arc, pub data_type: DataType, - pub eval_mode: EvalMode, - - /// When cast from/to timezone related types, we need timezone, which will be resolved with - /// session local timezone by an analyzer in Spark. - pub timezone: String, - - /// Whether to allow casts that are known to be incompatible with Spark - pub allow_incompat: bool, + pub cast_options: SparkCastOptions, } macro_rules! cast_utf8_to_int { @@ -547,31 +539,48 @@ impl Cast { pub fn new( child: Arc, data_type: DataType, - eval_mode: EvalMode, - timezone: String, - allow_incompat: bool, + cast_options: SparkCastOptions, ) -> Self { Self { child, data_type, - timezone, + cast_options, + } + } +} + +/// Spark cast options +#[derive(Debug, Clone, Hash, PartialEq, Eq)] +pub struct SparkCastOptions { + /// Spark evaluation mode + pub eval_mode: EvalMode, + /// When cast from/to timezone related types, we need timezone, which will be resolved with + /// session local timezone by an analyzer in Spark. + // TODO we should change timezone to Tz to avoid repeated parsing + pub timezone: String, + /// Allow casts that are supported but not guaranteed to be 100% compatible + pub allow_incompat: bool, + /// We also use the cast logic for adapting Parquet schemas, so this flag is used + /// for that use case + pub is_adapting_schema: bool, +} + +impl SparkCastOptions { + pub fn new(eval_mode: EvalMode, timezone: &str, allow_incompat: bool) -> Self { + Self { eval_mode, + timezone: timezone.to_string(), allow_incompat, + is_adapting_schema: false, } } - pub fn new_without_timezone( - child: Arc, - data_type: DataType, - eval_mode: EvalMode, - allow_incompat: bool, - ) -> Self { + pub fn new_without_timezone(eval_mode: EvalMode, allow_incompat: bool) -> Self { Self { - child, - data_type, - timezone: "".to_string(), eval_mode, + timezone: "".to_string(), allow_incompat, + is_adapting_schema: false, } } } @@ -582,33 +591,21 @@ impl Cast { pub fn spark_cast( arg: ColumnarValue, data_type: &DataType, - eval_mode: EvalMode, - timezone: &str, - allow_incompat: bool, + cast_options: &SparkCastOptions, ) -> DataFusionResult { match arg { ColumnarValue::Array(array) => Ok(ColumnarValue::Array(cast_array( array, data_type, - eval_mode, - timezone.to_owned(), - allow_incompat, + cast_options, )?)), ColumnarValue::Scalar(scalar) => { // Note that normally CAST(scalar) should be fold in Spark JVM side. However, for // some cases e.g., scalar subquery, Spark will not fold it, so we need to handle it // here. let array = scalar.to_array()?; - let scalar = ScalarValue::try_from_array( - &cast_array( - array, - data_type, - eval_mode, - timezone.to_owned(), - allow_incompat, - )?, - 0, - )?; + let scalar = + ScalarValue::try_from_array(&cast_array(array, data_type, cast_options)?, 0)?; Ok(ColumnarValue::Scalar(scalar)) } } @@ -617,11 +614,9 @@ pub fn spark_cast( fn cast_array( array: ArrayRef, to_type: &DataType, - eval_mode: EvalMode, - timezone: String, - allow_incompat: bool, + cast_options: &SparkCastOptions, ) -> DataFusionResult { - let array = array_with_timezone(array, timezone.clone(), Some(to_type))?; + let array = array_with_timezone(array, cast_options.timezone.clone(), Some(to_type))?; let from_type = array.data_type().clone(); let array = match &from_type { @@ -637,13 +632,7 @@ fn cast_array( let casted_dictionary = DictionaryArray::::new( dict_array.keys().clone(), - cast_array( - Arc::clone(dict_array.values()), - to_type, - eval_mode, - timezone, - allow_incompat, - )?, + cast_array(Arc::clone(dict_array.values()), to_type, cast_options)?, ); let casted_result = match to_type { @@ -655,6 +644,7 @@ fn cast_array( _ => array, }; let from_type = array.data_type(); + let eval_mode = cast_options.eval_mode; let cast_result = match (from_type, to_type) { (DataType::Utf8, DataType::Boolean) => spark_cast_utf8_to_boolean::(&array, eval_mode), @@ -662,7 +652,7 @@ fn cast_array( spark_cast_utf8_to_boolean::(&array, eval_mode) } (DataType::Utf8, DataType::Timestamp(_, _)) => { - cast_string_to_timestamp(&array, to_type, eval_mode, &timezone) + cast_string_to_timestamp(&array, to_type, eval_mode, &cast_options.timezone) } (DataType::Utf8, DataType::Date32) => cast_string_to_date(&array, to_type, eval_mode), (DataType::Int64, DataType::Int32) @@ -713,17 +703,17 @@ fn cast_array( spark_cast_nonintegral_numeric_to_integral(&array, eval_mode, from_type, to_type) } (DataType::Struct(_), DataType::Utf8) => { - Ok(casts_struct_to_string(array.as_struct(), &timezone)?) + Ok(casts_struct_to_string(array.as_struct(), cast_options)?) } (DataType::Struct(_), DataType::Struct(_)) => Ok(cast_struct_to_struct( array.as_struct(), from_type, to_type, - eval_mode, - timezone, - allow_incompat, + cast_options, )?), - _ if is_datafusion_spark_compatible(from_type, to_type, allow_incompat) => { + _ if cast_options.is_adapting_schema + || is_datafusion_spark_compatible(from_type, to_type, cast_options.allow_incompat) => + { // use DataFusion cast only when we know that it is compatible with Spark Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?) } @@ -826,9 +816,7 @@ fn cast_struct_to_struct( array: &StructArray, from_type: &DataType, to_type: &DataType, - eval_mode: EvalMode, - timezone: String, - allow_incompat: bool, + cast_options: &SparkCastOptions, ) -> DataFusionResult { match (from_type, to_type) { (DataType::Struct(_), DataType::Struct(to_fields)) => { @@ -837,9 +825,7 @@ fn cast_struct_to_struct( let cast_field = cast_array( Arc::clone(array.column(i)), to_fields[i].data_type(), - eval_mode, - timezone.clone(), - allow_incompat, + cast_options, )?; cast_fields.push((Arc::clone(&to_fields[i]), cast_field)); } @@ -849,7 +835,10 @@ fn cast_struct_to_struct( } } -fn casts_struct_to_string(array: &StructArray, timezone: &str) -> DataFusionResult { +fn casts_struct_to_string( + array: &StructArray, + spark_cast_options: &SparkCastOptions, +) -> DataFusionResult { // cast each field to a string let string_arrays: Vec = array .columns() @@ -858,9 +847,7 @@ fn casts_struct_to_string(array: &StructArray, timezone: &str) -> DataFusionResu spark_cast( ColumnarValue::Array(Arc::clone(arr)), &DataType::Utf8, - EvalMode::Legacy, - timezone, - true, + spark_cast_options, ) .and_then(|cv| cv.into_array(arr.len())) }) @@ -1465,7 +1452,7 @@ impl Display for Cast { write!( f, "Cast [data_type: {}, timezone: {}, child: {}, eval_mode: {:?}]", - self.data_type, self.timezone, self.child, &self.eval_mode + self.data_type, self.cast_options.timezone, self.child, &self.cast_options.eval_mode ) } } @@ -1476,9 +1463,8 @@ impl PartialEq for Cast { .downcast_ref::() .map(|x| { self.child.eq(&x.child) - && self.timezone.eq(&x.timezone) + && self.cast_options.eq(&x.cast_options) && self.data_type.eq(&x.data_type) - && self.eval_mode.eq(&x.eval_mode) }) .unwrap_or(false) } @@ -1499,13 +1485,7 @@ impl PhysicalExpr for Cast { fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult { let arg = self.child.evaluate(batch)?; - spark_cast( - arg, - &self.data_type, - self.eval_mode, - &self.timezone, - self.allow_incompat, - ) + spark_cast(arg, &self.data_type, &self.cast_options) } fn children(&self) -> Vec<&Arc> { @@ -1520,9 +1500,7 @@ impl PhysicalExpr for Cast { 1 => Ok(Arc::new(Cast::new( Arc::clone(&children[0]), self.data_type.clone(), - self.eval_mode, - self.timezone.clone(), - self.allow_incompat, + self.cast_options.clone(), ))), _ => internal_err!("Cast should have exactly one child"), } @@ -1532,9 +1510,7 @@ impl PhysicalExpr for Cast { let mut s = state; self.child.hash(&mut s); self.data_type.hash(&mut s); - self.timezone.hash(&mut s); - self.eval_mode.hash(&mut s); - self.allow_incompat.hash(&mut s); + self.cast_options.hash(&mut s); self.hash(&mut s); } } @@ -2111,12 +2087,11 @@ mod tests { let timezone = "UTC".to_string(); // test casting string dictionary array to timestamp array + let cast_options = SparkCastOptions::new(EvalMode::Legacy, timezone.clone(), false); let result = cast_array( dict_array, &DataType::Timestamp(TimeUnit::Microsecond, Some(timezone.clone().into())), - EvalMode::Legacy, - timezone.clone(), - false, + &cast_options, )?; assert_eq!( *result.data_type(), @@ -2321,12 +2296,11 @@ mod tests { fn test_cast_unsupported_timestamp_to_date() { // Since datafusion uses chrono::Datetime internally not all dates representable by TimestampMicrosecondType are supported let timestamps: PrimitiveArray = vec![i64::MAX].into(); + let cast_options = SparkCastOptions::new(EvalMode::Legacy, "UTC".to_string(), false); let result = cast_array( Arc::new(timestamps.with_timezone("Europe/Copenhagen")), &DataType::Date32, - EvalMode::Legacy, - "UTC".to_owned(), - false, + &cast_options, ); assert!(result.is_err()) } @@ -2334,12 +2308,12 @@ mod tests { #[test] fn test_cast_invalid_timezone() { let timestamps: PrimitiveArray = vec![i64::MAX].into(); + let cast_options = + SparkCastOptions::new(EvalMode::Legacy, "Not a valid timezone".to_string(), false); let result = cast_array( Arc::new(timestamps.with_timezone("Europe/Copenhagen")), &DataType::Date32, - EvalMode::Legacy, - "Not a valid timezone".to_owned(), - false, + &cast_options, ); assert!(result.is_err()) } @@ -2361,9 +2335,7 @@ mod tests { let string_array = cast_array( c, &DataType::Utf8, - EvalMode::Legacy, - "UTC".to_owned(), - false, + &SparkCastOptions::new(EvalMode::Legacy, "UTC".to_owned(), false), ) .unwrap(); let string_array = string_array.as_string::(); @@ -2397,9 +2369,7 @@ mod tests { let cast_array = spark_cast( ColumnarValue::Array(c), &DataType::Struct(fields), - EvalMode::Legacy, - "UTC", - false, + &SparkCastOptions::new(EvalMode::Legacy, "UTC", false), ) .unwrap(); if let ColumnarValue::Array(cast_array) = cast_array { @@ -2433,6 +2403,7 @@ mod tests { EvalMode::Legacy, "UTC", false, + false, ) .unwrap(); if let ColumnarValue::Array(cast_array) = cast_array { diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs index 614b48f2b..eb02cef84 100644 --- a/native/spark-expr/src/lib.rs +++ b/native/spark-expr/src/lib.rs @@ -34,7 +34,7 @@ pub mod timezone; mod to_json; pub mod utils; -pub use cast::{spark_cast, Cast}; +pub use cast::{spark_cast, Cast, SparkCastOptions}; pub use error::{SparkError, SparkResult}; pub use if_expr::IfExpr; pub use list::{GetArrayStructFields, ListExtract}; @@ -47,7 +47,7 @@ pub use to_json::ToJson; /// the behavior when processing input values that are invalid or would result in an /// error, such as divide by zero errors, and also affects behavior when converting /// between types. -#[derive(Debug, Hash, PartialEq, Clone, Copy)] +#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)] pub enum EvalMode { /// Legacy is the default behavior in Spark prior to Spark 4.0. This mode silently ignores /// or replaces errors during SQL operations. Operations resulting in errors (like diff --git a/native/spark-expr/src/to_json.rs b/native/spark-expr/src/to_json.rs index 7d38cbf1b..1f68eb860 100644 --- a/native/spark-expr/src/to_json.rs +++ b/native/spark-expr/src/to_json.rs @@ -19,6 +19,7 @@ // of the Spark-specific compatibility features that we need (including // being able to specify Spark-compatible cast from all types to string) +use crate::cast::SparkCastOptions; use crate::{spark_cast, EvalMode}; use arrow_array::builder::StringBuilder; use arrow_array::{Array, ArrayRef, RecordBatch, StringArray, StructArray}; @@ -117,9 +118,7 @@ fn array_to_json_string(arr: &Arc, timezone: &str) -> Result { timestamp_ntz_to_timestamp(array, timezone.as_str(), Some(timezone.as_str())) } + Some(DataType::Timestamp(_, None)) => { + timestamp_ntz_to_timestamp(array, timezone.as_str(), None) + } _ => { // Not supported panic!( @@ -80,7 +84,7 @@ pub fn array_with_timezone( } } } - DataType::Timestamp(_, Some(_)) => { + DataType::Timestamp(TimeUnit::Microsecond, Some(_)) => { assert!(!timezone.is_empty()); let array = as_primitive_array::(&array); let array_with_timezone = array.clone().with_timezone(timezone.clone()); @@ -92,6 +96,18 @@ pub fn array_with_timezone( _ => Ok(array), } } + DataType::Timestamp(TimeUnit::Millisecond, Some(_)) => { + assert!(!timezone.is_empty()); + let array = as_primitive_array::(&array); + let array_with_timezone = array.clone().with_timezone(timezone.clone()); + let array = Arc::new(array_with_timezone) as ArrayRef; + match to_type { + Some(DataType::Utf8) | Some(DataType::Date32) => { + pre_timestamp_cast(array, timezone) + } + _ => Ok(array), + } + } DataType::Dictionary(_, value_type) if matches!(value_type.as_ref(), &DataType::Timestamp(_, _)) => { @@ -127,7 +143,7 @@ fn timestamp_ntz_to_timestamp( ) -> Result { assert!(!tz.is_empty()); match array.data_type() { - DataType::Timestamp(_, None) => { + DataType::Timestamp(TimeUnit::Microsecond, None) => { let array = as_primitive_array::(&array); let tz: Tz = tz.parse()?; let array: PrimitiveArray = array.try_unary(|value| { @@ -146,6 +162,25 @@ fn timestamp_ntz_to_timestamp( }; Ok(Arc::new(array_with_tz)) } + DataType::Timestamp(TimeUnit::Millisecond, None) => { + let array = as_primitive_array::(&array); + let tz: Tz = tz.parse()?; + let array: PrimitiveArray = array.try_unary(|value| { + as_datetime::(value) + .ok_or_else(|| datetime_cast_err(value)) + .map(|local_datetime| { + let datetime: DateTime = + tz.from_local_datetime(&local_datetime).unwrap(); + datetime.timestamp_millis() + }) + })?; + let array_with_tz = if let Some(to_tz) = to_timezone { + array.with_timezone(to_tz) + } else { + array + }; + Ok(Arc::new(array_with_tz)) + } _ => Ok(array), } } diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 29e50d734..b30ad1396 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -2496,8 +2496,9 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim // Sink operators don't have children result.clearChildren() - val dataFilters = scan.dataFilters.map(exprToProto(_, scan.output)) - nativeScanBuilder.addAllDataFilters(dataFilters.map(_.get).asJava) + // TODO remove flatMap and add error handling for unsupported data filters + val dataFilters = scan.dataFilters.flatMap(exprToProto(_, scan.output)) + nativeScanBuilder.addAllDataFilters(dataFilters.asJava) // TODO: modify CometNativeScan to generate the file partitions without instantiating RDD. scan.inputRDD match {