diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index fd1f9166d8..7e8a96f28d 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -503,41 +503,37 @@ impl Cast { fn cast_array(&self, array: ArrayRef) -> DataFusionResult { let to_type = &self.data_type; let array = array_with_timezone(array, self.timezone.clone(), Some(to_type)); + let from_type = array.data_type().clone(); + + // unpack dictionary string arrays first + // TODO: we are unpacking a dictionary-encoded array and then performing + // the cast. We could potentially improve performance here by casting the + // dictionary values directly without unpacking the array first, although this + // would add more complexity to the code + let array = match &from_type { + DataType::Dictionary(key_type, value_type) + if key_type.as_ref() == &DataType::Int32 + && (value_type.as_ref() == &DataType::Utf8 + || value_type.as_ref() == &DataType::LargeUtf8) => + { + cast_with_options(&array, value_type.as_ref(), &CAST_OPTIONS)? + } + _ => array, + }; let from_type = array.data_type(); + let cast_result = match (from_type, to_type) { (DataType::Utf8, DataType::Boolean) => { - Self::spark_cast_utf8_to_boolean::(&array, self.eval_mode)? + Self::spark_cast_utf8_to_boolean::(&array, self.eval_mode) } (DataType::LargeUtf8, DataType::Boolean) => { - Self::spark_cast_utf8_to_boolean::(&array, self.eval_mode)? + Self::spark_cast_utf8_to_boolean::(&array, self.eval_mode) } (DataType::Utf8, DataType::Timestamp(_, _)) => { - Self::cast_string_to_timestamp(&array, to_type, self.eval_mode)? + Self::cast_string_to_timestamp(&array, to_type, self.eval_mode) } (DataType::Utf8, DataType::Date32) => { - Self::cast_string_to_date(&array, to_type, self.eval_mode)? - } - (DataType::Dictionary(key_type, value_type), DataType::Date32) - if key_type.as_ref() == &DataType::Int32 - && (value_type.as_ref() == &DataType::Utf8 - || value_type.as_ref() == &DataType::LargeUtf8) => - { - match value_type.as_ref() { - DataType::Utf8 => { - let unpacked_array = - cast_with_options(&array, &DataType::Utf8, &CAST_OPTIONS)?; - Self::cast_string_to_date(&unpacked_array, to_type, self.eval_mode)? - } - DataType::LargeUtf8 => { - let unpacked_array = - cast_with_options(&array, &DataType::LargeUtf8, &CAST_OPTIONS)?; - Self::cast_string_to_date(&unpacked_array, to_type, self.eval_mode)? - } - dt => unreachable!( - "{}", - format!("invalid value type {dt} for dictionary-encoded string array") - ), - } + Self::cast_string_to_date(&array, to_type, self.eval_mode) } (DataType::Int64, DataType::Int32) | (DataType::Int64, DataType::Int16) @@ -547,61 +543,33 @@ impl Cast { | (DataType::Int16, DataType::Int8) if self.eval_mode != EvalMode::Try => { - Self::spark_cast_int_to_int(&array, self.eval_mode, from_type, to_type)? + Self::spark_cast_int_to_int(&array, self.eval_mode, from_type, to_type) } ( DataType::Utf8, DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64, - ) => Self::cast_string_to_int::(to_type, &array, self.eval_mode)?, + ) => Self::cast_string_to_int::(to_type, &array, self.eval_mode), ( DataType::LargeUtf8, DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64, - ) => Self::cast_string_to_int::(to_type, &array, self.eval_mode)?, - ( - DataType::Dictionary(key_type, value_type), - DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64, - ) if key_type.as_ref() == &DataType::Int32 - && (value_type.as_ref() == &DataType::Utf8 - || value_type.as_ref() == &DataType::LargeUtf8) => - { - // TODO: we are unpacking a dictionary-encoded array and then performing - // the cast. We could potentially improve performance here by casting the - // dictionary values directly without unpacking the array first, although this - // would add more complexity to the code - match value_type.as_ref() { - DataType::Utf8 => { - let unpacked_array = - cast_with_options(&array, &DataType::Utf8, &CAST_OPTIONS)?; - Self::cast_string_to_int::(to_type, &unpacked_array, self.eval_mode)? - } - DataType::LargeUtf8 => { - let unpacked_array = - cast_with_options(&array, &DataType::LargeUtf8, &CAST_OPTIONS)?; - Self::cast_string_to_int::(to_type, &unpacked_array, self.eval_mode)? - } - dt => unreachable!( - "{}", - format!("invalid value type {dt} for dictionary-encoded string array") - ), - } - } + ) => Self::cast_string_to_int::(to_type, &array, self.eval_mode), (DataType::Float64, DataType::Utf8) => { - Self::spark_cast_float64_to_utf8::(&array, self.eval_mode)? + Self::spark_cast_float64_to_utf8::(&array, self.eval_mode) } (DataType::Float64, DataType::LargeUtf8) => { - Self::spark_cast_float64_to_utf8::(&array, self.eval_mode)? + Self::spark_cast_float64_to_utf8::(&array, self.eval_mode) } (DataType::Float32, DataType::Utf8) => { - Self::spark_cast_float32_to_utf8::(&array, self.eval_mode)? + Self::spark_cast_float32_to_utf8::(&array, self.eval_mode) } (DataType::Float32, DataType::LargeUtf8) => { - Self::spark_cast_float32_to_utf8::(&array, self.eval_mode)? + Self::spark_cast_float32_to_utf8::(&array, self.eval_mode) } (DataType::Float32, DataType::Decimal128(precision, scale)) => { - Self::cast_float32_to_decimal128(&array, *precision, *scale, self.eval_mode)? + Self::cast_float32_to_decimal128(&array, *precision, *scale, self.eval_mode) } (DataType::Float64, DataType::Decimal128(precision, scale)) => { - Self::cast_float64_to_decimal128(&array, *precision, *scale, self.eval_mode)? + Self::cast_float64_to_decimal128(&array, *precision, *scale, self.eval_mode) } (DataType::Float32, DataType::Int8) | (DataType::Float32, DataType::Int16) @@ -622,14 +590,94 @@ impl Cast { self.eval_mode, from_type, to_type, - )? + ) + } + _ if Self::is_datafusion_spark_compatible(from_type, to_type) => { + // use DataFusion cast only when we know that it is compatible with Spark + Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?) } _ => { - // when we have no Spark-specific casting we delegate to DataFusion - cast_with_options(&array, to_type, &CAST_OPTIONS)? + // we should never reach this code because the Scala code should be checking + // for supported cast operations and falling back to Spark for anything that + // is not yet supported + Err(CometError::Internal(format!( + "Native cast invoked for unsupported cast from {from_type:?} to {to_type:?}" + ))) } }; - Ok(spark_cast(cast_result, from_type, to_type)) + Ok(spark_cast(cast_result?, from_type, to_type)) + } + + /// Determines if DataFusion supports the given cast in a way that is + /// compatible with Spark + fn is_datafusion_spark_compatible(from_type: &DataType, to_type: &DataType) -> bool { + if from_type == to_type { + return true; + } + match from_type { + DataType::Boolean => matches!( + to_type, + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 + | DataType::Utf8 + ), + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { + // note that the cast from Int32/Int64 -> Decimal128 here is actually + // not compatible with Spark (no overflow checks) but we have tests that + // rely on this cast working so we have to leave it here for now + matches!( + to_type, + DataType::Boolean + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 + | DataType::Decimal128(_, _) + | DataType::Utf8 + ) + } + DataType::Float32 | DataType::Float64 => matches!( + to_type, + DataType::Boolean + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 + ), + DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => matches!( + to_type, + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 + | DataType::Decimal128(_, _) + | DataType::Decimal256(_, _) + ), + DataType::Utf8 => matches!(to_type, DataType::Binary), + DataType::Date32 => matches!(to_type, DataType::Utf8), + DataType::Timestamp(_, _) => { + matches!( + to_type, + DataType::Int64 | DataType::Date32 | DataType::Utf8 | DataType::Timestamp(_, _) + ) + } + DataType::Binary => { + // note that this is not completely Spark compatible because + // DataFusion only supports binary data containing valid UTF-8 strings + matches!(to_type, DataType::Utf8) + } + _ => false, + } } fn cast_string_to_int(