Skip to content

Commit

Permalink
feat: Implement Spark-compatible CAST from floating-point/double to d…
Browse files Browse the repository at this point in the history
…ecimal (#384)

* support NumericValueOutOfRange error

* adding ansi checks and code refactor

* fmt fixes

* Remove redundant comment

* bug fix

* adding cast for float32 as well

* fix test case for spark 3.2 and 3.3

* return error only in ansi mode
  • Loading branch information
vaibhawvipul authored May 9, 2024
1 parent 14494d3 commit 56f57f4
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 11 deletions.
11 changes: 11 additions & 0 deletions core/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,13 @@ pub enum CometError {
to_type: String,
},

#[error("[NUMERIC_VALUE_OUT_OF_RANGE] {value} cannot be represented as Decimal({precision}, {scale}). If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error, and return NULL instead.")]
NumericValueOutOfRange {
value: String,
precision: u8,
scale: i8,
},

#[error("[CAST_OVERFLOW] The value {value} of the type \"{from_type}\" cannot be cast to \"{to_type}\" \
due to an overflow. Use `try_cast` to tolerate overflow and return NULL instead. If necessary \
set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")]
Expand Down Expand Up @@ -208,6 +215,10 @@ impl jni::errors::ToException for CometError {
class: "org/apache/spark/SparkException".to_string(),
msg: self.to_string(),
},
CometError::NumericValueOutOfRange { .. } => Exception {
class: "org/apache/spark/SparkException".to_string(),
msg: self.to_string(),
},
CometError::NumberIntFormat { source: s } => Exception {
class: "java/lang/NumberFormatException".to_string(),
msg: s.to_string(),
Expand Down
90 changes: 88 additions & 2 deletions core/src/execution/datafusion/expressions/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@ use std::{
use crate::errors::{CometError, CometResult};
use arrow::{
compute::{cast_with_options, CastOptions},
datatypes::TimestampMicrosecondType,
datatypes::{
ArrowPrimitiveType, Decimal128Type, DecimalType, Float32Type, Float64Type,
TimestampMicrosecondType,
},
record_batch::RecordBatch,
util::display::FormatOptions,
};
Expand All @@ -39,7 +42,7 @@ use chrono::{TimeZone, Timelike};
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::{internal_err, Result as DataFusionResult, ScalarValue};
use datafusion_physical_expr::PhysicalExpr;
use num::{traits::CheckedNeg, CheckedSub, Integer, Num};
use num::{cast::AsPrimitive, traits::CheckedNeg, CheckedSub, Integer, Num, ToPrimitive};
use regex::Regex;

use crate::execution::datafusion::expressions::utils::{
Expand Down Expand Up @@ -566,6 +569,12 @@ impl Cast {
(DataType::Float32, DataType::LargeUtf8) => {
Self::spark_cast_float32_to_utf8::<i64>(&array, self.eval_mode)?
}
(DataType::Float32, DataType::Decimal128(precision, scale)) => {
Self::cast_float32_to_decimal128(&array, *precision, *scale, self.eval_mode)?
}
(DataType::Float64, DataType::Decimal128(precision, scale)) => {
Self::cast_float64_to_decimal128(&array, *precision, *scale, self.eval_mode)?
}
(DataType::Float32, DataType::Int8)
| (DataType::Float32, DataType::Int16)
| (DataType::Float32, DataType::Int32)
Expand Down Expand Up @@ -650,6 +659,83 @@ impl Cast {
Ok(cast_array)
}

fn cast_float64_to_decimal128(
array: &dyn Array,
precision: u8,
scale: i8,
eval_mode: EvalMode,
) -> CometResult<ArrayRef> {
Self::cast_floating_point_to_decimal128::<Float64Type>(array, precision, scale, eval_mode)
}

fn cast_float32_to_decimal128(
array: &dyn Array,
precision: u8,
scale: i8,
eval_mode: EvalMode,
) -> CometResult<ArrayRef> {
Self::cast_floating_point_to_decimal128::<Float32Type>(array, precision, scale, eval_mode)
}

fn cast_floating_point_to_decimal128<T: ArrowPrimitiveType>(
array: &dyn Array,
precision: u8,
scale: i8,
eval_mode: EvalMode,
) -> CometResult<ArrayRef>
where
<T as ArrowPrimitiveType>::Native: AsPrimitive<f64>,
{
let input = array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
let mut cast_array = PrimitiveArray::<Decimal128Type>::builder(input.len());

let mul = 10_f64.powi(scale as i32);

for i in 0..input.len() {
if input.is_null(i) {
cast_array.append_null();
} else {
let input_value = input.value(i).as_();
let value = (input_value * mul).round().to_i128();

match value {
Some(v) => {
if Decimal128Type::validate_decimal_precision(v, precision).is_err() {
if eval_mode == EvalMode::Ansi {
return Err(CometError::NumericValueOutOfRange {
value: input_value.to_string(),
precision,
scale,
});
} else {
cast_array.append_null();
}
}
cast_array.append_value(v);
}
None => {
if eval_mode == EvalMode::Ansi {
return Err(CometError::NumericValueOutOfRange {
value: input_value.to_string(),
precision,
scale,
});
} else {
cast_array.append_null();
}
}
}
}
}

let res = Arc::new(
cast_array
.with_precision_and_scale(precision, scale)?
.finish(),
) as ArrayRef;
Ok(res)
}

fn spark_cast_float64_to_utf8<OffsetSize>(
from: &dyn Array,
_eval_mode: EvalMode,
Expand Down
4 changes: 2 additions & 2 deletions docs/source/user-guide/compatibility.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,15 @@ The following cast operations are generally compatible with Spark except for the
| float | integer | |
| float | long | |
| float | double | |
| float | decimal | |
| float | string | There can be differences in precision. For example, the input "1.4E-45" will produce 1.0E-45 instead of 1.4E-45 |
| double | boolean | |
| double | byte | |
| double | short | |
| double | integer | |
| double | long | |
| double | float | |
| double | decimal | |
| double | string | There can be differences in precision. For example, the input "1.4E-45" will produce 1.0E-45 instead of 1.4E-45 |
| decimal | byte | |
| decimal | short | |
Expand Down Expand Up @@ -127,8 +129,6 @@ The following cast operations are not compatible with Spark for all inputs and a
|-|-|-|
| integer | decimal | No overflow check |
| long | decimal | No overflow check |
| float | decimal | No overflow check |
| double | decimal | No overflow check |
| string | timestamp | Not all valid formats are supported |
| binary | string | Only works for binary data representing valid UTF-8 strings |

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,15 +229,15 @@ object CometCast {
case DataTypes.BooleanType | DataTypes.DoubleType | DataTypes.ByteType | DataTypes.ShortType |
DataTypes.IntegerType | DataTypes.LongType =>
Compatible()
case _: DecimalType => Incompatible(Some("No overflow check"))
case _: DecimalType => Compatible()
case _ => Unsupported
}

private def canCastFromDouble(toType: DataType): SupportLevel = toType match {
case DataTypes.BooleanType | DataTypes.FloatType | DataTypes.ByteType | DataTypes.ShortType |
DataTypes.IntegerType | DataTypes.LongType =>
Compatible()
case _: DecimalType => Incompatible(Some("No overflow check"))
case _: DecimalType => Compatible()
case _ => Unsupported
}

Expand Down
16 changes: 11 additions & 5 deletions spark/src/test/scala/org/apache/comet/CometCastSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -340,8 +340,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
castTest(generateFloats(), DataTypes.DoubleType)
}

ignore("cast FloatType to DecimalType(10,2)") {
// Comet should have failed with [NUMERIC_VALUE_OUT_OF_RANGE]
test("cast FloatType to DecimalType(10,2)") {
castTest(generateFloats(), DataTypes.createDecimalType(10, 2))
}

Expand Down Expand Up @@ -394,8 +393,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
castTest(generateDoubles(), DataTypes.FloatType)
}

ignore("cast DoubleType to DecimalType(10,2)") {
// Comet should have failed with [NUMERIC_VALUE_OUT_OF_RANGE]
test("cast DoubleType to DecimalType(10,2)") {
castTest(generateDoubles(), DataTypes.createDecimalType(10, 2))
}

Expand Down Expand Up @@ -1003,11 +1001,19 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
val cometMessageModified = cometMessage
.replace("[CAST_INVALID_INPUT] ", "")
.replace("[CAST_OVERFLOW] ", "")
assert(cometMessageModified == sparkMessage)
.replace("[NUMERIC_VALUE_OUT_OF_RANGE] ", "")

if (sparkMessage.contains("cannot be represented as")) {
assert(cometMessage.contains("cannot be represented as"))
} else {
assert(cometMessageModified == sparkMessage)
}
} else {
// for Spark 3.2 we just make sure we are seeing a similar type of error
if (sparkMessage.contains("causes overflow")) {
assert(cometMessage.contains("due to an overflow"))
} else if (sparkMessage.contains("cannot be represented as")) {
assert(cometMessage.contains("cannot be represented as"))
} else {
// assume that this is an invalid input message in the form:
// `invalid input syntax for type numeric: -9223372036854775809`
Expand Down

0 comments on commit 56f57f4

Please sign in to comment.