Skip to content

Commit

Permalink
fix: Only delegate to DataFusion cast when we know that it is compati…
Browse files Browse the repository at this point in the history
…ble with Spark (#461)

* only delegate to DataFusion cast when we know that it is compatible with Spark

* add more supported casts

* improve support for dictionary-encoded string arrays

* clippy

* fix merge conflict

* fix a regression

* fix a regression

* fix a regression

* fix regression

* fix regression

* fix regression

* remove TODO comment now that issue has been filed

* remove cast int32/int64 -> decimal from datafusion compatible list

* Revert "remove cast int32/int64 -> decimal from datafusion compatible list"

This reverts commit 340e000.

* add comment
andygrove authored May 25, 2024
1 parent 93af704 commit 79431f8
Showing 1 changed file with 115 additions and 67 deletions.
182 changes: 115 additions & 67 deletions core/src/execution/datafusion/expressions/cast.rs
Original file line number Diff line number Diff line change
@@ -503,41 +503,37 @@ impl Cast {
fn cast_array(&self, array: ArrayRef) -> DataFusionResult<ArrayRef> {
let to_type = &self.data_type;
let array = array_with_timezone(array, self.timezone.clone(), Some(to_type));
let from_type = array.data_type().clone();

// unpack dictionary string arrays first
// TODO: we are unpacking a dictionary-encoded array and then performing
// the cast. We could potentially improve performance here by casting the
// dictionary values directly without unpacking the array first, although this
// would add more complexity to the code
let array = match &from_type {
DataType::Dictionary(key_type, value_type)
if key_type.as_ref() == &DataType::Int32
&& (value_type.as_ref() == &DataType::Utf8
|| value_type.as_ref() == &DataType::LargeUtf8) =>
{
cast_with_options(&array, value_type.as_ref(), &CAST_OPTIONS)?
}
_ => array,
};
let from_type = array.data_type();

let cast_result = match (from_type, to_type) {
(DataType::Utf8, DataType::Boolean) => {
Self::spark_cast_utf8_to_boolean::<i32>(&array, self.eval_mode)?
Self::spark_cast_utf8_to_boolean::<i32>(&array, self.eval_mode)
}
(DataType::LargeUtf8, DataType::Boolean) => {
Self::spark_cast_utf8_to_boolean::<i64>(&array, self.eval_mode)?
Self::spark_cast_utf8_to_boolean::<i64>(&array, self.eval_mode)
}
(DataType::Utf8, DataType::Timestamp(_, _)) => {
Self::cast_string_to_timestamp(&array, to_type, self.eval_mode)?
Self::cast_string_to_timestamp(&array, to_type, self.eval_mode)
}
(DataType::Utf8, DataType::Date32) => {
Self::cast_string_to_date(&array, to_type, self.eval_mode)?
}
(DataType::Dictionary(key_type, value_type), DataType::Date32)
if key_type.as_ref() == &DataType::Int32
&& (value_type.as_ref() == &DataType::Utf8
|| value_type.as_ref() == &DataType::LargeUtf8) =>
{
match value_type.as_ref() {
DataType::Utf8 => {
let unpacked_array =
cast_with_options(&array, &DataType::Utf8, &CAST_OPTIONS)?;
Self::cast_string_to_date(&unpacked_array, to_type, self.eval_mode)?
}
DataType::LargeUtf8 => {
let unpacked_array =
cast_with_options(&array, &DataType::LargeUtf8, &CAST_OPTIONS)?;
Self::cast_string_to_date(&unpacked_array, to_type, self.eval_mode)?
}
dt => unreachable!(
"{}",
format!("invalid value type {dt} for dictionary-encoded string array")
),
}
Self::cast_string_to_date(&array, to_type, self.eval_mode)
}
(DataType::Int64, DataType::Int32)
| (DataType::Int64, DataType::Int16)
@@ -547,61 +543,33 @@ impl Cast {
| (DataType::Int16, DataType::Int8)
if self.eval_mode != EvalMode::Try =>
{
Self::spark_cast_int_to_int(&array, self.eval_mode, from_type, to_type)?
Self::spark_cast_int_to_int(&array, self.eval_mode, from_type, to_type)
}
(
DataType::Utf8,
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64,
) => Self::cast_string_to_int::<i32>(to_type, &array, self.eval_mode)?,
) => Self::cast_string_to_int::<i32>(to_type, &array, self.eval_mode),
(
DataType::LargeUtf8,
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64,
) => Self::cast_string_to_int::<i64>(to_type, &array, self.eval_mode)?,
(
DataType::Dictionary(key_type, value_type),
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64,
) if key_type.as_ref() == &DataType::Int32
&& (value_type.as_ref() == &DataType::Utf8
|| value_type.as_ref() == &DataType::LargeUtf8) =>
{
// TODO: we are unpacking a dictionary-encoded array and then performing
// the cast. We could potentially improve performance here by casting the
// dictionary values directly without unpacking the array first, although this
// would add more complexity to the code
match value_type.as_ref() {
DataType::Utf8 => {
let unpacked_array =
cast_with_options(&array, &DataType::Utf8, &CAST_OPTIONS)?;
Self::cast_string_to_int::<i32>(to_type, &unpacked_array, self.eval_mode)?
}
DataType::LargeUtf8 => {
let unpacked_array =
cast_with_options(&array, &DataType::LargeUtf8, &CAST_OPTIONS)?;
Self::cast_string_to_int::<i64>(to_type, &unpacked_array, self.eval_mode)?
}
dt => unreachable!(
"{}",
format!("invalid value type {dt} for dictionary-encoded string array")
),
}
}
) => Self::cast_string_to_int::<i64>(to_type, &array, self.eval_mode),
(DataType::Float64, DataType::Utf8) => {
Self::spark_cast_float64_to_utf8::<i32>(&array, self.eval_mode)?
Self::spark_cast_float64_to_utf8::<i32>(&array, self.eval_mode)
}
(DataType::Float64, DataType::LargeUtf8) => {
Self::spark_cast_float64_to_utf8::<i64>(&array, self.eval_mode)?
Self::spark_cast_float64_to_utf8::<i64>(&array, self.eval_mode)
}
(DataType::Float32, DataType::Utf8) => {
Self::spark_cast_float32_to_utf8::<i32>(&array, self.eval_mode)?
Self::spark_cast_float32_to_utf8::<i32>(&array, self.eval_mode)
}
(DataType::Float32, DataType::LargeUtf8) => {
Self::spark_cast_float32_to_utf8::<i64>(&array, self.eval_mode)?
Self::spark_cast_float32_to_utf8::<i64>(&array, self.eval_mode)
}
(DataType::Float32, DataType::Decimal128(precision, scale)) => {
Self::cast_float32_to_decimal128(&array, *precision, *scale, self.eval_mode)?
Self::cast_float32_to_decimal128(&array, *precision, *scale, self.eval_mode)
}
(DataType::Float64, DataType::Decimal128(precision, scale)) => {
Self::cast_float64_to_decimal128(&array, *precision, *scale, self.eval_mode)?
Self::cast_float64_to_decimal128(&array, *precision, *scale, self.eval_mode)
}
(DataType::Float32, DataType::Int8)
| (DataType::Float32, DataType::Int16)
@@ -622,14 +590,94 @@ impl Cast {
self.eval_mode,
from_type,
to_type,
)?
)
}
_ if Self::is_datafusion_spark_compatible(from_type, to_type) => {
// use DataFusion cast only when we know that it is compatible with Spark
Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?)
}
_ => {
// when we have no Spark-specific casting we delegate to DataFusion
cast_with_options(&array, to_type, &CAST_OPTIONS)?
// we should never reach this code because the Scala code should be checking
// for supported cast operations and falling back to Spark for anything that
// is not yet supported
Err(CometError::Internal(format!(
"Native cast invoked for unsupported cast from {from_type:?} to {to_type:?}"
)))
}
};
Ok(spark_cast(cast_result, from_type, to_type))
Ok(spark_cast(cast_result?, from_type, to_type))
}

