From 5d912ffcf1ead3a548cac2aabfc5ecc2778fdd3e Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 13 Dec 2024 02:14:53 -0700 Subject: [PATCH] refactor --- native/spark-expr/src/cast.rs | 253 ++++++++++++++++++++++++ native/spark-expr/src/schema_adapter.rs | 16 +- 2 files changed, 260 insertions(+), 9 deletions(-) diff --git a/native/spark-expr/src/cast.rs b/native/spark-expr/src/cast.rs index f62d0220c..4a3a994a6 100644 --- a/native/spark-expr/src/cast.rs +++ b/native/spark-expr/src/cast.rs @@ -141,6 +141,259 @@ pub struct Cast { pub cast_options: SparkCastOptions, } +pub fn cast_supported( + from_type: &DataType, + to_type: &DataType, + options: &SparkCastOptions, +) -> bool { + true + + // TODO: + // convert the following scala code to Rust + // have the plugin Scala code call this logic via JNI so that we do not duplicate it + + /* + def isSupported( + fromType: DataType, + toType: DataType, + timeZoneId: Option[String], + evalMode: CometEvalMode.Value): SupportLevel = { + + if (fromType == toType) { + return Compatible() + } + + (fromType, toType) match { + case (dt: DataType, _) if dt.typeName == "timestamp_ntz" => + // https://github.com/apache/datafusion-comet/issues/378 + toType match { + case DataTypes.TimestampType | DataTypes.DateType | DataTypes.StringType => + Incompatible() + case _ => + Unsupported + } + case (from: DecimalType, to: DecimalType) => + if (to.precision < from.precision) { + // https://github.com/apache/datafusion/issues/13492 + Incompatible(Some("Casting to smaller precision is not supported")) + } else { + Compatible() + } + case (DataTypes.StringType, _) => + canCastFromString(toType, timeZoneId, evalMode) + case (_, DataTypes.StringType) => + canCastToString(fromType, timeZoneId, evalMode) + case (DataTypes.TimestampType, _) => + canCastFromTimestamp(toType) + case (_: DecimalType, _) => + canCastFromDecimal(toType) + case (DataTypes.BooleanType, _) => + canCastFromBoolean(toType) + case (DataTypes.ByteType, _) => + canCastFromByte(toType) + case (DataTypes.ShortType, _) => + canCastFromShort(toType) + case (DataTypes.IntegerType, _) => + canCastFromInt(toType) + case (DataTypes.LongType, _) => + canCastFromLong(toType) + case (DataTypes.FloatType, _) => + canCastFromFloat(toType) + case (DataTypes.DoubleType, _) => + canCastFromDouble(toType) + case (from_struct: StructType, to_struct: StructType) => + from_struct.fields.zip(to_struct.fields).foreach { case (a, b) => + isSupported(a.dataType, b.dataType, timeZoneId, evalMode) match { + case Compatible(_) => + // all good + case other => + return other + } + } + Compatible() + case _ => Unsupported + } + } + + private def canCastFromString( + toType: DataType, + timeZoneId: Option[String], + evalMode: CometEvalMode.Value): SupportLevel = { + toType match { + case DataTypes.BooleanType => + Compatible() + case DataTypes.ByteType | DataTypes.ShortType | DataTypes.IntegerType | + DataTypes.LongType => + Compatible() + case DataTypes.BinaryType => + Compatible() + case DataTypes.FloatType | DataTypes.DoubleType => + // https://github.com/apache/datafusion-comet/issues/326 + Incompatible( + Some( + "Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. " + + "Does not support ANSI mode.")) + case _: DecimalType => + // https://github.com/apache/datafusion-comet/issues/325 + Incompatible( + Some("Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. " + + "Does not support ANSI mode. Returns 0.0 instead of null if input contains no digits")) + case DataTypes.DateType => + // https://github.com/apache/datafusion-comet/issues/327 + Compatible(Some("Only supports years between 262143 BC and 262142 AD")) + case DataTypes.TimestampType if timeZoneId.exists(tz => tz != "UTC") => + Incompatible(Some(s"Cast will use UTC instead of $timeZoneId")) + case DataTypes.TimestampType if evalMode == "ANSI" => + Incompatible(Some("ANSI mode not supported")) + case DataTypes.TimestampType => + // https://github.com/apache/datafusion-comet/issues/328 + Incompatible(Some("Not all valid formats are supported")) + case _ => + Unsupported + } + } + + private def canCastToString( + fromType: DataType, + timeZoneId: Option[String], + evalMode: CometEvalMode.Value): SupportLevel = { + fromType match { + case DataTypes.BooleanType => Compatible() + case DataTypes.ByteType | DataTypes.ShortType | DataTypes.IntegerType | + DataTypes.LongType => + Compatible() + case DataTypes.DateType => Compatible() + case DataTypes.TimestampType => Compatible() + case DataTypes.FloatType | DataTypes.DoubleType => + Compatible( + Some( + "There can be differences in precision. " + + "For example, the input \"1.4E-45\" will produce 1.0E-45 " + + "instead of 1.4E-45")) + case _: DecimalType => + // https://github.com/apache/datafusion-comet/issues/1068 + Compatible( + Some( + "There can be formatting differences in some case due to Spark using " + + "scientific notation where Comet does not")) + case DataTypes.BinaryType => + // https://github.com/apache/datafusion-comet/issues/377 + Incompatible(Some("Only works for binary data representing valid UTF-8 strings")) + case StructType(fields) => + for (field <- fields) { + isSupported(field.dataType, DataTypes.StringType, timeZoneId, evalMode) match { + case s: Incompatible => + return s + case Unsupported => + return Unsupported + case _ => + } + } + Compatible() + case _ => Unsupported + } + } + + private def canCastFromTimestamp(toType: DataType): SupportLevel = { + toType match { + case DataTypes.BooleanType | DataTypes.ByteType | DataTypes.ShortType | + DataTypes.IntegerType => + // https://github.com/apache/datafusion-comet/issues/352 + // this seems like an edge case that isn't important for us to support + Unsupported + case DataTypes.LongType => + // https://github.com/apache/datafusion-comet/issues/352 + Compatible() + case DataTypes.StringType => Compatible() + case DataTypes.DateType => Compatible() + case _: DecimalType => Compatible() + case _ => Unsupported + } + } + + private def canCastFromBoolean(toType: DataType): SupportLevel = toType match { + case DataTypes.ByteType | DataTypes.ShortType | DataTypes.IntegerType | DataTypes.LongType | + DataTypes.FloatType | DataTypes.DoubleType => + Compatible() + case _ => Unsupported + } + + private def canCastFromByte(toType: DataType): SupportLevel = toType match { + case DataTypes.BooleanType => + Compatible() + case DataTypes.ShortType | DataTypes.IntegerType | DataTypes.LongType => + Compatible() + case DataTypes.FloatType | DataTypes.DoubleType | _: DecimalType => + Compatible() + case _ => + Unsupported + } + + private def canCastFromShort(toType: DataType): SupportLevel = toType match { + case DataTypes.BooleanType => + Compatible() + case DataTypes.ByteType | DataTypes.IntegerType | DataTypes.LongType => + Compatible() + case DataTypes.FloatType | DataTypes.DoubleType | _: DecimalType => + Compatible() + case _ => + Unsupported + } + + private def canCastFromInt(toType: DataType): SupportLevel = toType match { + case DataTypes.BooleanType => + Compatible() + case DataTypes.ByteType | DataTypes.ShortType | DataTypes.LongType => + Compatible() + case DataTypes.FloatType | DataTypes.DoubleType => + Compatible() + case _: DecimalType => + Incompatible(Some("No overflow check")) + case _ => + Unsupported + } + + private def canCastFromLong(toType: DataType): SupportLevel = toType match { + case DataTypes.BooleanType => + Compatible() + case DataTypes.ByteType | DataTypes.ShortType | DataTypes.IntegerType => + Compatible() + case DataTypes.FloatType | DataTypes.DoubleType => + Compatible() + case _: DecimalType => + Incompatible(Some("No overflow check")) + case _ => + Unsupported + } + + private def canCastFromFloat(toType: DataType): SupportLevel = toType match { + case DataTypes.BooleanType | DataTypes.DoubleType | DataTypes.ByteType | DataTypes.ShortType | + DataTypes.IntegerType | DataTypes.LongType => + Compatible() + case _: DecimalType => Compatible() + case _ => Unsupported + } + + private def canCastFromDouble(toType: DataType): SupportLevel = toType match { + case DataTypes.BooleanType | DataTypes.FloatType | DataTypes.ByteType | DataTypes.ShortType | + DataTypes.IntegerType | DataTypes.LongType => + Compatible() + case _: DecimalType => Compatible() + case _ => Unsupported + } + + private def canCastFromDecimal(toType: DataType): SupportLevel = toType match { + case DataTypes.FloatType | DataTypes.DoubleType | DataTypes.ByteType | DataTypes.ShortType | + DataTypes.IntegerType | DataTypes.LongType => + Compatible() + case _ => Unsupported + } + + } + + */ +} + macro_rules! cast_utf8_to_int { ($array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident) => {{ let len = $array.len(); diff --git a/native/spark-expr/src/schema_adapter.rs b/native/spark-expr/src/schema_adapter.rs index 64c088e14..da432b747 100644 --- a/native/spark-expr/src/schema_adapter.rs +++ b/native/spark-expr/src/schema_adapter.rs @@ -17,10 +17,10 @@ //! Custom schema adapter that uses Spark-compatible casts +use crate::cast::cast_supported; use crate::{spark_cast, SparkCastOptions}; -use arrow::compute::can_cast_types; use arrow_array::{new_null_array, Array, RecordBatch, RecordBatchOptions}; -use arrow_schema::{DataType, Schema, SchemaRef}; +use arrow_schema::{Schema, SchemaRef}; use datafusion::datasource::schema_adapter::{SchemaAdapter, SchemaAdapterFactory, SchemaMapper}; use datafusion_common::plan_err; use datafusion_expr::ColumnarValue; @@ -109,7 +109,11 @@ impl SchemaAdapter for SparkSchemaAdapter { if let Some((table_idx, table_field)) = self.required_schema.fields().find(file_field.name()) { - if spark_can_cast_types(file_field.data_type(), table_field.data_type()) { + if cast_supported( + file_field.data_type(), + table_field.data_type(), + &self.cast_options, + ) { field_mappings[table_idx] = Some(projection.len()); projection.push(file_idx); } else { @@ -279,9 +283,3 @@ impl SchemaMapper for SchemaMapping { Ok(record_batch) } } - -fn spark_can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { - // TODO: implement Spark's logic for determining which casts are supported but for now - // just delegate to Arrow's rules - can_cast_types(from_type, to_type) -}