-
Notifications
You must be signed in to change notification settings - Fork 166
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: Only delegate to DataFusion cast when we know that it is compatible with Spark #461
Changes from 13 commits
3a1c387
60354bd
35585aa
dc4f99c
f8a6f94
2b21917
c2c2546
7f1951a
c0f913b
34edf00
54043ae
2f22e06
bc18fae
340e000
7d804f5
3fb2258
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -503,41 +503,37 @@ impl Cast { | |
fn cast_array(&self, array: ArrayRef) -> DataFusionResult<ArrayRef> { | ||
let to_type = &self.data_type; | ||
let array = array_with_timezone(array, self.timezone.clone(), Some(to_type)); | ||
let from_type = array.data_type().clone(); | ||
|
||
// unpack dictionary string arrays first | ||
// TODO: we are unpacking a dictionary-encoded array and then performing | ||
// the cast. We could potentially improve performance here by casting the | ||
// dictionary values directly without unpacking the array first, although this | ||
// would add more complexity to the code | ||
let array = match &from_type { | ||
DataType::Dictionary(key_type, value_type) | ||
if key_type.as_ref() == &DataType::Int32 | ||
&& (value_type.as_ref() == &DataType::Utf8 | ||
|| value_type.as_ref() == &DataType::LargeUtf8) => | ||
{ | ||
cast_with_options(&array, value_type.as_ref(), &CAST_OPTIONS)? | ||
} | ||
_ => array, | ||
}; | ||
let from_type = array.data_type(); | ||
|
||
let cast_result = match (from_type, to_type) { | ||
(DataType::Utf8, DataType::Boolean) => { | ||
Self::spark_cast_utf8_to_boolean::<i32>(&array, self.eval_mode)? | ||
Self::spark_cast_utf8_to_boolean::<i32>(&array, self.eval_mode) | ||
} | ||
(DataType::LargeUtf8, DataType::Boolean) => { | ||
Self::spark_cast_utf8_to_boolean::<i64>(&array, self.eval_mode)? | ||
Self::spark_cast_utf8_to_boolean::<i64>(&array, self.eval_mode) | ||
} | ||
(DataType::Utf8, DataType::Timestamp(_, _)) => { | ||
Self::cast_string_to_timestamp(&array, to_type, self.eval_mode)? | ||
Self::cast_string_to_timestamp(&array, to_type, self.eval_mode) | ||
} | ||
(DataType::Utf8, DataType::Date32) => { | ||
Self::cast_string_to_date(&array, to_type, self.eval_mode)? | ||
} | ||
(DataType::Dictionary(key_type, value_type), DataType::Date32) | ||
if key_type.as_ref() == &DataType::Int32 | ||
&& (value_type.as_ref() == &DataType::Utf8 | ||
|| value_type.as_ref() == &DataType::LargeUtf8) => | ||
{ | ||
match value_type.as_ref() { | ||
DataType::Utf8 => { | ||
let unpacked_array = | ||
cast_with_options(&array, &DataType::Utf8, &CAST_OPTIONS)?; | ||
Self::cast_string_to_date(&unpacked_array, to_type, self.eval_mode)? | ||
} | ||
DataType::LargeUtf8 => { | ||
let unpacked_array = | ||
cast_with_options(&array, &DataType::LargeUtf8, &CAST_OPTIONS)?; | ||
Self::cast_string_to_date(&unpacked_array, to_type, self.eval_mode)? | ||
} | ||
dt => unreachable!( | ||
"{}", | ||
format!("invalid value type {dt} for dictionary-encoded string array") | ||
), | ||
} | ||
Self::cast_string_to_date(&array, to_type, self.eval_mode) | ||
} | ||
(DataType::Int64, DataType::Int32) | ||
| (DataType::Int64, DataType::Int16) | ||
|
@@ -547,61 +543,33 @@ impl Cast { | |
| (DataType::Int16, DataType::Int8) | ||
if self.eval_mode != EvalMode::Try => | ||
{ | ||
Self::spark_cast_int_to_int(&array, self.eval_mode, from_type, to_type)? | ||
Self::spark_cast_int_to_int(&array, self.eval_mode, from_type, to_type) | ||
} | ||
( | ||
DataType::Utf8, | ||
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64, | ||
) => Self::cast_string_to_int::<i32>(to_type, &array, self.eval_mode)?, | ||
) => Self::cast_string_to_int::<i32>(to_type, &array, self.eval_mode), | ||
( | ||
DataType::LargeUtf8, | ||
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64, | ||
) => Self::cast_string_to_int::<i64>(to_type, &array, self.eval_mode)?, | ||
( | ||
DataType::Dictionary(key_type, value_type), | ||
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64, | ||
) if key_type.as_ref() == &DataType::Int32 | ||
&& (value_type.as_ref() == &DataType::Utf8 | ||
|| value_type.as_ref() == &DataType::LargeUtf8) => | ||
{ | ||
// TODO: we are unpacking a dictionary-encoded array and then performing | ||
// the cast. We could potentially improve performance here by casting the | ||
// dictionary values directly without unpacking the array first, although this | ||
// would add more complexity to the code | ||
match value_type.as_ref() { | ||
DataType::Utf8 => { | ||
let unpacked_array = | ||
cast_with_options(&array, &DataType::Utf8, &CAST_OPTIONS)?; | ||
Self::cast_string_to_int::<i32>(to_type, &unpacked_array, self.eval_mode)? | ||
} | ||
DataType::LargeUtf8 => { | ||
let unpacked_array = | ||
cast_with_options(&array, &DataType::LargeUtf8, &CAST_OPTIONS)?; | ||
Self::cast_string_to_int::<i64>(to_type, &unpacked_array, self.eval_mode)? | ||
} | ||
dt => unreachable!( | ||
"{}", | ||
format!("invalid value type {dt} for dictionary-encoded string array") | ||
), | ||
} | ||
} | ||
) => Self::cast_string_to_int::<i64>(to_type, &array, self.eval_mode), | ||
(DataType::Float64, DataType::Utf8) => { | ||
Self::spark_cast_float64_to_utf8::<i32>(&array, self.eval_mode)? | ||
Self::spark_cast_float64_to_utf8::<i32>(&array, self.eval_mode) | ||
} | ||
(DataType::Float64, DataType::LargeUtf8) => { | ||
Self::spark_cast_float64_to_utf8::<i64>(&array, self.eval_mode)? | ||
Self::spark_cast_float64_to_utf8::<i64>(&array, self.eval_mode) | ||
} | ||
(DataType::Float32, DataType::Utf8) => { | ||
Self::spark_cast_float32_to_utf8::<i32>(&array, self.eval_mode)? | ||
Self::spark_cast_float32_to_utf8::<i32>(&array, self.eval_mode) | ||
} | ||
(DataType::Float32, DataType::LargeUtf8) => { | ||
Self::spark_cast_float32_to_utf8::<i64>(&array, self.eval_mode)? | ||
Self::spark_cast_float32_to_utf8::<i64>(&array, self.eval_mode) | ||
} | ||
(DataType::Float32, DataType::Decimal128(precision, scale)) => { | ||
Self::cast_float32_to_decimal128(&array, *precision, *scale, self.eval_mode)? | ||
Self::cast_float32_to_decimal128(&array, *precision, *scale, self.eval_mode) | ||
} | ||
(DataType::Float64, DataType::Decimal128(precision, scale)) => { | ||
Self::cast_float64_to_decimal128(&array, *precision, *scale, self.eval_mode)? | ||
Self::cast_float64_to_decimal128(&array, *precision, *scale, self.eval_mode) | ||
} | ||
(DataType::Float32, DataType::Int8) | ||
| (DataType::Float32, DataType::Int16) | ||
|
@@ -622,14 +590,89 @@ impl Cast { | |
self.eval_mode, | ||
from_type, | ||
to_type, | ||
)? | ||
) | ||
} | ||
_ if Self::is_datafusion_spark_compatible(from_type, to_type) => { | ||
// use DataFusion cast only when we know that it is compatible with Spark | ||
Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?) | ||
} | ||
_ => { | ||
// when we have no Spark-specific casting we delegate to DataFusion | ||
cast_with_options(&array, to_type, &CAST_OPTIONS)? | ||
// we should never reach this code because the Scala code should be checking | ||
// for supported cast operations and falling back to Spark for anything that | ||
// is not yet supported | ||
Err(CometError::Internal(format!( | ||
"Native cast invoked for unsupported cast from {from_type:?} to {to_type:?}" | ||
))) | ||
} | ||
}; | ||
Ok(spark_cast(cast_result, from_type, to_type)) | ||
Ok(spark_cast(cast_result?, from_type, to_type)) | ||
} | ||
|
||
/// Determines if DataFusion supports the given cast in a way that is | ||
/// compatible with Spark | ||
fn is_datafusion_spark_compatible(from_type: &DataType, to_type: &DataType) -> bool { | ||
if from_type == to_type { | ||
return true; | ||
} | ||
match from_type { | ||
DataType::Boolean => matches!( | ||
to_type, | ||
DataType::Int8 | ||
| DataType::Int16 | ||
| DataType::Int32 | ||
| DataType::Int64 | ||
| DataType::Float32 | ||
| DataType::Float64 | ||
| DataType::Utf8 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So right now, there is not There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current code says that datafusion is compatible with Spark for all int types -> decimal: DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => matches!(
to_type,
DataType::Boolean
...
| DataType::Decimal128(_, _) However, this is actually not correct since DataFusion does not have overflow checks for int32 and int64 -> decimal and is not compatible with Spark. I will look at removing those. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removing that case causes a test failure:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test relies on a cast that we do not yet support and enables |
||
), | ||
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => matches!( | ||
to_type, | ||
DataType::Boolean | ||
| DataType::Int8 | ||
| DataType::Int16 | ||
| DataType::Int32 | ||
| DataType::Int64 | ||
| DataType::Float32 | ||
| DataType::Float64 | ||
| DataType::Decimal128(_, _) | ||
| DataType::Utf8 | ||
), | ||
DataType::Float32 | DataType::Float64 => matches!( | ||
to_type, | ||
DataType::Boolean | ||
| DataType::Int8 | ||
| DataType::Int16 | ||
| DataType::Int32 | ||
| DataType::Int64 | ||
| DataType::Float32 | ||
| DataType::Float64 | ||
Comment on lines
+645
to
+653
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, that is correct. |
||
), | ||
DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => matches!( | ||
viirya marked this conversation as resolved.
Show resolved
Hide resolved
|
||
to_type, | ||
DataType::Int8 | ||
| DataType::Int16 | ||
| DataType::Int32 | ||
| DataType::Int64 | ||
| DataType::Float32 | ||
| DataType::Float64 | ||
| DataType::Decimal128(_, _) | ||
| DataType::Decimal256(_, _) | ||
), | ||
DataType::Utf8 => matches!(to_type, DataType::Binary), | ||
DataType::Date32 => matches!(to_type, DataType::Utf8), | ||
DataType::Timestamp(_, _) => { | ||
matches!( | ||
to_type, | ||
DataType::Int64 | DataType::Date32 | DataType::Utf8 | DataType::Timestamp(_, _) | ||
) | ||
} | ||
DataType::Binary => { | ||
// note that this is not completely Spark compatible because | ||
// DataFusion only supports binary data containing valid UTF-8 strings | ||
matches!(to_type, DataType::Utf8) | ||
} | ||
_ => false, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Casting to narrower type like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Casting from DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => matches!(
to_type,
DataType::Boolean
| DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::Float32
| DataType::Float64
| DataType::Decimal128(_, _)
| DataType::Utf8
), |
||
} | ||
} | ||
|
||
fn cast_string_to_int<OffsetSize: OffsetSizeTrait>( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We were previously unpacking dictionary-encoded string arrays only for string to int and string to date. I just moved it earlier on so that we don't have to handle it specifically for certain casts from string