diff --git a/core/src/errors.rs b/core/src/errors.rs index 5a410e9f1..a06c613ad 100644 --- a/core/src/errors.rs +++ b/core/src/errors.rs @@ -71,8 +71,7 @@ pub enum CometError { from_type: String, to_type: String, }, - // Note that this message format is based on Spark 3.4 and is more detailed than the message - // returned by Spark 3.2 or 3.3 + #[error("[CAST_OVERFLOW] The value {value} of the type \"{from_type}\" cannot be cast to \"{to_type}\" \ due to an overflow. Use `try_cast` to tolerate overflow and return NULL instead. If necessary \ set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index 4c9a50ad8..f7cbb5ee2 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -28,7 +28,8 @@ use arrow::{ record_batch::RecordBatch, util::display::FormatOptions, }; -use arrow_array::{Array, ArrayRef, BooleanArray, GenericStringArray, Int16Array, Int64Array, OffsetSizeTrait}; +use arrow_array::{Array, ArrayRef, BooleanArray, GenericStringArray, OffsetSizeTrait, PrimitiveArray}; +use arrow_array::types::{Int16Type, Int32Type, Int64Type, Int8Type}; use arrow_schema::{DataType, Schema}; use datafusion::logical_expr::ColumnarValue; use datafusion_common::{internal_err, Result as DataFusionResult, ScalarValue}; @@ -64,6 +65,62 @@ pub struct Cast { pub timezone: String, } +macro_rules! cast_int_to_int_macro{ + ( + $array: expr, + $eval_mode:expr, + $from_arrow_primitive_type: ty, + $to_arrow_primitive_type: ty, + $from_data_type: expr, + $to_native_type: ty, + $spark_from_data_type_name: expr, + $spark_to_data_type_name: expr + ) => {{ + let cast_array = $array + .as_any() + .downcast_ref::>() + .unwrap(); + let spark_int_literal_suffix = match $from_data_type { + &DataType::Int64 => "L", + &DataType::Int16 => "S", + &DataType::Int8 => "T", + _ => "" + }; + + let output_array = match $eval_mode { + EvalMode::Legacy => + cast_array.iter() + .map(|value| match value { + Some(value) => Ok::, CometError>(Some(value as $to_native_type)), + _ => Ok(None) + }) + .collect::, _>>(), + _ => { + cast_array.iter() + .map(|value| match value{ + Some(value) => { + let res = <$to_native_type>::try_from(value); + if res.is_err() { + Err(CometError::CastOverFlow{ + value: value.to_string() + spark_int_literal_suffix, + from_type: $spark_from_data_type_name.to_string(), + to_type: $spark_to_data_type_name.to_string(), + }) + }else{ + Ok::, CometError>(Some(res.unwrap())) + } + + }, + _ => Ok(None) + }) + .collect::, _>>() + } + }?; + let result: CometResult = Ok(Arc::new(output_array) as ArrayRef); + result + }}; +} + impl Cast { pub fn new( child: Arc, @@ -103,56 +160,46 @@ impl Cast { (DataType::LargeUtf8, DataType::Boolean) => { Self::spark_cast_utf8_to_boolean::(&array, self.eval_mode)? } - (DataType::Int64, DataType::Int16) if self.eval_mode != EvalMode::Try => { - // (DataType::Int64, DataType::Int16) => { - Self::spark_cast_int64_to_int16(&array, self.eval_mode)? + (DataType::Int64, DataType::Int32) + | (DataType::Int64, DataType::Int16) + | (DataType::Int64, DataType::Int8) + | (DataType::Int32, DataType::Int16) + | (DataType::Int32, DataType::Int8) + | (DataType::Int16, DataType::Int8) + if self.eval_mode != EvalMode::Try => { + Self::spark_cast_int_to_int(&array, self.eval_mode, from_type, to_type)? } _ => cast_with_options(&array, to_type, &CAST_OPTIONS)?, }; let result = spark_cast(cast_result, from_type, to_type); Ok(result) } - fn spark_cast_int64_to_int16( - from: &dyn Array, + + fn spark_cast_int_to_int( + array: &dyn Array, eval_mode: EvalMode, + from_type: &DataType, + to_type: &DataType, ) -> CometResult { - let array = from - .as_any() - .downcast_ref::() - .unwrap(); - - let output_array = match eval_mode { - EvalMode::Legacy => { - array.iter() - .map(|value| match value{ - Some(value) => Ok::, CometError>(Some(value as i16)), - _ => Ok(None) - }) - .collect::>()? - }, - _ => { - array.iter() - .map(|value| match value{ - Some(value) => { - let res = i16::try_from(value); - if res.is_err() { - Err(CometError::CastOverFlow{ - value: value.to_string() + "L", - from_type: "BIGINT".to_string(), - to_type: "SMALLINT".to_string(), - }) - }else{ - Ok::, CometError>(Some(i16::try_from(value).unwrap())) - } - - }, - _ => Ok(None) - }) - .collect::>()? - } - }; - Ok(Arc::new(output_array)) + match (from_type, to_type) { + (DataType::Int64, DataType::Int32) => + cast_int_to_int_macro!(array, eval_mode, Int64Type, Int32Type, from_type, i32, "BIGINT", "INT"), + (DataType::Int64, DataType::Int16) => + cast_int_to_int_macro!(array, eval_mode, Int64Type, Int16Type, from_type, i16, "BIGINT", "SMALLINT"), + (DataType::Int64, DataType::Int8) => + cast_int_to_int_macro!(array, eval_mode, Int64Type, Int8Type, from_type, i8, "BIGINT", "TINYINT"), + (DataType::Int32, DataType::Int16) => + cast_int_to_int_macro!(array, eval_mode, Int32Type, Int16Type, from_type, i16, "INT", "SMALLINT"), + (DataType::Int32, DataType::Int8) => + cast_int_to_int_macro!(array, eval_mode, Int32Type, Int8Type, from_type, i8, "INT", "TINYINT"), + (DataType::Int16, DataType::Int8) => + cast_int_to_int_macro!(array, eval_mode, Int16Type, Int8Type, from_type, i8, "SMALLINT", "TINYINT"), + _ => unreachable!( + "{}", + format!("invalid integer type {to_type} in cast from {from_type}") + ), + } } fn spark_cast_utf8_to_boolean( diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index 2f42640a5..54b08ae0e 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -43,10 +43,6 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { private val datePattern = "0123456789/" + whitespaceChars private val timestampPattern = "0123456789/:T" + whitespaceChars -// ignore("cast long to short") { -// castTest(generateLongs, DataTypes.ShortType) -// } -// ignore("cast float to bool") { castTest(generateFloats, DataTypes.BooleanType) } @@ -106,26 +102,29 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { castTest(values.toDF("a"), DataTypes.DoubleType) } - // spotless:off - test("cast short to int"){ - + test("cast short to byte") { + castTest(generateShorts, DataTypes.ByteType) } - test("cast short to long"){ + test("cast int to byte") { + castTest(generateInts, DataTypes.ByteType) } - test("cast int to short"){ + test("cast int to short") { + castTest(generateInts, DataTypes.ShortType) } - test("cast int to long"){ + test("cast long to byte") { + castTest(generateLongs, DataTypes.ByteType) } - test("cast long to short"){ + + test("cast long to short") { castTest(generateLongs, DataTypes.ShortType) } - test("cast long to int"){ + test("cast long to int") { + castTest(generateLongs, DataTypes.IntegerType) } - // spotless:on private def generateFloats(): DataFrame = { val r = new Random(0) @@ -137,6 +136,18 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { (Range(0, dataSize).map(_ => r.nextLong()) ++ Seq(Long.MaxValue, Long.MinValue)).toDF("a") } + private def generateInts(): DataFrame = { + val r = new Random(0) + (Range(0, dataSize).map(_ => r.nextInt()) ++ Seq(Int.MaxValue, Int.MinValue)).toDF("a") + } + + private def generateShorts(): DataFrame = { + val r = new Random(0) + (Range(0, dataSize).map(_ => r.nextInt(Short.MaxValue).toShort) ++ Seq( + Short.MaxValue, + Short.MinValue)).toDF("a") + } + private def generateString(r: Random, chars: String, maxLen: Int): String = { val len = r.nextInt(maxLen) Range(0, len).map(_ => chars.charAt(r.nextInt(chars.length))).mkString