From bf6b4d45b1175801f8b6e2a25c28147a360ea2f3 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 4 Dec 2024 14:01:56 -0700 Subject: [PATCH 1/7] support more timestamp conversions --- native/spark-expr/src/utils.rs | 41 +++++++++++++++++++++++++++++++--- 1 file changed, 38 insertions(+), 3 deletions(-) diff --git a/native/spark-expr/src/utils.rs b/native/spark-expr/src/utils.rs index db4ad1956..2fc8de974 100644 --- a/native/spark-expr/src/utils.rs +++ b/native/spark-expr/src/utils.rs @@ -19,7 +19,7 @@ use arrow_array::{ cast::as_primitive_array, types::{Int32Type, TimestampMicrosecondType}, }; -use arrow_schema::{ArrowError, DataType}; +use arrow_schema::{ArrowError, DataType, TimeUnit}; use std::sync::Arc; use crate::timezone::Tz; @@ -27,6 +27,7 @@ use arrow::{ array::{as_dictionary_array, Array, ArrayRef, PrimitiveArray}, temporal_conversions::as_datetime, }; +use arrow_array::types::TimestampMillisecondType; use chrono::{DateTime, Offset, TimeZone}; /// Preprocesses input arrays to add timezone information from Spark to Arrow array datatype or @@ -70,6 +71,9 @@ pub fn array_with_timezone( Some(DataType::Timestamp(_, Some(_))) => { timestamp_ntz_to_timestamp(array, timezone.as_str(), Some(timezone.as_str())) } + Some(DataType::Timestamp(_, None)) => { + timestamp_ntz_to_timestamp(array, timezone.as_str(), None) + } _ => { // Not supported panic!( @@ -80,7 +84,7 @@ pub fn array_with_timezone( } } } - DataType::Timestamp(_, Some(_)) => { + DataType::Timestamp(TimeUnit::Microsecond, Some(_)) => { assert!(!timezone.is_empty()); let array = as_primitive_array::(&array); let array_with_timezone = array.clone().with_timezone(timezone.clone()); @@ -92,6 +96,18 @@ pub fn array_with_timezone( _ => Ok(array), } } + DataType::Timestamp(TimeUnit::Millisecond, Some(_)) => { + assert!(!timezone.is_empty()); + let array = as_primitive_array::(&array); + let array_with_timezone = array.clone().with_timezone(timezone.clone()); + let array = Arc::new(array_with_timezone) as ArrayRef; + match to_type { + Some(DataType::Utf8) | Some(DataType::Date32) => { + pre_timestamp_cast(array, timezone) + } + _ => Ok(array), + } + } DataType::Dictionary(_, value_type) if matches!(value_type.as_ref(), &DataType::Timestamp(_, _)) => { @@ -127,7 +143,7 @@ fn timestamp_ntz_to_timestamp( ) -> Result { assert!(!tz.is_empty()); match array.data_type() { - DataType::Timestamp(_, None) => { + DataType::Timestamp(TimeUnit::Microsecond, None) => { let array = as_primitive_array::(&array); let tz: Tz = tz.parse()?; let array: PrimitiveArray = array.try_unary(|value| { @@ -146,6 +162,25 @@ fn timestamp_ntz_to_timestamp( }; Ok(Arc::new(array_with_tz)) } + DataType::Timestamp(TimeUnit::Millisecond, None) => { + let array = as_primitive_array::(&array); + let tz: Tz = tz.parse()?; + let array: PrimitiveArray = array.try_unary(|value| { + as_datetime::(value) + .ok_or_else(|| datetime_cast_err(value)) + .map(|local_datetime| { + let datetime: DateTime = + tz.from_local_datetime(&local_datetime).unwrap(); + datetime.timestamp_millis() + }) + })?; + let array_with_tz = if let Some(to_tz) = to_timezone { + array.with_timezone(to_tz) + } else { + array + }; + Ok(Arc::new(array_with_tz)) + } _ => Ok(array), } } From d4d71bc2d35a601f28728f1bcaa4064182f4aa0e Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 4 Dec 2024 14:13:31 -0700 Subject: [PATCH 2/7] improve error handling --- .../core/src/execution/datafusion/planner.rs | 8 +-- .../execution/datafusion/schema_adapter.rs | 63 ++++++++++--------- 2 files changed, 36 insertions(+), 35 deletions(-) diff --git a/native/core/src/execution/datafusion/planner.rs b/native/core/src/execution/datafusion/planner.rs index c5147d772..a2a75bf98 100644 --- a/native/core/src/execution/datafusion/planner.rs +++ b/native/core/src/execution/datafusion/planner.rs @@ -1095,11 +1095,9 @@ impl PhysicalPlanner { table_parquet_options.global.pushdown_filters = true; table_parquet_options.global.reorder_filters = true; - let mut builder = ParquetExecBuilder::new(file_scan_config) - .with_table_parquet_options(table_parquet_options) - .with_schema_adapter_factory( - Arc::new(CometSchemaAdapterFactory::default()), - ); + let mut builder = ParquetExecBuilder::new(file_scan_config) + .with_table_parquet_options(table_parquet_options) + .with_schema_adapter_factory(Arc::new(CometSchemaAdapterFactory::default())); if let Some(filter) = test_data_filters { builder = builder.with_predicate(filter); diff --git a/native/core/src/execution/datafusion/schema_adapter.rs b/native/core/src/execution/datafusion/schema_adapter.rs index 16d4b9d67..4573ba348 100644 --- a/native/core/src/execution/datafusion/schema_adapter.rs +++ b/native/core/src/execution/datafusion/schema_adapter.rs @@ -19,7 +19,7 @@ use arrow::compute::can_cast_types; use arrow_array::{new_null_array, Array, RecordBatch, RecordBatchOptions}; -use arrow_schema::{DataType, Schema, SchemaRef}; +use arrow_schema::{DataType, Schema, SchemaRef, TimeUnit}; use datafusion::datasource::schema_adapter::{SchemaAdapter, SchemaAdapterFactory, SchemaMapper}; use datafusion_comet_spark_expr::{spark_cast, EvalMode}; use datafusion_common::plan_err; @@ -38,11 +38,11 @@ impl SchemaAdapterFactory for CometSchemaAdapterFactory { /// schema. fn create( &self, - projected_table_schema: SchemaRef, + required_schema: SchemaRef, table_schema: SchemaRef, ) -> Box { Box::new(CometSchemaAdapter { - projected_table_schema, + required_schema, table_schema, }) } @@ -54,7 +54,7 @@ impl SchemaAdapterFactory for CometSchemaAdapterFactory { pub struct CometSchemaAdapter { /// The schema for the table, projected to include only the fields being output (projected) by the /// associated ParquetExec - projected_table_schema: SchemaRef, + required_schema: SchemaRef, /// The entire table schema for the table we're using this to adapt. /// /// This is used to evaluate any filters pushed down into the scan @@ -69,7 +69,7 @@ impl SchemaAdapter for CometSchemaAdapter { /// /// Panics if index is not in range for the table schema fn map_column_index(&self, index: usize, file_schema: &Schema) -> Option { - let field = self.projected_table_schema.field(index); + let field = self.required_schema.field(index); Some(file_schema.fields.find(field.name())?.0) } @@ -87,40 +87,29 @@ impl SchemaAdapter for CometSchemaAdapter { file_schema: &Schema, ) -> datafusion_common::Result<(Arc, Vec)> { let mut projection = Vec::with_capacity(file_schema.fields().len()); - let mut field_mappings = vec![None; self.projected_table_schema.fields().len()]; + let mut field_mappings = vec![None; self.required_schema.fields().len()]; for (file_idx, file_field) in file_schema.fields.iter().enumerate() { if let Some((table_idx, table_field)) = - self.projected_table_schema.fields().find(file_field.name()) + self.required_schema.fields().find(file_field.name()) { - // workaround for struct casting - match (file_field.data_type(), table_field.data_type()) { - // TODO need to use Comet cast logic to determine which casts are supported, - // but for now just add a hack to support casting between struct types - (DataType::Struct(_), DataType::Struct(_)) => { - field_mappings[table_idx] = Some(projection.len()); - projection.push(file_idx); - } - _ => { - if can_cast_types(file_field.data_type(), table_field.data_type()) { - field_mappings[table_idx] = Some(projection.len()); - projection.push(file_idx); - } else { - return plan_err!( - "Cannot cast file schema field {} of type {:?} to table schema field of type {:?}", - file_field.name(), - file_field.data_type(), - table_field.data_type() - ); - } - } + if spark_can_cast_types(file_field.data_type(), table_field.data_type()) { + field_mappings[table_idx] = Some(projection.len()); + projection.push(file_idx); + } else { + return plan_err!( + "Cannot cast file schema field {} of type {:?} to required schema field of type {:?}", + file_field.name(), + file_field.data_type(), + table_field.data_type() + ); } } } Ok(( Arc::new(SchemaMapping { - projected_table_schema: self.projected_table_schema.clone(), + projected_table_schema: self.required_schema.clone(), field_mappings, table_schema: self.table_schema.clone(), }), @@ -129,6 +118,19 @@ impl SchemaAdapter for CometSchemaAdapter { } } +pub fn spark_can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { + // TODO add all Spark cast rules (they are currently implemented in + // org.apache.comet.expressions.CometCast#isSupported in JVM side) + match (from_type, to_type) { + (DataType::Struct(_), DataType::Struct(_)) => { + // workaround for struct casting + true + } + (_, DataType::Timestamp(TimeUnit::Nanosecond, _)) => false, + _ => can_cast_types(from_type, to_type), + } +} + // TODO SchemaMapping is mostly copied from DataFusion but calls spark_cast // instead of arrow cast - can we reduce the amount of code copied here and make // the DataFusion version more extensible? @@ -259,7 +261,8 @@ impl SchemaMapper for SchemaMapping { EvalMode::Legacy, "UTC", false, - )?.into_array(batch_col.len()) + )? + .into_array(batch_col.len()) // and if that works, return the field and column. .map(|new_col| (new_col, table_field.clone())) }) From b6036f241a6d41fae53f3ce6cbeaa3f1a4fc2779 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 4 Dec 2024 15:05:10 -0700 Subject: [PATCH 3/7] rename projected_table_schema to required_schema --- native/core/src/execution/datafusion/schema_adapter.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/native/core/src/execution/datafusion/schema_adapter.rs b/native/core/src/execution/datafusion/schema_adapter.rs index 4573ba348..c5d714564 100644 --- a/native/core/src/execution/datafusion/schema_adapter.rs +++ b/native/core/src/execution/datafusion/schema_adapter.rs @@ -109,7 +109,7 @@ impl SchemaAdapter for CometSchemaAdapter { Ok(( Arc::new(SchemaMapping { - projected_table_schema: self.required_schema.clone(), + required_schema: self.required_schema.clone(), field_mappings, table_schema: self.table_schema.clone(), }), @@ -163,7 +163,7 @@ pub fn spark_can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { pub struct SchemaMapping { /// The schema of the table. This is the expected schema after conversion /// and it should match the schema of the query result. - projected_table_schema: SchemaRef, + required_schema: SchemaRef, /// Mapping from field index in `projected_table_schema` to index in /// projected file_schema. /// @@ -187,7 +187,7 @@ impl SchemaMapper for SchemaMapping { let batch_cols = batch.columns().to_vec(); let cols = self - .projected_table_schema + .required_schema // go through each field in the projected schema .fields() .iter() @@ -220,7 +220,7 @@ impl SchemaMapper for SchemaMapping { // Necessary to handle empty batches let options = RecordBatchOptions::new().with_row_count(Some(batch.num_rows())); - let schema = self.projected_table_schema.clone(); + let schema = self.required_schema.clone(); let record_batch = RecordBatch::try_new_with_options(schema, cols, &options)?; Ok(record_batch) } From 0602e24c349c05e0b6980762ca769155af480f2a Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 5 Dec 2024 10:40:29 -0700 Subject: [PATCH 4/7] Save --- native/core/src/execution/datafusion/schema_adapter.rs | 4 ++++ .../main/scala/org/apache/comet/serde/QueryPlanSerde.scala | 5 +++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/native/core/src/execution/datafusion/schema_adapter.rs b/native/core/src/execution/datafusion/schema_adapter.rs index c5d714564..ff325f647 100644 --- a/native/core/src/execution/datafusion/schema_adapter.rs +++ b/native/core/src/execution/datafusion/schema_adapter.rs @@ -127,6 +127,10 @@ pub fn spark_can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { true } (_, DataType::Timestamp(TimeUnit::Nanosecond, _)) => false, + // Native cast invoked for unsupported cast from FixedSizeBinary(3) to Binary. + (DataType::FixedSizeBinary(_), _) => false, + // Native cast invoked for unsupported cast from UInt32 to Int64. + (DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64, _) => false, _ => can_cast_types(from_type, to_type), } } diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 29e50d734..b30ad1396 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -2496,8 +2496,9 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim // Sink operators don't have children result.clearChildren() - val dataFilters = scan.dataFilters.map(exprToProto(_, scan.output)) - nativeScanBuilder.addAllDataFilters(dataFilters.map(_.get).asJava) + // TODO remove flatMap and add error handling for unsupported data filters + val dataFilters = scan.dataFilters.flatMap(exprToProto(_, scan.output)) + nativeScanBuilder.addAllDataFilters(dataFilters.asJava) // TODO: modify CometNativeScan to generate the file partitions without instantiating RDD. scan.inputRDD match { From 74add9c51a9ca5c07dc7cb31ed4758d06835a113 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 5 Dec 2024 13:55:52 -0700 Subject: [PATCH 5/7] save --- native/core/src/execution/datafusion/planner.rs | 9 +++++++-- native/core/src/execution/operators/filter.rs | 2 +- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/native/core/src/execution/datafusion/planner.rs b/native/core/src/execution/datafusion/planner.rs index 40fe4515c..c40780c58 100644 --- a/native/core/src/execution/datafusion/planner.rs +++ b/native/core/src/execution/datafusion/planner.rs @@ -17,9 +17,10 @@ //! Converts Spark physical plan to DataFusion physical plan +use datafusion::physical_plan::filter::FilterExec; use super::expressions::EvalMode; use crate::execution::datafusion::expressions::comet_scalar_funcs::create_comet_physical_fun; -use crate::execution::operators::{CopyMode, FilterExec}; +use crate::execution::operators::{CopyMode, FilterExec as CometFilterExec}; use crate::{ errors::ExpressionError, execution::{ @@ -851,7 +852,11 @@ impl PhysicalPlanner { let predicate = self.create_expr(filter.predicate.as_ref().unwrap(), child.schema())?; - Ok((scans, Arc::new(FilterExec::try_new(predicate, child)?))) + if can_reuse_input_batch(&child) { + Ok((scans, Arc::new(CometFilterExec::try_new(predicate, child)?))) + } else { + Ok((scans, Arc::new(FilterExec::try_new(predicate, child)?))) + } } OpStruct::HashAgg(agg) => { assert!(children.len() == 1); diff --git a/native/core/src/execution/operators/filter.rs b/native/core/src/execution/operators/filter.rs index d9a54712d..18a094602 100644 --- a/native/core/src/execution/operators/filter.rs +++ b/native/core/src/execution/operators/filter.rs @@ -227,7 +227,7 @@ impl DisplayAs for FilterExec { impl ExecutionPlan for FilterExec { fn name(&self) -> &'static str { - "FilterExec" + "CometFilterExec" } /// Return a reference to Any that can be used for downcasting From 202140673a1810c7da1150167c46287c1187ef44 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 5 Dec 2024 14:14:33 -0700 Subject: [PATCH 6/7] save --- .../core/src/execution/datafusion/planner.rs | 2 +- .../execution/datafusion/schema_adapter.rs | 51 ++++++++++++------- native/spark-expr/src/cast.rs | 18 ++++++- native/spark-expr/src/to_json.rs | 1 + 4 files changed, 51 insertions(+), 21 deletions(-) diff --git a/native/core/src/execution/datafusion/planner.rs b/native/core/src/execution/datafusion/planner.rs index c40780c58..d376b7889 100644 --- a/native/core/src/execution/datafusion/planner.rs +++ b/native/core/src/execution/datafusion/planner.rs @@ -17,7 +17,6 @@ //! Converts Spark physical plan to DataFusion physical plan -use datafusion::physical_plan::filter::FilterExec; use super::expressions::EvalMode; use crate::execution::datafusion::expressions::comet_scalar_funcs::create_comet_physical_fun; use crate::execution::operators::{CopyMode, FilterExec as CometFilterExec}; @@ -56,6 +55,7 @@ use datafusion::functions_aggregate::bit_and_or_xor::{bit_and_udaf, bit_or_udaf, use datafusion::functions_aggregate::min_max::max_udaf; use datafusion::functions_aggregate::min_max::min_udaf; use datafusion::functions_aggregate::sum::sum_udaf; +use datafusion::physical_plan::filter::FilterExec; use datafusion::physical_plan::windows::BoundedWindowAggExec; use datafusion::physical_plan::InputOrderMode; use datafusion::{ diff --git a/native/core/src/execution/datafusion/schema_adapter.rs b/native/core/src/execution/datafusion/schema_adapter.rs index ff325f647..2009bf84b 100644 --- a/native/core/src/execution/datafusion/schema_adapter.rs +++ b/native/core/src/execution/datafusion/schema_adapter.rs @@ -22,7 +22,7 @@ use arrow_array::{new_null_array, Array, RecordBatch, RecordBatchOptions}; use arrow_schema::{DataType, Schema, SchemaRef, TimeUnit}; use datafusion::datasource::schema_adapter::{SchemaAdapter, SchemaAdapterFactory, SchemaMapper}; use datafusion_comet_spark_expr::{spark_cast, EvalMode}; -use datafusion_common::plan_err; +use datafusion_common::{plan_err, DataFusionError}; use datafusion_expr::ColumnarValue; use std::sync::Arc; @@ -93,7 +93,7 @@ impl SchemaAdapter for CometSchemaAdapter { if let Some((table_idx, table_field)) = self.required_schema.fields().find(file_field.name()) { - if spark_can_cast_types(file_field.data_type(), table_field.data_type()) { + if comet_can_cast_types(file_field.data_type(), table_field.data_type()) { field_mappings[table_idx] = Some(projection.len()); projection.push(file_idx); } else { @@ -118,23 +118,6 @@ impl SchemaAdapter for CometSchemaAdapter { } } -pub fn spark_can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { - // TODO add all Spark cast rules (they are currently implemented in - // org.apache.comet.expressions.CometCast#isSupported in JVM side) - match (from_type, to_type) { - (DataType::Struct(_), DataType::Struct(_)) => { - // workaround for struct casting - true - } - (_, DataType::Timestamp(TimeUnit::Nanosecond, _)) => false, - // Native cast invoked for unsupported cast from FixedSizeBinary(3) to Binary. - (DataType::FixedSizeBinary(_), _) => false, - // Native cast invoked for unsupported cast from UInt32 to Int64. - (DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64, _) => false, - _ => can_cast_types(from_type, to_type), - } -} - // TODO SchemaMapping is mostly copied from DataFusion but calls spark_cast // instead of arrow cast - can we reduce the amount of code copied here and make // the DataFusion version more extensible? @@ -214,6 +197,7 @@ impl SchemaMapper for SchemaMapping { EvalMode::Legacy, "UTC", false, + true, )? .into_array(batch_rows) }, @@ -265,6 +249,7 @@ impl SchemaMapper for SchemaMapping { EvalMode::Legacy, "UTC", false, + true, )? .into_array(batch_col.len()) // and if that works, return the field and column. @@ -283,3 +268,31 @@ impl SchemaMapper for SchemaMapping { Ok(record_batch) } } + + + +fn comet_can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { + // TODO this is just a quick hack to get tests passing + match (from_type, to_type) { + (DataType::Struct(_), DataType::Struct(_)) => { + // workaround for struct casting + true + } + // TODO this is maybe no longer needed + (_, DataType::Timestamp(TimeUnit::Nanosecond, _)) => false, + _ => can_cast_types(from_type, to_type), + } +} + +pub fn comet_cast( + arg: ColumnarValue, + data_type: &DataType, + eval_mode: EvalMode, + timezone: &str, + allow_incompat: bool, +) -> Result { + // TODO for now we are re-using the spark cast rules, with a hack to override + // unsupported cases and let those fall through to arrow. This is just a short term + // hack and we need to implement specific Parquet to Spark conversions here instead + spark_cast(arg, data_type, eval_mode, timezone, allow_incompat, true) +} diff --git a/native/spark-expr/src/cast.rs b/native/spark-expr/src/cast.rs index 8ef9ca291..7ce0c2a66 100644 --- a/native/spark-expr/src/cast.rs +++ b/native/spark-expr/src/cast.rs @@ -585,6 +585,7 @@ pub fn spark_cast( eval_mode: EvalMode, timezone: &str, allow_incompat: bool, + ugly_hack_for_poc: bool, // TODO we definitely don't want to do this ) -> DataFusionResult { match arg { ColumnarValue::Array(array) => Ok(ColumnarValue::Array(cast_array( @@ -593,6 +594,7 @@ pub fn spark_cast( eval_mode, timezone.to_owned(), allow_incompat, + ugly_hack_for_poc, )?)), ColumnarValue::Scalar(scalar) => { // Note that normally CAST(scalar) should be fold in Spark JVM side. However, for @@ -606,6 +608,7 @@ pub fn spark_cast( eval_mode, timezone.to_owned(), allow_incompat, + ugly_hack_for_poc, )?, 0, )?; @@ -620,6 +623,7 @@ fn cast_array( eval_mode: EvalMode, timezone: String, allow_incompat: bool, + ugly_hack_for_poc: bool, ) -> DataFusionResult { let array = array_with_timezone(array, timezone.clone(), Some(to_type))?; let from_type = array.data_type().clone(); @@ -643,6 +647,7 @@ fn cast_array( eval_mode, timezone, allow_incompat, + ugly_hack_for_poc, )?, ); @@ -723,7 +728,9 @@ fn cast_array( timezone, allow_incompat, )?), - _ if is_datafusion_spark_compatible(from_type, to_type, allow_incompat) => { + _ if ugly_hack_for_poc + || is_datafusion_spark_compatible(from_type, to_type, allow_incompat) => + { // use DataFusion cast only when we know that it is compatible with Spark Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?) } @@ -840,6 +847,7 @@ fn cast_struct_to_struct( eval_mode, timezone.clone(), allow_incompat, + false, )?; cast_fields.push((Arc::clone(&to_fields[i]), cast_field)); } @@ -861,6 +869,7 @@ fn casts_struct_to_string(array: &StructArray, timezone: &str) -> DataFusionResu EvalMode::Legacy, timezone, true, + false, ) .and_then(|cv| cv.into_array(arr.len())) }) @@ -1505,6 +1514,7 @@ impl PhysicalExpr for Cast { self.eval_mode, &self.timezone, self.allow_incompat, + false, ) } @@ -2117,6 +2127,7 @@ mod tests { EvalMode::Legacy, timezone.clone(), false, + false, )?; assert_eq!( *result.data_type(), @@ -2327,6 +2338,7 @@ mod tests { EvalMode::Legacy, "UTC".to_owned(), false, + false, ); assert!(result.is_err()) } @@ -2340,6 +2352,7 @@ mod tests { EvalMode::Legacy, "Not a valid timezone".to_owned(), false, + false, ); assert!(result.is_err()) } @@ -2364,6 +2377,7 @@ mod tests { EvalMode::Legacy, "UTC".to_owned(), false, + false, ) .unwrap(); let string_array = string_array.as_string::(); @@ -2400,6 +2414,7 @@ mod tests { EvalMode::Legacy, "UTC", false, + false, ) .unwrap(); if let ColumnarValue::Array(cast_array) = cast_array { @@ -2433,6 +2448,7 @@ mod tests { EvalMode::Legacy, "UTC", false, + false, ) .unwrap(); if let ColumnarValue::Array(cast_array) = cast_array { diff --git a/native/spark-expr/src/to_json.rs b/native/spark-expr/src/to_json.rs index 7d38cbf1b..99f14a886 100644 --- a/native/spark-expr/src/to_json.rs +++ b/native/spark-expr/src/to_json.rs @@ -120,6 +120,7 @@ fn array_to_json_string(arr: &Arc, timezone: &str) -> Result Date: Fri, 6 Dec 2024 08:02:45 -0700 Subject: [PATCH 7/7] code cleanup --- .../core/src/execution/datafusion/planner.rs | 22 +- .../execution/datafusion/schema_adapter.rs | 36 +--- native/spark-expr/src/cast.rs | 189 +++++++----------- native/spark-expr/src/lib.rs | 4 +- native/spark-expr/src/to_json.rs | 6 +- 5 files changed, 93 insertions(+), 164 deletions(-) diff --git a/native/core/src/execution/datafusion/planner.rs b/native/core/src/execution/datafusion/planner.rs index d376b7889..bbe452cf8 100644 --- a/native/core/src/execution/datafusion/planner.rs +++ b/native/core/src/execution/datafusion/planner.rs @@ -103,7 +103,7 @@ use datafusion_comet_proto::{ }; use datafusion_comet_spark_expr::{ Cast, CreateNamedStruct, DateTruncExpr, GetArrayStructFields, GetStructField, HourExpr, IfExpr, - ListExtract, MinuteExpr, RLike, SecondExpr, TimestampTruncExpr, ToJson, + ListExtract, MinuteExpr, RLike, SecondExpr, SparkCastOptions, TimestampTruncExpr, ToJson, }; use datafusion_common::config::TableParquetOptions; use datafusion_common::scalar::ScalarStructBuilder; @@ -393,14 +393,11 @@ impl PhysicalPlanner { ExprStruct::Cast(expr) => { let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?; let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); - let timezone = expr.timezone.clone(); let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; Ok(Arc::new(Cast::new( child, datatype, - eval_mode, - timezone, - expr.allow_incompat, + SparkCastOptions::new(eval_mode, &expr.timezone, expr.allow_incompat), ))) } ExprStruct::Hour(expr) => { @@ -768,24 +765,21 @@ impl PhysicalPlanner { let data_type = return_type.map(to_arrow_datatype).unwrap(); // For some Decimal128 operations, we need wider internal digits. // Cast left and right to Decimal256 and cast the result back to Decimal128 - let left = Arc::new(Cast::new_without_timezone( + let left = Arc::new(Cast::new( left, DataType::Decimal256(p1, s1), - EvalMode::Legacy, - false, + SparkCastOptions::new_without_timezone(EvalMode::Legacy, false), )); - let right = Arc::new(Cast::new_without_timezone( + let right = Arc::new(Cast::new( right, DataType::Decimal256(p2, s2), - EvalMode::Legacy, - false, + SparkCastOptions::new_without_timezone(EvalMode::Legacy, false), )); let child = Arc::new(BinaryExpr::new(left, op, right)); - Ok(Arc::new(Cast::new_without_timezone( + Ok(Arc::new(Cast::new( child, data_type, - EvalMode::Legacy, - false, + SparkCastOptions::new_without_timezone(EvalMode::Legacy, false), ))) } ( diff --git a/native/core/src/execution/datafusion/schema_adapter.rs b/native/core/src/execution/datafusion/schema_adapter.rs index 2009bf84b..cf467cd1f 100644 --- a/native/core/src/execution/datafusion/schema_adapter.rs +++ b/native/core/src/execution/datafusion/schema_adapter.rs @@ -21,8 +21,8 @@ use arrow::compute::can_cast_types; use arrow_array::{new_null_array, Array, RecordBatch, RecordBatchOptions}; use arrow_schema::{DataType, Schema, SchemaRef, TimeUnit}; use datafusion::datasource::schema_adapter::{SchemaAdapter, SchemaAdapterFactory, SchemaMapper}; -use datafusion_comet_spark_expr::{spark_cast, EvalMode}; -use datafusion_common::{plan_err, DataFusionError}; +use datafusion_comet_spark_expr::{spark_cast, EvalMode, SparkCastOptions}; +use datafusion_common::plan_err; use datafusion_expr::ColumnarValue; use std::sync::Arc; @@ -107,11 +107,14 @@ impl SchemaAdapter for CometSchemaAdapter { } } + let mut cast_options = SparkCastOptions::new(EvalMode::Legacy, "UTC", false); + cast_options.is_adapting_schema = true; Ok(( Arc::new(SchemaMapping { required_schema: self.required_schema.clone(), field_mappings, table_schema: self.table_schema.clone(), + cast_options, }), projection, )) @@ -162,6 +165,8 @@ pub struct SchemaMapping { /// This contains all fields in the table, regardless of if they will be /// projected out or not. table_schema: SchemaRef, + + cast_options: SparkCastOptions, } impl SchemaMapper for SchemaMapping { @@ -193,11 +198,7 @@ impl SchemaMapper for SchemaMapping { spark_cast( ColumnarValue::Array(Arc::clone(&batch_cols[batch_idx])), field.data_type(), - // TODO need to pass in configs here - EvalMode::Legacy, - "UTC", - false, - true, + &self.cast_options, )? .into_array(batch_rows) }, @@ -245,11 +246,7 @@ impl SchemaMapper for SchemaMapping { spark_cast( ColumnarValue::Array(Arc::clone(batch_col)), table_field.data_type(), - // TODO need to pass in configs here - EvalMode::Legacy, - "UTC", - false, - true, + &self.cast_options, )? .into_array(batch_col.len()) // and if that works, return the field and column. @@ -269,8 +266,6 @@ impl SchemaMapper for SchemaMapping { } } - - fn comet_can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { // TODO this is just a quick hack to get tests passing match (from_type, to_type) { @@ -283,16 +278,3 @@ fn comet_can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { _ => can_cast_types(from_type, to_type), } } - -pub fn comet_cast( - arg: ColumnarValue, - data_type: &DataType, - eval_mode: EvalMode, - timezone: &str, - allow_incompat: bool, -) -> Result { - // TODO for now we are re-using the spark cast rules, with a hack to override - // unsupported cases and let those fall through to arrow. This is just a short term - // hack and we need to implement specific Parquet to Spark conversions here instead - spark_cast(arg, data_type, eval_mode, timezone, allow_incompat, true) -} diff --git a/native/spark-expr/src/cast.rs b/native/spark-expr/src/cast.rs index 7ce0c2a66..17ab73b72 100644 --- a/native/spark-expr/src/cast.rs +++ b/native/spark-expr/src/cast.rs @@ -35,11 +35,18 @@ use arrow::{ use arrow_array::builder::StringBuilder; use arrow_array::{DictionaryArray, StringArray, StructArray}; use arrow_schema::{DataType, Field, Schema}; +use chrono::{NaiveDate, NaiveDateTime, TimeZone, Timelike}; +use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; use datafusion_common::{ cast::as_generic_string_array, internal_err, Result as DataFusionResult, ScalarValue, }; use datafusion_expr::ColumnarValue; use datafusion_physical_expr::PhysicalExpr; +use num::{ + cast::AsPrimitive, integer::div_floor, traits::CheckedNeg, CheckedSub, Integer, Num, + ToPrimitive, +}; +use regex::Regex; use std::str::FromStr; use std::{ any::Any, @@ -49,14 +56,6 @@ use std::{ sync::Arc, }; -use chrono::{NaiveDate, NaiveDateTime, TimeZone, Timelike}; -use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; -use num::{ - cast::AsPrimitive, integer::div_floor, traits::CheckedNeg, CheckedSub, Integer, Num, - ToPrimitive, -}; -use regex::Regex; - use crate::timezone; use crate::utils::array_with_timezone; @@ -138,14 +137,7 @@ impl TimeStampInfo { pub struct Cast { pub child: Arc, pub data_type: DataType, - pub eval_mode: EvalMode, - - /// When cast from/to timezone related types, we need timezone, which will be resolved with - /// session local timezone by an analyzer in Spark. - pub timezone: String, - - /// Whether to allow casts that are known to be incompatible with Spark - pub allow_incompat: bool, + pub cast_options: SparkCastOptions, } macro_rules! cast_utf8_to_int { @@ -547,31 +539,48 @@ impl Cast { pub fn new( child: Arc, data_type: DataType, - eval_mode: EvalMode, - timezone: String, - allow_incompat: bool, + cast_options: SparkCastOptions, ) -> Self { Self { child, data_type, - timezone, + cast_options, + } + } +} + +/// Spark cast options +#[derive(Debug, Clone, Hash, PartialEq, Eq)] +pub struct SparkCastOptions { + /// Spark evaluation mode + pub eval_mode: EvalMode, + /// When cast from/to timezone related types, we need timezone, which will be resolved with + /// session local timezone by an analyzer in Spark. + // TODO we should change timezone to Tz to avoid repeated parsing + pub timezone: String, + /// Allow casts that are supported but not guaranteed to be 100% compatible + pub allow_incompat: bool, + /// We also use the cast logic for adapting Parquet schemas, so this flag is used + /// for that use case + pub is_adapting_schema: bool, +} + +impl SparkCastOptions { + pub fn new(eval_mode: EvalMode, timezone: &str, allow_incompat: bool) -> Self { + Self { eval_mode, + timezone: timezone.to_string(), allow_incompat, + is_adapting_schema: false, } } - pub fn new_without_timezone( - child: Arc, - data_type: DataType, - eval_mode: EvalMode, - allow_incompat: bool, - ) -> Self { + pub fn new_without_timezone(eval_mode: EvalMode, allow_incompat: bool) -> Self { Self { - child, - data_type, - timezone: "".to_string(), eval_mode, + timezone: "".to_string(), allow_incompat, + is_adapting_schema: false, } } } @@ -582,36 +591,21 @@ impl Cast { pub fn spark_cast( arg: ColumnarValue, data_type: &DataType, - eval_mode: EvalMode, - timezone: &str, - allow_incompat: bool, - ugly_hack_for_poc: bool, // TODO we definitely don't want to do this + cast_options: &SparkCastOptions, ) -> DataFusionResult { match arg { ColumnarValue::Array(array) => Ok(ColumnarValue::Array(cast_array( array, data_type, - eval_mode, - timezone.to_owned(), - allow_incompat, - ugly_hack_for_poc, + cast_options, )?)), ColumnarValue::Scalar(scalar) => { // Note that normally CAST(scalar) should be fold in Spark JVM side. However, for // some cases e.g., scalar subquery, Spark will not fold it, so we need to handle it // here. let array = scalar.to_array()?; - let scalar = ScalarValue::try_from_array( - &cast_array( - array, - data_type, - eval_mode, - timezone.to_owned(), - allow_incompat, - ugly_hack_for_poc, - )?, - 0, - )?; + let scalar = + ScalarValue::try_from_array(&cast_array(array, data_type, cast_options)?, 0)?; Ok(ColumnarValue::Scalar(scalar)) } } @@ -620,12 +614,9 @@ pub fn spark_cast( fn cast_array( array: ArrayRef, to_type: &DataType, - eval_mode: EvalMode, - timezone: String, - allow_incompat: bool, - ugly_hack_for_poc: bool, + cast_options: &SparkCastOptions, ) -> DataFusionResult { - let array = array_with_timezone(array, timezone.clone(), Some(to_type))?; + let array = array_with_timezone(array, cast_options.timezone.clone(), Some(to_type))?; let from_type = array.data_type().clone(); let array = match &from_type { @@ -641,14 +632,7 @@ fn cast_array( let casted_dictionary = DictionaryArray::::new( dict_array.keys().clone(), - cast_array( - Arc::clone(dict_array.values()), - to_type, - eval_mode, - timezone, - allow_incompat, - ugly_hack_for_poc, - )?, + cast_array(Arc::clone(dict_array.values()), to_type, cast_options)?, ); let casted_result = match to_type { @@ -660,6 +644,7 @@ fn cast_array( _ => array, }; let from_type = array.data_type(); + let eval_mode = cast_options.eval_mode; let cast_result = match (from_type, to_type) { (DataType::Utf8, DataType::Boolean) => spark_cast_utf8_to_boolean::(&array, eval_mode), @@ -667,7 +652,7 @@ fn cast_array( spark_cast_utf8_to_boolean::(&array, eval_mode) } (DataType::Utf8, DataType::Timestamp(_, _)) => { - cast_string_to_timestamp(&array, to_type, eval_mode, &timezone) + cast_string_to_timestamp(&array, to_type, eval_mode, &cast_options.timezone) } (DataType::Utf8, DataType::Date32) => cast_string_to_date(&array, to_type, eval_mode), (DataType::Int64, DataType::Int32) @@ -718,18 +703,16 @@ fn cast_array( spark_cast_nonintegral_numeric_to_integral(&array, eval_mode, from_type, to_type) } (DataType::Struct(_), DataType::Utf8) => { - Ok(casts_struct_to_string(array.as_struct(), &timezone)?) + Ok(casts_struct_to_string(array.as_struct(), cast_options)?) } (DataType::Struct(_), DataType::Struct(_)) => Ok(cast_struct_to_struct( array.as_struct(), from_type, to_type, - eval_mode, - timezone, - allow_incompat, + cast_options, )?), - _ if ugly_hack_for_poc - || is_datafusion_spark_compatible(from_type, to_type, allow_incompat) => + _ if cast_options.is_adapting_schema + || is_datafusion_spark_compatible(from_type, to_type, cast_options.allow_incompat) => { // use DataFusion cast only when we know that it is compatible with Spark Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?) @@ -833,9 +816,7 @@ fn cast_struct_to_struct( array: &StructArray, from_type: &DataType, to_type: &DataType, - eval_mode: EvalMode, - timezone: String, - allow_incompat: bool, + cast_options: &SparkCastOptions, ) -> DataFusionResult { match (from_type, to_type) { (DataType::Struct(_), DataType::Struct(to_fields)) => { @@ -844,10 +825,7 @@ fn cast_struct_to_struct( let cast_field = cast_array( Arc::clone(array.column(i)), to_fields[i].data_type(), - eval_mode, - timezone.clone(), - allow_incompat, - false, + cast_options, )?; cast_fields.push((Arc::clone(&to_fields[i]), cast_field)); } @@ -857,7 +835,10 @@ fn cast_struct_to_struct( } } -fn casts_struct_to_string(array: &StructArray, timezone: &str) -> DataFusionResult { +fn casts_struct_to_string( + array: &StructArray, + spark_cast_options: &SparkCastOptions, +) -> DataFusionResult { // cast each field to a string let string_arrays: Vec = array .columns() @@ -866,10 +847,7 @@ fn casts_struct_to_string(array: &StructArray, timezone: &str) -> DataFusionResu spark_cast( ColumnarValue::Array(Arc::clone(arr)), &DataType::Utf8, - EvalMode::Legacy, - timezone, - true, - false, + spark_cast_options, ) .and_then(|cv| cv.into_array(arr.len())) }) @@ -1474,7 +1452,7 @@ impl Display for Cast { write!( f, "Cast [data_type: {}, timezone: {}, child: {}, eval_mode: {:?}]", - self.data_type, self.timezone, self.child, &self.eval_mode + self.data_type, self.cast_options.timezone, self.child, &self.cast_options.eval_mode ) } } @@ -1485,9 +1463,8 @@ impl PartialEq for Cast { .downcast_ref::() .map(|x| { self.child.eq(&x.child) - && self.timezone.eq(&x.timezone) + && self.cast_options.eq(&x.cast_options) && self.data_type.eq(&x.data_type) - && self.eval_mode.eq(&x.eval_mode) }) .unwrap_or(false) } @@ -1508,14 +1485,7 @@ impl PhysicalExpr for Cast { fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult { let arg = self.child.evaluate(batch)?; - spark_cast( - arg, - &self.data_type, - self.eval_mode, - &self.timezone, - self.allow_incompat, - false, - ) + spark_cast(arg, &self.data_type, &self.cast_options) } fn children(&self) -> Vec<&Arc> { @@ -1530,9 +1500,7 @@ impl PhysicalExpr for Cast { 1 => Ok(Arc::new(Cast::new( Arc::clone(&children[0]), self.data_type.clone(), - self.eval_mode, - self.timezone.clone(), - self.allow_incompat, + self.cast_options.clone(), ))), _ => internal_err!("Cast should have exactly one child"), } @@ -1542,9 +1510,7 @@ impl PhysicalExpr for Cast { let mut s = state; self.child.hash(&mut s); self.data_type.hash(&mut s); - self.timezone.hash(&mut s); - self.eval_mode.hash(&mut s); - self.allow_incompat.hash(&mut s); + self.cast_options.hash(&mut s); self.hash(&mut s); } } @@ -2121,13 +2087,11 @@ mod tests { let timezone = "UTC".to_string(); // test casting string dictionary array to timestamp array + let cast_options = SparkCastOptions::new(EvalMode::Legacy, timezone.clone(), false); let result = cast_array( dict_array, &DataType::Timestamp(TimeUnit::Microsecond, Some(timezone.clone().into())), - EvalMode::Legacy, - timezone.clone(), - false, - false, + &cast_options, )?; assert_eq!( *result.data_type(), @@ -2332,13 +2296,11 @@ mod tests { fn test_cast_unsupported_timestamp_to_date() { // Since datafusion uses chrono::Datetime internally not all dates representable by TimestampMicrosecondType are supported let timestamps: PrimitiveArray = vec![i64::MAX].into(); + let cast_options = SparkCastOptions::new(EvalMode::Legacy, "UTC".to_string(), false); let result = cast_array( Arc::new(timestamps.with_timezone("Europe/Copenhagen")), &DataType::Date32, - EvalMode::Legacy, - "UTC".to_owned(), - false, - false, + &cast_options, ); assert!(result.is_err()) } @@ -2346,13 +2308,12 @@ mod tests { #[test] fn test_cast_invalid_timezone() { let timestamps: PrimitiveArray = vec![i64::MAX].into(); + let cast_options = + SparkCastOptions::new(EvalMode::Legacy, "Not a valid timezone".to_string(), false); let result = cast_array( Arc::new(timestamps.with_timezone("Europe/Copenhagen")), &DataType::Date32, - EvalMode::Legacy, - "Not a valid timezone".to_owned(), - false, - false, + &cast_options, ); assert!(result.is_err()) } @@ -2374,10 +2335,7 @@ mod tests { let string_array = cast_array( c, &DataType::Utf8, - EvalMode::Legacy, - "UTC".to_owned(), - false, - false, + &SparkCastOptions::new(EvalMode::Legacy, "UTC".to_owned(), false), ) .unwrap(); let string_array = string_array.as_string::(); @@ -2411,10 +2369,7 @@ mod tests { let cast_array = spark_cast( ColumnarValue::Array(c), &DataType::Struct(fields), - EvalMode::Legacy, - "UTC", - false, - false, + &SparkCastOptions::new(EvalMode::Legacy, "UTC", false), ) .unwrap(); if let ColumnarValue::Array(cast_array) = cast_array { diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs index 614b48f2b..eb02cef84 100644 --- a/native/spark-expr/src/lib.rs +++ b/native/spark-expr/src/lib.rs @@ -34,7 +34,7 @@ pub mod timezone; mod to_json; pub mod utils; -pub use cast::{spark_cast, Cast}; +pub use cast::{spark_cast, Cast, SparkCastOptions}; pub use error::{SparkError, SparkResult}; pub use if_expr::IfExpr; pub use list::{GetArrayStructFields, ListExtract}; @@ -47,7 +47,7 @@ pub use to_json::ToJson; /// the behavior when processing input values that are invalid or would result in an /// error, such as divide by zero errors, and also affects behavior when converting /// between types. -#[derive(Debug, Hash, PartialEq, Clone, Copy)] +#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)] pub enum EvalMode { /// Legacy is the default behavior in Spark prior to Spark 4.0. This mode silently ignores /// or replaces errors during SQL operations. Operations resulting in errors (like diff --git a/native/spark-expr/src/to_json.rs b/native/spark-expr/src/to_json.rs index 99f14a886..1f68eb860 100644 --- a/native/spark-expr/src/to_json.rs +++ b/native/spark-expr/src/to_json.rs @@ -19,6 +19,7 @@ // of the Spark-specific compatibility features that we need (including // being able to specify Spark-compatible cast from all types to string) +use crate::cast::SparkCastOptions; use crate::{spark_cast, EvalMode}; use arrow_array::builder::StringBuilder; use arrow_array::{Array, ArrayRef, RecordBatch, StringArray, StructArray}; @@ -117,10 +118,7 @@ fn array_to_json_string(arr: &Arc, timezone: &str) -> Result