Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove committed Dec 13, 2024
1 parent 0bc7e82 commit 5d912ff
Show file tree
Hide file tree
Showing 2 changed files with 260 additions and 9 deletions.
253 changes: 253 additions & 0 deletions native/spark-expr/src/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,259 @@ pub struct Cast {
pub cast_options: SparkCastOptions,
}

pub fn cast_supported(
from_type: &DataType,
to_type: &DataType,
options: &SparkCastOptions,
) -> bool {
true

// TODO:
// convert the following scala code to Rust
// have the plugin Scala code call this logic via JNI so that we do not duplicate it

/*
def isSupported(
fromType: DataType,
toType: DataType,
timeZoneId: Option[String],
evalMode: CometEvalMode.Value): SupportLevel = {
if (fromType == toType) {
return Compatible()
}
(fromType, toType) match {
case (dt: DataType, _) if dt.typeName == "timestamp_ntz" =>
// https://github.com/apache/datafusion-comet/issues/378
toType match {
case DataTypes.TimestampType | DataTypes.DateType | DataTypes.StringType =>
Incompatible()
case _ =>
Unsupported
}
case (from: DecimalType, to: DecimalType) =>
if (to.precision < from.precision) {
// https://github.com/apache/datafusion/issues/13492
Incompatible(Some("Casting to smaller precision is not supported"))
} else {
Compatible()
}
case (DataTypes.StringType, _) =>
canCastFromString(toType, timeZoneId, evalMode)
case (_, DataTypes.StringType) =>
canCastToString(fromType, timeZoneId, evalMode)
case (DataTypes.TimestampType, _) =>
canCastFromTimestamp(toType)
case (_: DecimalType, _) =>
canCastFromDecimal(toType)
case (DataTypes.BooleanType, _) =>
canCastFromBoolean(toType)
case (DataTypes.ByteType, _) =>
canCastFromByte(toType)
case (DataTypes.ShortType, _) =>
canCastFromShort(toType)
case (DataTypes.IntegerType, _) =>
canCastFromInt(toType)
case (DataTypes.LongType, _) =>
canCastFromLong(toType)
case (DataTypes.FloatType, _) =>
canCastFromFloat(toType)
case (DataTypes.DoubleType, _) =>
canCastFromDouble(toType)
case (from_struct: StructType, to_struct: StructType) =>
from_struct.fields.zip(to_struct.fields).foreach { case (a, b) =>
isSupported(a.dataType, b.dataType, timeZoneId, evalMode) match {
case Compatible(_) =>
// all good
case other =>
return other
}
}
Compatible()
case _ => Unsupported
}
}
private def canCastFromString(
toType: DataType,
timeZoneId: Option[String],
evalMode: CometEvalMode.Value): SupportLevel = {
toType match {
case DataTypes.BooleanType =>
Compatible()
case DataTypes.ByteType | DataTypes.ShortType | DataTypes.IntegerType |
DataTypes.LongType =>
Compatible()
case DataTypes.BinaryType =>
Compatible()
case DataTypes.FloatType | DataTypes.DoubleType =>
// https://github.com/apache/datafusion-comet/issues/326
Incompatible(
Some(
"Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. " +
"Does not support ANSI mode."))
case _: DecimalType =>
// https://github.com/apache/datafusion-comet/issues/325
Incompatible(
Some("Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. " +
"Does not support ANSI mode. Returns 0.0 instead of null if input contains no digits"))
case DataTypes.DateType =>
// https://github.com/apache/datafusion-comet/issues/327
Compatible(Some("Only supports years between 262143 BC and 262142 AD"))
case DataTypes.TimestampType if timeZoneId.exists(tz => tz != "UTC") =>
Incompatible(Some(s"Cast will use UTC instead of $timeZoneId"))
case DataTypes.TimestampType if evalMode == "ANSI" =>
Incompatible(Some("ANSI mode not supported"))
case DataTypes.TimestampType =>
// https://github.com/apache/datafusion-comet/issues/328
Incompatible(Some("Not all valid formats are supported"))
case _ =>
Unsupported
}
}
private def canCastToString(
fromType: DataType,
timeZoneId: Option[String],
evalMode: CometEvalMode.Value): SupportLevel = {
fromType match {
case DataTypes.BooleanType => Compatible()
case DataTypes.ByteType | DataTypes.ShortType | DataTypes.IntegerType |
DataTypes.LongType =>
Compatible()
case DataTypes.DateType => Compatible()
case DataTypes.TimestampType => Compatible()
case DataTypes.FloatType | DataTypes.DoubleType =>
Compatible(
Some(
"There can be differences in precision. " +
"For example, the input \"1.4E-45\" will produce 1.0E-45 " +
"instead of 1.4E-45"))
case _: DecimalType =>
// https://github.com/apache/datafusion-comet/issues/1068
Compatible(
Some(
"There can be formatting differences in some case due to Spark using " +
"scientific notation where Comet does not"))
case DataTypes.BinaryType =>
// https://github.com/apache/datafusion-comet/issues/377
Incompatible(Some("Only works for binary data representing valid UTF-8 strings"))
case StructType(fields) =>
for (field <- fields) {
isSupported(field.dataType, DataTypes.StringType, timeZoneId, evalMode) match {
case s: Incompatible =>
return s
case Unsupported =>
return Unsupported
case _ =>
}
}
Compatible()
case _ => Unsupported
}
}
private def canCastFromTimestamp(toType: DataType): SupportLevel = {
toType match {
case DataTypes.BooleanType | DataTypes.ByteType | DataTypes.ShortType |
DataTypes.IntegerType =>
// https://github.com/apache/datafusion-comet/issues/352
// this seems like an edge case that isn't important for us to support
Unsupported
case DataTypes.LongType =>
// https://github.com/apache/datafusion-comet/issues/352
Compatible()
case DataTypes.StringType => Compatible()
case DataTypes.DateType => Compatible()
case _: DecimalType => Compatible()
case _ => Unsupported
}
}
private def canCastFromBoolean(toType: DataType): SupportLevel = toType match {
case DataTypes.ByteType | DataTypes.ShortType | DataTypes.IntegerType | DataTypes.LongType |
DataTypes.FloatType | DataTypes.DoubleType =>
Compatible()
case _ => Unsupported
}
private def canCastFromByte(toType: DataType): SupportLevel = toType match {
case DataTypes.BooleanType =>
Compatible()
case DataTypes.ShortType | DataTypes.IntegerType | DataTypes.LongType =>
Compatible()
case DataTypes.FloatType | DataTypes.DoubleType | _: DecimalType =>
Compatible()
case _ =>
Unsupported
}
private def canCastFromShort(toType: DataType): SupportLevel = toType match {
case DataTypes.BooleanType =>
Compatible()
case DataTypes.ByteType | DataTypes.IntegerType | DataTypes.LongType =>
Compatible()
case DataTypes.FloatType | DataTypes.DoubleType | _: DecimalType =>
Compatible()
case _ =>
Unsupported
}
private def canCastFromInt(toType: DataType): SupportLevel = toType match {
case DataTypes.BooleanType =>
Compatible()
case DataTypes.ByteType | DataTypes.ShortType | DataTypes.LongType =>
Compatible()
case DataTypes.FloatType | DataTypes.DoubleType =>
Compatible()
case _: DecimalType =>
Incompatible(Some("No overflow check"))
case _ =>
Unsupported
}
private def canCastFromLong(toType: DataType): SupportLevel = toType match {
case DataTypes.BooleanType =>
Compatible()
case DataTypes.ByteType | DataTypes.ShortType | DataTypes.IntegerType =>
Compatible()
case DataTypes.FloatType | DataTypes.DoubleType =>
Compatible()
case _: DecimalType =>
Incompatible(Some("No overflow check"))
case _ =>
Unsupported
}
private def canCastFromFloat(toType: DataType): SupportLevel = toType match {
case DataTypes.BooleanType | DataTypes.DoubleType | DataTypes.ByteType | DataTypes.ShortType |
DataTypes.IntegerType | DataTypes.LongType =>
Compatible()
case _: DecimalType => Compatible()
case _ => Unsupported
}
private def canCastFromDouble(toType: DataType): SupportLevel = toType match {
case DataTypes.BooleanType | DataTypes.FloatType | DataTypes.ByteType | DataTypes.ShortType |
DataTypes.IntegerType | DataTypes.LongType =>
Compatible()
case _: DecimalType => Compatible()
case _ => Unsupported
}
private def canCastFromDecimal(toType: DataType): SupportLevel = toType match {
case DataTypes.FloatType | DataTypes.DoubleType | DataTypes.ByteType | DataTypes.ShortType |
DataTypes.IntegerType | DataTypes.LongType =>
Compatible()
case _ => Unsupported
}
}
*/
}

