diff --git a/core/src/errors.rs b/core/src/errors.rs index a06c613ad..04a1629d5 100644 --- a/core/src/errors.rs +++ b/core/src/errors.rs @@ -72,6 +72,13 @@ pub enum CometError { to_type: String, }, + #[error("[NUMERIC_VALUE_OUT_OF_RANGE] {value} cannot be represented as Decimal({precision}, {scale}). If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error, and return NULL instead.")] + NumericValueOutOfRange { + value: String, + precision: u8, + scale: i8, + }, + #[error("[CAST_OVERFLOW] The value {value} of the type \"{from_type}\" cannot be cast to \"{to_type}\" \ due to an overflow. Use `try_cast` to tolerate overflow and return NULL instead. If necessary \ set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] @@ -208,6 +215,10 @@ impl jni::errors::ToException for CometError { class: "org/apache/spark/SparkException".to_string(), msg: self.to_string(), }, + CometError::NumericValueOutOfRange { .. } => Exception { + class: "org/apache/spark/SparkException".to_string(), + msg: self.to_string(), + }, CometError::NumberIntFormat { source: s } => Exception { class: "java/lang/NumberFormatException".to_string(), msg: s.to_string(), diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index 2ad9c40dd..35ab23a76 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -25,7 +25,10 @@ use std::{ use crate::errors::{CometError, CometResult}; use arrow::{ compute::{cast_with_options, CastOptions}, - datatypes::TimestampMicrosecondType, + datatypes::{ + ArrowPrimitiveType, Decimal128Type, DecimalType, Float32Type, Float64Type, + TimestampMicrosecondType, + }, record_batch::RecordBatch, util::display::FormatOptions, }; @@ -39,7 +42,7 @@ use chrono::{TimeZone, Timelike}; use datafusion::logical_expr::ColumnarValue; use datafusion_common::{internal_err, Result as DataFusionResult, ScalarValue}; use datafusion_physical_expr::PhysicalExpr; -use num::{traits::CheckedNeg, CheckedSub, Integer, Num}; +use num::{cast::AsPrimitive, traits::CheckedNeg, CheckedSub, Integer, Num, ToPrimitive}; use regex::Regex; use crate::execution::datafusion::expressions::utils::{ @@ -566,6 +569,12 @@ impl Cast { (DataType::Float32, DataType::LargeUtf8) => { Self::spark_cast_float32_to_utf8::(&array, self.eval_mode)? } + (DataType::Float32, DataType::Decimal128(precision, scale)) => { + Self::cast_float32_to_decimal128(&array, *precision, *scale, self.eval_mode)? + } + (DataType::Float64, DataType::Decimal128(precision, scale)) => { + Self::cast_float64_to_decimal128(&array, *precision, *scale, self.eval_mode)? + } (DataType::Float32, DataType::Int8) | (DataType::Float32, DataType::Int16) | (DataType::Float32, DataType::Int32) @@ -650,6 +659,83 @@ impl Cast { Ok(cast_array) } + fn cast_float64_to_decimal128( + array: &dyn Array, + precision: u8, + scale: i8, + eval_mode: EvalMode, + ) -> CometResult { + Self::cast_floating_point_to_decimal128::(array, precision, scale, eval_mode) + } + + fn cast_float32_to_decimal128( + array: &dyn Array, + precision: u8, + scale: i8, + eval_mode: EvalMode, + ) -> CometResult { + Self::cast_floating_point_to_decimal128::(array, precision, scale, eval_mode) + } + + fn cast_floating_point_to_decimal128( + array: &dyn Array, + precision: u8, + scale: i8, + eval_mode: EvalMode, + ) -> CometResult + where + ::Native: AsPrimitive, + { + let input = array.as_any().downcast_ref::>().unwrap(); + let mut cast_array = PrimitiveArray::::builder(input.len()); + + let mul = 10_f64.powi(scale as i32); + + for i in 0..input.len() { + if input.is_null(i) { + cast_array.append_null(); + } else { + let input_value = input.value(i).as_(); + let value = (input_value * mul).round().to_i128(); + + match value { + Some(v) => { + if Decimal128Type::validate_decimal_precision(v, precision).is_err() { + if eval_mode == EvalMode::Ansi { + return Err(CometError::NumericValueOutOfRange { + value: input_value.to_string(), + precision, + scale, + }); + } else { + cast_array.append_null(); + } + } + cast_array.append_value(v); + } + None => { + if eval_mode == EvalMode::Ansi { + return Err(CometError::NumericValueOutOfRange { + value: input_value.to_string(), + precision, + scale, + }); + } else { + cast_array.append_null(); + } + } + } + } + } + + let res = Arc::new( + cast_array + .with_precision_and_scale(precision, scale)? + .finish(), + ) as ArrayRef; + Ok(res) + } + fn spark_cast_float64_to_utf8( from: &dyn Array, _eval_mode: EvalMode, diff --git a/docs/source/user-guide/compatibility.md b/docs/source/user-guide/compatibility.md index 2fd4b09b6..a4ed9289f 100644 --- a/docs/source/user-guide/compatibility.md +++ b/docs/source/user-guide/compatibility.md @@ -93,6 +93,7 @@ The following cast operations are generally compatible with Spark except for the | float | integer | | | float | long | | | float | double | | +| float | decimal | | | float | string | There can be differences in precision. For example, the input "1.4E-45" will produce 1.0E-45 instead of 1.4E-45 | | double | boolean | | | double | byte | | @@ -100,6 +101,7 @@ The following cast operations are generally compatible with Spark except for the | double | integer | | | double | long | | | double | float | | +| double | decimal | | | double | string | There can be differences in precision. For example, the input "1.4E-45" will produce 1.0E-45 instead of 1.4E-45 | | decimal | byte | | | decimal | short | | @@ -127,8 +129,6 @@ The following cast operations are not compatible with Spark for all inputs and a |-|-|-| | integer | decimal | No overflow check | | long | decimal | No overflow check | -| float | decimal | No overflow check | -| double | decimal | No overflow check | | string | timestamp | Not all valid formats are supported | | binary | string | Only works for binary data representing valid UTF-8 strings | diff --git a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala index 5c225e3b6..795bdb428 100644 --- a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala +++ b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala @@ -229,7 +229,7 @@ object CometCast { case DataTypes.BooleanType | DataTypes.DoubleType | DataTypes.ByteType | DataTypes.ShortType | DataTypes.IntegerType | DataTypes.LongType => Compatible() - case _: DecimalType => Incompatible(Some("No overflow check")) + case _: DecimalType => Compatible() case _ => Unsupported } @@ -237,7 +237,7 @@ object CometCast { case DataTypes.BooleanType | DataTypes.FloatType | DataTypes.ByteType | DataTypes.ShortType | DataTypes.IntegerType | DataTypes.LongType => Compatible() - case _: DecimalType => Incompatible(Some("No overflow check")) + case _: DecimalType => Compatible() case _ => Unsupported } diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index 827f4238d..1881c561c 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -340,8 +340,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { castTest(generateFloats(), DataTypes.DoubleType) } - ignore("cast FloatType to DecimalType(10,2)") { - // Comet should have failed with [NUMERIC_VALUE_OUT_OF_RANGE] + test("cast FloatType to DecimalType(10,2)") { castTest(generateFloats(), DataTypes.createDecimalType(10, 2)) } @@ -394,8 +393,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { castTest(generateDoubles(), DataTypes.FloatType) } - ignore("cast DoubleType to DecimalType(10,2)") { - // Comet should have failed with [NUMERIC_VALUE_OUT_OF_RANGE] + test("cast DoubleType to DecimalType(10,2)") { castTest(generateDoubles(), DataTypes.createDecimalType(10, 2)) } @@ -1003,11 +1001,19 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { val cometMessageModified = cometMessage .replace("[CAST_INVALID_INPUT] ", "") .replace("[CAST_OVERFLOW] ", "") - assert(cometMessageModified == sparkMessage) + .replace("[NUMERIC_VALUE_OUT_OF_RANGE] ", "") + + if (sparkMessage.contains("cannot be represented as")) { + assert(cometMessage.contains("cannot be represented as")) + } else { + assert(cometMessageModified == sparkMessage) + } } else { // for Spark 3.2 we just make sure we are seeing a similar type of error if (sparkMessage.contains("causes overflow")) { assert(cometMessage.contains("due to an overflow")) + } else if (sparkMessage.contains("cannot be represented as")) { + assert(cometMessage.contains("cannot be represented as")) } else { // assume that this is an invalid input message in the form: // `invalid input syntax for type numeric: -9223372036854775809`