From 202140673a1810c7da1150167c46287c1187ef44 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 5 Dec 2024 14:14:33 -0700 Subject: [PATCH] save --- .../core/src/execution/datafusion/planner.rs | 2 +- .../execution/datafusion/schema_adapter.rs | 51 ++++++++++++------- native/spark-expr/src/cast.rs | 18 ++++++- native/spark-expr/src/to_json.rs | 1 + 4 files changed, 51 insertions(+), 21 deletions(-) diff --git a/native/core/src/execution/datafusion/planner.rs b/native/core/src/execution/datafusion/planner.rs index c40780c58..d376b7889 100644 --- a/native/core/src/execution/datafusion/planner.rs +++ b/native/core/src/execution/datafusion/planner.rs @@ -17,7 +17,6 @@ //! Converts Spark physical plan to DataFusion physical plan -use datafusion::physical_plan::filter::FilterExec; use super::expressions::EvalMode; use crate::execution::datafusion::expressions::comet_scalar_funcs::create_comet_physical_fun; use crate::execution::operators::{CopyMode, FilterExec as CometFilterExec}; @@ -56,6 +55,7 @@ use datafusion::functions_aggregate::bit_and_or_xor::{bit_and_udaf, bit_or_udaf, use datafusion::functions_aggregate::min_max::max_udaf; use datafusion::functions_aggregate::min_max::min_udaf; use datafusion::functions_aggregate::sum::sum_udaf; +use datafusion::physical_plan::filter::FilterExec; use datafusion::physical_plan::windows::BoundedWindowAggExec; use datafusion::physical_plan::InputOrderMode; use datafusion::{ diff --git a/native/core/src/execution/datafusion/schema_adapter.rs b/native/core/src/execution/datafusion/schema_adapter.rs index ff325f647..2009bf84b 100644 --- a/native/core/src/execution/datafusion/schema_adapter.rs +++ b/native/core/src/execution/datafusion/schema_adapter.rs @@ -22,7 +22,7 @@ use arrow_array::{new_null_array, Array, RecordBatch, RecordBatchOptions}; use arrow_schema::{DataType, Schema, SchemaRef, TimeUnit}; use datafusion::datasource::schema_adapter::{SchemaAdapter, SchemaAdapterFactory, SchemaMapper}; use datafusion_comet_spark_expr::{spark_cast, EvalMode}; -use datafusion_common::plan_err; +use datafusion_common::{plan_err, DataFusionError}; use datafusion_expr::ColumnarValue; use std::sync::Arc; @@ -93,7 +93,7 @@ impl SchemaAdapter for CometSchemaAdapter { if let Some((table_idx, table_field)) = self.required_schema.fields().find(file_field.name()) { - if spark_can_cast_types(file_field.data_type(), table_field.data_type()) { + if comet_can_cast_types(file_field.data_type(), table_field.data_type()) { field_mappings[table_idx] = Some(projection.len()); projection.push(file_idx); } else { @@ -118,23 +118,6 @@ impl SchemaAdapter for CometSchemaAdapter { } } -pub fn spark_can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { - // TODO add all Spark cast rules (they are currently implemented in - // org.apache.comet.expressions.CometCast#isSupported in JVM side) - match (from_type, to_type) { - (DataType::Struct(_), DataType::Struct(_)) => { - // workaround for struct casting - true - } - (_, DataType::Timestamp(TimeUnit::Nanosecond, _)) => false, - // Native cast invoked for unsupported cast from FixedSizeBinary(3) to Binary. - (DataType::FixedSizeBinary(_), _) => false, - // Native cast invoked for unsupported cast from UInt32 to Int64. - (DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64, _) => false, - _ => can_cast_types(from_type, to_type), - } -} - // TODO SchemaMapping is mostly copied from DataFusion but calls spark_cast // instead of arrow cast - can we reduce the amount of code copied here and make // the DataFusion version more extensible? @@ -214,6 +197,7 @@ impl SchemaMapper for SchemaMapping { EvalMode::Legacy, "UTC", false, + true, )? .into_array(batch_rows) }, @@ -265,6 +249,7 @@ impl SchemaMapper for SchemaMapping { EvalMode::Legacy, "UTC", false, + true, )? .into_array(batch_col.len()) // and if that works, return the field and column. @@ -283,3 +268,31 @@ impl SchemaMapper for SchemaMapping { Ok(record_batch) } } + + + +fn comet_can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { + // TODO this is just a quick hack to get tests passing + match (from_type, to_type) { + (DataType::Struct(_), DataType::Struct(_)) => { + // workaround for struct casting + true + } + // TODO this is maybe no longer needed + (_, DataType::Timestamp(TimeUnit::Nanosecond, _)) => false, + _ => can_cast_types(from_type, to_type), + } +} + +pub fn comet_cast( + arg: ColumnarValue, + data_type: &DataType, + eval_mode: EvalMode, + timezone: &str, + allow_incompat: bool, +) -> Result { + // TODO for now we are re-using the spark cast rules, with a hack to override + // unsupported cases and let those fall through to arrow. This is just a short term + // hack and we need to implement specific Parquet to Spark conversions here instead + spark_cast(arg, data_type, eval_mode, timezone, allow_incompat, true) +} diff --git a/native/spark-expr/src/cast.rs b/native/spark-expr/src/cast.rs index 8ef9ca291..7ce0c2a66 100644 --- a/native/spark-expr/src/cast.rs +++ b/native/spark-expr/src/cast.rs @@ -585,6 +585,7 @@ pub fn spark_cast( eval_mode: EvalMode, timezone: &str, allow_incompat: bool, + ugly_hack_for_poc: bool, // TODO we definitely don't want to do this ) -> DataFusionResult { match arg { ColumnarValue::Array(array) => Ok(ColumnarValue::Array(cast_array( @@ -593,6 +594,7 @@ pub fn spark_cast( eval_mode, timezone.to_owned(), allow_incompat, + ugly_hack_for_poc, )?)), ColumnarValue::Scalar(scalar) => { // Note that normally CAST(scalar) should be fold in Spark JVM side. However, for @@ -606,6 +608,7 @@ pub fn spark_cast( eval_mode, timezone.to_owned(), allow_incompat, + ugly_hack_for_poc, )?, 0, )?; @@ -620,6 +623,7 @@ fn cast_array( eval_mode: EvalMode, timezone: String, allow_incompat: bool, + ugly_hack_for_poc: bool, ) -> DataFusionResult { let array = array_with_timezone(array, timezone.clone(), Some(to_type))?; let from_type = array.data_type().clone(); @@ -643,6 +647,7 @@ fn cast_array( eval_mode, timezone, allow_incompat, + ugly_hack_for_poc, )?, ); @@ -723,7 +728,9 @@ fn cast_array( timezone, allow_incompat, )?), - _ if is_datafusion_spark_compatible(from_type, to_type, allow_incompat) => { + _ if ugly_hack_for_poc + || is_datafusion_spark_compatible(from_type, to_type, allow_incompat) => + { // use DataFusion cast only when we know that it is compatible with Spark Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?) } @@ -840,6 +847,7 @@ fn cast_struct_to_struct( eval_mode, timezone.clone(), allow_incompat, + false, )?; cast_fields.push((Arc::clone(&to_fields[i]), cast_field)); } @@ -861,6 +869,7 @@ fn casts_struct_to_string(array: &StructArray, timezone: &str) -> DataFusionResu EvalMode::Legacy, timezone, true, + false, ) .and_then(|cv| cv.into_array(arr.len())) }) @@ -1505,6 +1514,7 @@ impl PhysicalExpr for Cast { self.eval_mode, &self.timezone, self.allow_incompat, + false, ) } @@ -2117,6 +2127,7 @@ mod tests { EvalMode::Legacy, timezone.clone(), false, + false, )?; assert_eq!( *result.data_type(), @@ -2327,6 +2338,7 @@ mod tests { EvalMode::Legacy, "UTC".to_owned(), false, + false, ); assert!(result.is_err()) } @@ -2340,6 +2352,7 @@ mod tests { EvalMode::Legacy, "Not a valid timezone".to_owned(), false, + false, ); assert!(result.is_err()) } @@ -2364,6 +2377,7 @@ mod tests { EvalMode::Legacy, "UTC".to_owned(), false, + false, ) .unwrap(); let string_array = string_array.as_string::(); @@ -2400,6 +2414,7 @@ mod tests { EvalMode::Legacy, "UTC", false, + false, ) .unwrap(); if let ColumnarValue::Array(cast_array) = cast_array { @@ -2433,6 +2448,7 @@ mod tests { EvalMode::Legacy, "UTC", false, + false, ) .unwrap(); if let ColumnarValue::Array(cast_array) = cast_array { diff --git a/native/spark-expr/src/to_json.rs b/native/spark-expr/src/to_json.rs index 7d38cbf1b..99f14a886 100644 --- a/native/spark-expr/src/to_json.rs +++ b/native/spark-expr/src/to_json.rs @@ -120,6 +120,7 @@ fn array_to_json_string(arr: &Arc, timezone: &str) -> Result