Skip to content

Commit

Permalink
feat: Implement Spark-compatible CAST between integer types (#340)
Browse files Browse the repository at this point in the history
* handled cast for long to short

* handled cast for all overflow cases

* ran make format

* added check for overflow exception for 3.4 below.

* added comments to on why we do overflow check.
added a check before we fetch the sparkInvalidValue

* -1 instead of 0, -1 indicates the provided character is not present

* ran mvn spotless:apply

* check for presence of ':' and have asserts accordingly

* reusing exising test functions

* added one more check in assert when ':' is not present

* redo the compare logic as per andy's suggestions.

---------

Co-authored-by: ganesh.maddula <[email protected]>
  • Loading branch information
ganeshkumar269 and ganesh.maddula authored May 3, 2024
1 parent 2741ae7 commit b39ed88
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 13 deletions.
9 changes: 9 additions & 0 deletions core/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,15 @@ pub enum CometError {
to_type: String,
},

#[error("[CAST_OVERFLOW] The value {value} of the type \"{from_type}\" cannot be cast to \"{to_type}\" \
due to an overflow. Use `try_cast` to tolerate overflow and return NULL instead. If necessary \
set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")]
CastOverFlow {
value: String,
from_type: String,
to_type: String,
},

#[error(transparent)]
Arrow {
#[from]
Expand Down
98 changes: 98 additions & 0 deletions core/src/execution/datafusion/expressions/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,62 @@ macro_rules! cast_float_to_string {
}};
}

macro_rules! cast_int_to_int_macro {
(
$array: expr,
$eval_mode:expr,
$from_arrow_primitive_type: ty,
$to_arrow_primitive_type: ty,
$from_data_type: expr,
$to_native_type: ty,
$spark_from_data_type_name: expr,
$spark_to_data_type_name: expr
) => {{
let cast_array = $array
.as_any()
.downcast_ref::<PrimitiveArray<$from_arrow_primitive_type>>()
.unwrap();
let spark_int_literal_suffix = match $from_data_type {
&DataType::Int64 => "L",
&DataType::Int16 => "S",
&DataType::Int8 => "T",
_ => "",
};

let output_array = match $eval_mode {
EvalMode::Legacy => cast_array
.iter()
.map(|value| match value {
Some(value) => {
Ok::<Option<$to_native_type>, CometError>(Some(value as $to_native_type))
}
_ => Ok(None),
})
.collect::<Result<PrimitiveArray<$to_arrow_primitive_type>, _>>(),
_ => cast_array
.iter()
.map(|value| match value {
Some(value) => {
let res = <$to_native_type>::try_from(value);
if res.is_err() {
Err(CometError::CastOverFlow {
value: value.to_string() + spark_int_literal_suffix,
from_type: $spark_from_data_type_name.to_string(),
to_type: $spark_to_data_type_name.to_string(),
})
} else {
Ok::<Option<$to_native_type>, CometError>(Some(res.unwrap()))
}
}
_ => Ok(None),
})
.collect::<Result<PrimitiveArray<$to_arrow_primitive_type>, _>>(),
}?;
let result: CometResult<ArrayRef> = Ok(Arc::new(output_array) as ArrayRef);
result
}};
}