macro_rules! cast_utf8_to_int {
($array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident) => {{
let len = $array.len();
Expand Down
16 changes: 7 additions & 9 deletions native/spark-expr/src/schema_adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@

//! Custom schema adapter that uses Spark-compatible casts
use crate::cast::cast_supported;
use crate::{spark_cast, SparkCastOptions};
use arrow::compute::can_cast_types;
use arrow_array::{new_null_array, Array, RecordBatch, RecordBatchOptions};
use arrow_schema::{DataType, Schema, SchemaRef};
use arrow_schema::{Schema, SchemaRef};
use datafusion::datasource::schema_adapter::{SchemaAdapter, SchemaAdapterFactory, SchemaMapper};
use datafusion_common::plan_err;
use datafusion_expr::ColumnarValue;
Expand Down Expand Up @@ -109,7 +109,11 @@ impl SchemaAdapter for SparkSchemaAdapter {
if let Some((table_idx, table_field)) =
self.required_schema.fields().find(file_field.name())
{
if spark_can_cast_types(file_field.data_type(), table_field.data_type()) {
if cast_supported(
file_field.data_type(),
table_field.data_type(),
&self.cast_options,
) {
field_mappings[table_idx] = Some(projection.len());
projection.push(file_idx);
} else {
Expand Down Expand Up @@ -279,9 +283,3 @@ impl SchemaMapper for SchemaMapping {
Ok(record_batch)
}
}

fn spark_can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
// TODO: implement Spark's logic for determining which casts are supported but for now
// just delegate to Arrow's rules
can_cast_types(from_type, to_type)
}

0 comments on commit 5d912ff

Please sign in to comment.