/// Determines if DataFusion supports the given cast in a way that is
/// compatible with Spark
fn is_datafusion_spark_compatible(from_type: &DataType, to_type: &DataType) -> bool {
if from_type == to_type {
return true;
}
match from_type {
DataType::Boolean => matches!(
to_type,
DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::Float32
| DataType::Float64
| DataType::Utf8
),
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => {
// note that the cast from Int32/Int64 -> Decimal128 here is actually
// not compatible with Spark (no overflow checks) but we have tests that
// rely on this cast working so we have to leave it here for now
matches!(
to_type,
DataType::Boolean
| DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::Float32
| DataType::Float64
| DataType::Decimal128(_, _)
| DataType::Utf8
)
}
DataType::Float32 | DataType::Float64 => matches!(
to_type,
DataType::Boolean
| DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::Float32
| DataType::Float64
),
DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => matches!(
to_type,
DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::Float32
| DataType::Float64
| DataType::Decimal128(_, _)
| DataType::Decimal256(_, _)
),
DataType::Utf8 => matches!(to_type, DataType::Binary),
DataType::Date32 => matches!(to_type, DataType::Utf8),
DataType::Timestamp(_, _) => {
matches!(
to_type,
DataType::Int64 | DataType::Date32 | DataType::Utf8 | DataType::Timestamp(_, _)
)
}
DataType::Binary => {
// note that this is not completely Spark compatible because
// DataFusion only supports binary data containing valid UTF-8 strings
matches!(to_type, DataType::Utf8)
}
_ => false,
}
}

fn cast_string_to_int<OffsetSize: OffsetSizeTrait>(

0 comments on commit 79431f8

Please sign in to comment.