impl Cast {
pub fn new(
child: Arc<dyn PhysicalExpr>,
Expand Down Expand Up @@ -218,6 +274,16 @@ impl Cast {
(DataType::Utf8, DataType::Timestamp(_, _)) => {
Self::cast_string_to_timestamp(&array, to_type, self.eval_mode)?
}
(DataType::Int64, DataType::Int32)
| (DataType::Int64, DataType::Int16)
| (DataType::Int64, DataType::Int8)
| (DataType::Int32, DataType::Int16)
| (DataType::Int32, DataType::Int8)
| (DataType::Int16, DataType::Int8)
if self.eval_mode != EvalMode::Try =>
{
Self::spark_cast_int_to_int(&array, self.eval_mode, from_type, to_type)?
}
(
DataType::Utf8,
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64,
Expand Down Expand Up @@ -349,6 +415,38 @@ impl Cast {
cast_float_to_string!(from, _eval_mode, f32, Float32Array, OffsetSize)
}

fn spark_cast_int_to_int(
array: &dyn Array,
eval_mode: EvalMode,
from_type: &DataType,
to_type: &DataType,
) -> CometResult<ArrayRef> {
match (from_type, to_type) {
(DataType::Int64, DataType::Int32) => cast_int_to_int_macro!(
array, eval_mode, Int64Type, Int32Type, from_type, i32, "BIGINT", "INT"
),
(DataType::Int64, DataType::Int16) => cast_int_to_int_macro!(
array, eval_mode, Int64Type, Int16Type, from_type, i16, "BIGINT", "SMALLINT"
),
(DataType::Int64, DataType::Int8) => cast_int_to_int_macro!(
array, eval_mode, Int64Type, Int8Type, from_type, i8, "BIGINT", "TINYINT"
),
(DataType::Int32, DataType::Int16) => cast_int_to_int_macro!(
array, eval_mode, Int32Type, Int16Type, from_type, i16, "INT", "SMALLINT"
),
(DataType::Int32, DataType::Int8) => cast_int_to_int_macro!(
array, eval_mode, Int32Type, Int8Type, from_type, i8, "INT", "TINYINT"
),
(DataType::Int16, DataType::Int8) => cast_int_to_int_macro!(
array, eval_mode, Int16Type, Int8Type, from_type, i8, "SMALLINT", "TINYINT"
),
_ => unreachable!(
"{}",
format!("invalid integer type {to_type} in cast from {from_type}")
),
}
}

fn spark_cast_utf8_to_boolean<OffsetSize>(
from: &dyn Array,
eval_mode: EvalMode,
Expand Down
37 changes: 24 additions & 13 deletions spark/src/test/scala/org/apache/comet/CometCastSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
castTest(generateShorts(), DataTypes.BooleanType)
}

ignore("cast ShortType to ByteType") {
test("cast ShortType to ByteType") {
// https://github.com/apache/datafusion-comet/issues/311
castTest(generateShorts(), DataTypes.ByteType)
}
Expand Down Expand Up @@ -210,12 +210,12 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
castTest(generateInts(), DataTypes.BooleanType)
}

ignore("cast IntegerType to ByteType") {
test("cast IntegerType to ByteType") {
// https://github.com/apache/datafusion-comet/issues/311
castTest(generateInts(), DataTypes.ByteType)
}

ignore("cast IntegerType to ShortType") {
test("cast IntegerType to ShortType") {
// https://github.com/apache/datafusion-comet/issues/311
castTest(generateInts(), DataTypes.ShortType)
}
Expand Down Expand Up @@ -256,17 +256,17 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
castTest(generateLongs(), DataTypes.BooleanType)
}

ignore("cast LongType to ByteType") {
test("cast LongType to ByteType") {
// https://github.com/apache/datafusion-comet/issues/311
castTest(generateLongs(), DataTypes.ByteType)
}

ignore("cast LongType to ShortType") {
test("cast LongType to ShortType") {
// https://github.com/apache/datafusion-comet/issues/311
castTest(generateLongs(), DataTypes.ShortType)
}

ignore("cast LongType to IntegerType") {
test("cast LongType to IntegerType") {
// https://github.com/apache/datafusion-comet/issues/311
castTest(generateLongs(), DataTypes.IntegerType)
}
Expand Down Expand Up @@ -921,15 +921,26 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
val cometMessage = cometException.getCause.getMessage
.replace("Execution error: ", "")
if (CometSparkSessionExtensions.isSpark34Plus) {
// for Spark 3.4 we expect to reproduce the error message exactly
assert(cometMessage == sparkMessage)
} else if (CometSparkSessionExtensions.isSpark33Plus) {
// for Spark 3.3 we just need to strip the prefix from the Comet message
// before comparing
val cometMessageModified = cometMessage
.replace("[CAST_INVALID_INPUT] ", "")
.replace("[CAST_OVERFLOW] ", "")
assert(cometMessageModified == sparkMessage)
} else {
// Spark 3.2 and 3.3 have a different error message format so we can't do a direct
// comparison between Spark and Comet.
// Spark message is in format `invalid input syntax for type TYPE: VALUE`
// Comet message is in format `The value 'VALUE' of the type FROM_TYPE cannot be cast to TO_TYPE`
// We just check that the comet message contains the same invalid value as the Spark message
val sparkInvalidValue = sparkMessage.substring(sparkMessage.indexOf(':') + 2)
assert(cometMessage.contains(sparkInvalidValue))
// for Spark 3.2 we just make sure we are seeing a similar type of error
if (sparkMessage.contains("causes overflow")) {
assert(cometMessage.contains("due to an overflow"))
} else {
// assume that this is an invalid input message in the form:
// `invalid input syntax for type numeric: -9223372036854775809`
// we just check that the Comet message contains the same literal value
val sparkInvalidValue = sparkMessage.substring(sparkMessage.indexOf(':') + 2)
assert(cometMessage.contains(sparkInvalidValue))
}
}
}

Expand Down

0 comments on commit b39ed88

Please sign in to comment.