From b39ed8823a6bf5ce7b624918c37cec40afcf7b36 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E0=B0=97=E0=B0=A3=E0=B1=87=E0=B0=B7=E0=B1=8D?= Date: Sat, 4 May 2024 03:37:47 +0530 Subject: [PATCH] feat: Implement Spark-compatible CAST between integer types (#340) * handled cast for long to short * handled cast for all overflow cases * ran make format * added check for overflow exception for 3.4 below. * added comments to on why we do overflow check. added a check before we fetch the sparkInvalidValue * -1 instead of 0, -1 indicates the provided character is not present * ran mvn spotless:apply * check for presence of ':' and have asserts accordingly * reusing exising test functions * added one more check in assert when ':' is not present * redo the compare logic as per andy's suggestions. --------- Co-authored-by: ganesh.maddula --- core/src/errors.rs | 9 ++ .../execution/datafusion/expressions/cast.rs | 98 +++++++++++++++++++ .../org/apache/comet/CometCastSuite.scala | 37 ++++--- 3 files changed, 131 insertions(+), 13 deletions(-) diff --git a/core/src/errors.rs b/core/src/errors.rs index f02bd1969..a06c613ad 100644 --- a/core/src/errors.rs +++ b/core/src/errors.rs @@ -72,6 +72,15 @@ pub enum CometError { to_type: String, }, + #[error("[CAST_OVERFLOW] The value {value} of the type \"{from_type}\" cannot be cast to \"{to_type}\" \ + due to an overflow. Use `try_cast` to tolerate overflow and return NULL instead. If necessary \ + set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + CastOverFlow { + value: String, + from_type: String, + to_type: String, + }, + #[error(transparent)] Arrow { #[from] diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index 45859c5fb..a6e3adaca 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -176,6 +176,62 @@ macro_rules! cast_float_to_string { }}; } +macro_rules! cast_int_to_int_macro { + ( + $array: expr, + $eval_mode:expr, + $from_arrow_primitive_type: ty, + $to_arrow_primitive_type: ty, + $from_data_type: expr, + $to_native_type: ty, + $spark_from_data_type_name: expr, + $spark_to_data_type_name: expr + ) => {{ + let cast_array = $array + .as_any() + .downcast_ref::>() + .unwrap(); + let spark_int_literal_suffix = match $from_data_type { + &DataType::Int64 => "L", + &DataType::Int16 => "S", + &DataType::Int8 => "T", + _ => "", + }; + + let output_array = match $eval_mode { + EvalMode::Legacy => cast_array + .iter() + .map(|value| match value { + Some(value) => { + Ok::, CometError>(Some(value as $to_native_type)) + } + _ => Ok(None), + }) + .collect::, _>>(), + _ => cast_array + .iter() + .map(|value| match value { + Some(value) => { + let res = <$to_native_type>::try_from(value); + if res.is_err() { + Err(CometError::CastOverFlow { + value: value.to_string() + spark_int_literal_suffix, + from_type: $spark_from_data_type_name.to_string(), + to_type: $spark_to_data_type_name.to_string(), + }) + } else { + Ok::, CometError>(Some(res.unwrap())) + } + } + _ => Ok(None), + }) + .collect::, _>>(), + }?; + let result: CometResult = Ok(Arc::new(output_array) as ArrayRef); + result + }}; +} + impl Cast { pub fn new( child: Arc, @@ -218,6 +274,16 @@ impl Cast { (DataType::Utf8, DataType::Timestamp(_, _)) => { Self::cast_string_to_timestamp(&array, to_type, self.eval_mode)? } + (DataType::Int64, DataType::Int32) + | (DataType::Int64, DataType::Int16) + | (DataType::Int64, DataType::Int8) + | (DataType::Int32, DataType::Int16) + | (DataType::Int32, DataType::Int8) + | (DataType::Int16, DataType::Int8) + if self.eval_mode != EvalMode::Try => + { + Self::spark_cast_int_to_int(&array, self.eval_mode, from_type, to_type)? + } ( DataType::Utf8, DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64, @@ -349,6 +415,38 @@ impl Cast { cast_float_to_string!(from, _eval_mode, f32, Float32Array, OffsetSize) } + fn spark_cast_int_to_int( + array: &dyn Array, + eval_mode: EvalMode, + from_type: &DataType, + to_type: &DataType, + ) -> CometResult { + match (from_type, to_type) { + (DataType::Int64, DataType::Int32) => cast_int_to_int_macro!( + array, eval_mode, Int64Type, Int32Type, from_type, i32, "BIGINT", "INT" + ), + (DataType::Int64, DataType::Int16) => cast_int_to_int_macro!( + array, eval_mode, Int64Type, Int16Type, from_type, i16, "BIGINT", "SMALLINT" + ), + (DataType::Int64, DataType::Int8) => cast_int_to_int_macro!( + array, eval_mode, Int64Type, Int8Type, from_type, i8, "BIGINT", "TINYINT" + ), + (DataType::Int32, DataType::Int16) => cast_int_to_int_macro!( + array, eval_mode, Int32Type, Int16Type, from_type, i16, "INT", "SMALLINT" + ), + (DataType::Int32, DataType::Int8) => cast_int_to_int_macro!( + array, eval_mode, Int32Type, Int8Type, from_type, i8, "INT", "TINYINT" + ), + (DataType::Int16, DataType::Int8) => cast_int_to_int_macro!( + array, eval_mode, Int16Type, Int8Type, from_type, i8, "SMALLINT", "TINYINT" + ), + _ => unreachable!( + "{}", + format!("invalid integer type {to_type} in cast from {from_type}") + ), + } + } + fn spark_cast_utf8_to_boolean( from: &dyn Array, eval_mode: EvalMode, diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index 54b136791..483301e02 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -166,7 +166,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { castTest(generateShorts(), DataTypes.BooleanType) } - ignore("cast ShortType to ByteType") { + test("cast ShortType to ByteType") { // https://github.com/apache/datafusion-comet/issues/311 castTest(generateShorts(), DataTypes.ByteType) } @@ -210,12 +210,12 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { castTest(generateInts(), DataTypes.BooleanType) } - ignore("cast IntegerType to ByteType") { + test("cast IntegerType to ByteType") { // https://github.com/apache/datafusion-comet/issues/311 castTest(generateInts(), DataTypes.ByteType) } - ignore("cast IntegerType to ShortType") { + test("cast IntegerType to ShortType") { // https://github.com/apache/datafusion-comet/issues/311 castTest(generateInts(), DataTypes.ShortType) } @@ -256,17 +256,17 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { castTest(generateLongs(), DataTypes.BooleanType) } - ignore("cast LongType to ByteType") { + test("cast LongType to ByteType") { // https://github.com/apache/datafusion-comet/issues/311 castTest(generateLongs(), DataTypes.ByteType) } - ignore("cast LongType to ShortType") { + test("cast LongType to ShortType") { // https://github.com/apache/datafusion-comet/issues/311 castTest(generateLongs(), DataTypes.ShortType) } - ignore("cast LongType to IntegerType") { + test("cast LongType to IntegerType") { // https://github.com/apache/datafusion-comet/issues/311 castTest(generateLongs(), DataTypes.IntegerType) } @@ -921,15 +921,26 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { val cometMessage = cometException.getCause.getMessage .replace("Execution error: ", "") if (CometSparkSessionExtensions.isSpark34Plus) { + // for Spark 3.4 we expect to reproduce the error message exactly assert(cometMessage == sparkMessage) + } else if (CometSparkSessionExtensions.isSpark33Plus) { + // for Spark 3.3 we just need to strip the prefix from the Comet message + // before comparing + val cometMessageModified = cometMessage + .replace("[CAST_INVALID_INPUT] ", "") + .replace("[CAST_OVERFLOW] ", "") + assert(cometMessageModified == sparkMessage) } else { - // Spark 3.2 and 3.3 have a different error message format so we can't do a direct - // comparison between Spark and Comet. - // Spark message is in format `invalid input syntax for type TYPE: VALUE` - // Comet message is in format `The value 'VALUE' of the type FROM_TYPE cannot be cast to TO_TYPE` - // We just check that the comet message contains the same invalid value as the Spark message - val sparkInvalidValue = sparkMessage.substring(sparkMessage.indexOf(':') + 2) - assert(cometMessage.contains(sparkInvalidValue)) + // for Spark 3.2 we just make sure we are seeing a similar type of error + if (sparkMessage.contains("causes overflow")) { + assert(cometMessage.contains("due to an overflow")) + } else { + // assume that this is an invalid input message in the form: + // `invalid input syntax for type numeric: -9223372036854775809` + // we just check that the Comet message contains the same literal value + val sparkInvalidValue = sparkMessage.substring(sparkMessage.indexOf(':') + 2) + assert(cometMessage.contains(sparkInvalidValue)) + } } }