diff --git a/native/Cargo.lock b/native/Cargo.lock index 9bf8247d0..605af92ee 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -914,6 +914,7 @@ dependencies = [ "datafusion-common", "datafusion-functions", "datafusion-physical-expr", + "thiserror", ] [[package]] diff --git a/native/Cargo.toml b/native/Cargo.toml index 53afed85a..944b6e28a 100644 --- a/native/Cargo.toml +++ b/native/Cargo.toml @@ -48,6 +48,7 @@ datafusion-physical-expr-common = { git = "https://github.com/apache/datafusion. datafusion-physical-expr = { git = "https://github.com/apache/datafusion.git", rev = "40.0.0-rc1", default-features = false } datafusion-comet-spark-expr = { path = "spark-expr", version = "0.1.0" } datafusion-comet-utils = { path = "utils", version = "0.1.0" } +thiserror = "1" [profile.release] debug = true diff --git a/native/core/Cargo.toml b/native/core/Cargo.toml index be135d4e9..50c1ce2b3 100644 --- a/native/core/Cargo.toml +++ b/native/core/Cargo.toml @@ -48,7 +48,7 @@ tokio = { version = "1", features = ["rt-multi-thread"] } async-trait = "0.1" log = "0.4" log4rs = "1.2.0" -thiserror = "1" +thiserror = { workspace = true } serde = { version = "1", features = ["derive"] } lazy_static = "1.4.0" prost = "0.12.1" diff --git a/native/core/src/errors.rs b/native/core/src/errors.rs index 8c02a72d1..ff89e77d2 100644 --- a/native/core/src/errors.rs +++ b/native/core/src/errors.rs @@ -38,6 +38,7 @@ use std::{ use jni::sys::{jboolean, jbyte, jchar, jdouble, jfloat, jint, jlong, jobject, jshort}; use crate::execution::operators::ExecutionError; +use datafusion_comet_spark_expr::SparkError; use jni::objects::{GlobalRef, JThrowable}; use jni::JNIEnv; use lazy_static::lazy_static; @@ -62,36 +63,10 @@ pub enum CometError { #[error("Comet Internal Error: {0}")] Internal(String), - // Note that this message format is based on Spark 3.4 and is more detailed than the message - // returned by Spark 3.3 - #[error("[CAST_INVALID_INPUT] The value '{value}' of the type \"{from_type}\" cannot be cast to \"{to_type}\" \ - because it is malformed. Correct the value as per the syntax, or change its target type. \ - Use `try_cast` to tolerate malformed input and return NULL instead. If necessary \ - set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] - CastInvalidValue { - value: String, - from_type: String, - to_type: String, - }, - - #[error("[NUMERIC_VALUE_OUT_OF_RANGE] {value} cannot be represented as Decimal({precision}, {scale}). If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error, and return NULL instead.")] - NumericValueOutOfRange { - value: String, - precision: u8, - scale: i8, - }, - - #[error("[CAST_OVERFLOW] The value {value} of the type \"{from_type}\" cannot be cast to \"{to_type}\" \ - due to an overflow. Use `try_cast` to tolerate overflow and return NULL instead. If necessary \ - set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] - CastOverFlow { - value: String, - from_type: String, - to_type: String, - }, - - #[error("[ARITHMETIC_OVERFLOW] {from_type} overflow. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] - ArithmeticOverflow { from_type: String }, + /// CometError::Spark is typically used in native code to emulate the same errors + /// that Spark would return + #[error(transparent)] + Spark(SparkError), #[error(transparent)] Arrow { @@ -239,11 +214,7 @@ impl jni::errors::ToException for CometError { class: "java/lang/NullPointerException".to_string(), msg: self.to_string(), }, - CometError::CastInvalidValue { .. } => Exception { - class: "org/apache/spark/SparkException".to_string(), - msg: self.to_string(), - }, - CometError::NumericValueOutOfRange { .. } => Exception { + CometError::Spark { .. } => Exception { class: "org/apache/spark/SparkException".to_string(), msg: self.to_string(), }, diff --git a/native/core/src/execution/datafusion/expressions/cast.rs b/native/core/src/execution/datafusion/expressions/cast.rs index 154ff28b5..0b513e776 100644 --- a/native/core/src/execution/datafusion/expressions/cast.rs +++ b/native/core/src/execution/datafusion/expressions/cast.rs @@ -40,16 +40,14 @@ use arrow_array::{ use arrow_schema::{DataType, Schema}; use chrono::{NaiveDate, NaiveDateTime, TimeZone, Timelike}; use datafusion::logical_expr::ColumnarValue; +use datafusion_comet_spark_expr::{SparkError, SparkResult}; use datafusion_common::{internal_err, Result as DataFusionResult, ScalarValue}; use datafusion_physical_expr::PhysicalExpr; use num::{cast::AsPrimitive, traits::CheckedNeg, CheckedSub, Integer, Num, ToPrimitive}; use regex::Regex; -use crate::{ - errors::{CometError, CometResult}, - execution::datafusion::expressions::utils::{ - array_with_timezone, down_cast_any_ref, spark_cast, - }, +use crate::execution::datafusion::expressions::utils::{ + array_with_timezone, down_cast_any_ref, spark_cast, }; use super::EvalMode; @@ -87,7 +85,7 @@ macro_rules! cast_utf8_to_int { cast_array.append_null() } } - let result: CometResult<ArrayRef> = Ok(Arc::new(cast_array.finish()) as ArrayRef); + let result: SparkResult<ArrayRef> = Ok(Arc::new(cast_array.finish()) as ArrayRef); result }}; } @@ -116,7 +114,7 @@ macro_rules! cast_float_to_string { fn cast<OffsetSize>( from: &dyn Array, _eval_mode: EvalMode, - ) -> CometResult<ArrayRef> + ) -> SparkResult<ArrayRef> where OffsetSize: OffsetSizeTrait, { let array = from.as_any().downcast_ref::<$output_type>().unwrap(); @@ -169,7 +167,7 @@ macro_rules! cast_float_to_string { Some(value) => Ok(Some(value.to_string())), _ => Ok(None), }) - .collect::<Result<GenericStringArray<OffsetSize>, CometError>>()?; + .collect::<Result<GenericStringArray<OffsetSize>, SparkError>>()?; Ok(Arc::new(output_array)) } @@ -205,7 +203,7 @@ macro_rules! cast_int_to_int_macro { .iter() .map(|value| match value { Some(value) => { - Ok::<Option<$to_native_type>, CometError>(Some(value as $to_native_type)) + Ok::<Option<$to_native_type>, SparkError>(Some(value as $to_native_type)) } _ => Ok(None), }) @@ -222,14 +220,14 @@ macro_rules! cast_int_to_int_macro { $spark_to_data_type_name, )) } else { - Ok::<Option<$to_native_type>, CometError>(Some(res.unwrap())) + Ok::<Option<$to_native_type>, SparkError>(Some(res.unwrap())) } } _ => Ok(None), }) .collect::<Result<PrimitiveArray<$to_arrow_primitive_type>, _>>(), }?; - let result: CometResult<ArrayRef> = Ok(Arc::new(output_array) as ArrayRef); + let result: SparkResult<ArrayRef> = Ok(Arc::new(output_array) as ArrayRef); result }}; } @@ -286,7 +284,7 @@ macro_rules! cast_float_to_int16_down { .map(|value| match value { Some(value) => { let i32_value = value as i32; - Ok::<Option<$rust_dest_type>, CometError>(Some( + Ok::<Option<$rust_dest_type>, SparkError>(Some( i32_value as $rust_dest_type, )) } @@ -339,7 +337,7 @@ macro_rules! cast_float_to_int32_up { .iter() .map(|value| match value { Some(value) => { - Ok::<Option<$rust_dest_type>, CometError>(Some(value as $rust_dest_type)) + Ok::<Option<$rust_dest_type>, SparkError>(Some(value as $rust_dest_type)) } None => Ok(None), }) @@ -402,7 +400,7 @@ macro_rules! cast_decimal_to_int16_down { Some(value) => { let divisor = 10_i128.pow($scale as u32); let i32_value = (value / divisor) as i32; - Ok::<Option<$rust_dest_type>, CometError>(Some( + Ok::<Option<$rust_dest_type>, SparkError>(Some( i32_value as $rust_dest_type, )) } @@ -456,7 +454,7 @@ macro_rules! cast_decimal_to_int32_up { Some(value) => { let divisor = 10_i128.pow($scale as u32); let truncated = value / divisor; - Ok::<Option<$rust_dest_type>, CometError>(Some( + Ok::<Option<$rust_dest_type>, SparkError>(Some( truncated as $rust_dest_type, )) } @@ -596,7 +594,7 @@ impl Cast { // we should never reach this code because the Scala code should be checking // for supported cast operations and falling back to Spark for anything that // is not yet supported - Err(CometError::Internal(format!( + Err(SparkError::Internal(format!( "Native cast invoked for unsupported cast from {from_type:?} to {to_type:?}" ))) } @@ -680,7 +678,7 @@ impl Cast { to_type: &DataType, array: &ArrayRef, eval_mode: EvalMode, - ) -> CometResult<ArrayRef> { + ) -> SparkResult<ArrayRef> { let string_array = array .as_any() .downcast_ref::<GenericStringArray<OffsetSize>>() @@ -711,7 +709,7 @@ impl Cast { array: &ArrayRef, to_type: &DataType, eval_mode: EvalMode, - ) -> CometResult<ArrayRef> { + ) -> SparkResult<ArrayRef> { let string_array = array .as_any() .downcast_ref::<GenericStringArray<i32>>() @@ -743,7 +741,7 @@ impl Cast { array: &ArrayRef, to_type: &DataType, eval_mode: EvalMode, - ) -> CometResult<ArrayRef> { + ) -> SparkResult<ArrayRef> { let string_array = array .as_any() .downcast_ref::<GenericStringArray<i32>>() @@ -768,7 +766,7 @@ impl Cast { precision: u8, scale: i8, eval_mode: EvalMode, - ) -> CometResult<ArrayRef> { + ) -> SparkResult<ArrayRef> { Self::cast_floating_point_to_decimal128::<Float64Type>(array, precision, scale, eval_mode) } @@ -777,7 +775,7 @@ impl Cast { precision: u8, scale: i8, eval_mode: EvalMode, - ) -> CometResult<ArrayRef> { + ) -> SparkResult<ArrayRef> { Self::cast_floating_point_to_decimal128::<Float32Type>(array, precision, scale, eval_mode) } @@ -786,7 +784,7 @@ impl Cast { precision: u8, scale: i8, eval_mode: EvalMode, - ) -> CometResult<ArrayRef> + ) -> SparkResult<ArrayRef> where <T as ArrowPrimitiveType>::Native: AsPrimitive<f64>, { @@ -806,7 +804,7 @@ impl Cast { Some(v) => { if Decimal128Type::validate_decimal_precision(v, precision).is_err() { if eval_mode == EvalMode::Ansi { - return Err(CometError::NumericValueOutOfRange { + return Err(SparkError::NumericValueOutOfRange { value: input_value.to_string(), precision, scale, @@ -819,7 +817,7 @@ impl Cast { } None => { if eval_mode == EvalMode::Ansi { - return Err(CometError::NumericValueOutOfRange { + return Err(SparkError::NumericValueOutOfRange { value: input_value.to_string(), precision, scale, @@ -843,7 +841,7 @@ impl Cast { fn spark_cast_float64_to_utf8<OffsetSize>( from: &dyn Array, _eval_mode: EvalMode, - ) -> CometResult<ArrayRef> + ) -> SparkResult<ArrayRef> where OffsetSize: OffsetSizeTrait, { @@ -853,7 +851,7 @@ impl Cast { fn spark_cast_float32_to_utf8<OffsetSize>( from: &dyn Array, _eval_mode: EvalMode, - ) -> CometResult<ArrayRef> + ) -> SparkResult<ArrayRef> where OffsetSize: OffsetSizeTrait, { @@ -865,7 +863,7 @@ impl Cast { eval_mode: EvalMode, from_type: &DataType, to_type: &DataType, - ) -> CometResult<ArrayRef> { + ) -> SparkResult<ArrayRef> { match (from_type, to_type) { (DataType::Int64, DataType::Int32) => cast_int_to_int_macro!( array, eval_mode, Int64Type, Int32Type, from_type, i32, "BIGINT", "INT" @@ -895,7 +893,7 @@ impl Cast { fn spark_cast_utf8_to_boolean<OffsetSize>( from: &dyn Array, eval_mode: EvalMode, - ) -> CometResult<ArrayRef> + ) -> SparkResult<ArrayRef> where OffsetSize: OffsetSizeTrait, { @@ -910,7 +908,7 @@ impl Cast { Some(value) => match value.to_ascii_lowercase().trim() { "t" | "true" | "y" | "yes" | "1" => Ok(Some(true)), "f" | "false" | "n" | "no" | "0" => Ok(Some(false)), - _ if eval_mode == EvalMode::Ansi => Err(CometError::CastInvalidValue { + _ if eval_mode == EvalMode::Ansi => Err(SparkError::CastInvalidValue { value: value.to_string(), from_type: "STRING".to_string(), to_type: "BOOLEAN".to_string(), @@ -929,7 +927,7 @@ impl Cast { eval_mode: EvalMode, from_type: &DataType, to_type: &DataType, - ) -> CometResult<ArrayRef> { + ) -> SparkResult<ArrayRef> { match (from_type, to_type) { (DataType::Float32, DataType::Int8) => cast_float_to_int16_down!( array, @@ -1066,7 +1064,7 @@ impl Cast { } /// Equivalent to org.apache.spark.unsafe.types.UTF8String.toByte -fn cast_string_to_i8(str: &str, eval_mode: EvalMode) -> CometResult<Option<i8>> { +fn cast_string_to_i8(str: &str, eval_mode: EvalMode) -> SparkResult<Option<i8>> { Ok(cast_string_to_int_with_range_check( str, eval_mode, @@ -1078,7 +1076,7 @@ fn cast_string_to_i8(str: &str, eval_mode: EvalMode) -> CometResult<Option<i8>> } /// Equivalent to org.apache.spark.unsafe.types.UTF8String.toShort -fn cast_string_to_i16(str: &str, eval_mode: EvalMode) -> CometResult<Option<i16>> { +fn cast_string_to_i16(str: &str, eval_mode: EvalMode) -> SparkResult<Option<i16>> { Ok(cast_string_to_int_with_range_check( str, eval_mode, @@ -1090,12 +1088,12 @@ fn cast_string_to_i16(str: &str, eval_mode: EvalMode) -> CometResult<Option<i16> } /// Equivalent to org.apache.spark.unsafe.types.UTF8String.toInt(IntWrapper intWrapper) -fn cast_string_to_i32(str: &str, eval_mode: EvalMode) -> CometResult<Option<i32>> { +fn cast_string_to_i32(str: &str, eval_mode: EvalMode) -> SparkResult<Option<i32>> { do_cast_string_to_int::<i32>(str, eval_mode, "INT", i32::MIN) } /// Equivalent to org.apache.spark.unsafe.types.UTF8String.toLong(LongWrapper intWrapper) -fn cast_string_to_i64(str: &str, eval_mode: EvalMode) -> CometResult<Option<i64>> { +fn cast_string_to_i64(str: &str, eval_mode: EvalMode) -> SparkResult<Option<i64>> { do_cast_string_to_int::<i64>(str, eval_mode, "BIGINT", i64::MIN) } @@ -1105,7 +1103,7 @@ fn cast_string_to_int_with_range_check( type_name: &str, min: i32, max: i32, -) -> CometResult<Option<i32>> { +) -> SparkResult<Option<i32>> { match do_cast_string_to_int(str, eval_mode, type_name, i32::MIN)? { None => Ok(None), Some(v) if v >= min && v <= max => Ok(Some(v)), @@ -1124,7 +1122,7 @@ fn do_cast_string_to_int< eval_mode: EvalMode, type_name: &str, min_value: T, -) -> CometResult<Option<T>> { +) -> SparkResult<Option<T>> { let trimmed_str = str.trim(); if trimmed_str.is_empty() { return none_or_err(eval_mode, type_name, str); @@ -1208,9 +1206,9 @@ fn do_cast_string_to_int< Ok(Some(result)) } -/// Either return Ok(None) or Err(CometError::CastInvalidValue) depending on the evaluation mode +/// Either return Ok(None) or Err(SparkError::CastInvalidValue) depending on the evaluation mode #[inline] -fn none_or_err<T>(eval_mode: EvalMode, type_name: &str, str: &str) -> CometResult<Option<T>> { +fn none_or_err<T>(eval_mode: EvalMode, type_name: &str, str: &str) -> SparkResult<Option<T>> { match eval_mode { EvalMode::Ansi => Err(invalid_value(str, "STRING", type_name)), _ => Ok(None), @@ -1218,8 +1216,8 @@ fn none_or_err<T>(eval_mode: EvalMode, type_name: &str, str: &str) -> CometResul } #[inline] -fn invalid_value(value: &str, from_type: &str, to_type: &str) -> CometError { - CometError::CastInvalidValue { +fn invalid_value(value: &str, from_type: &str, to_type: &str) -> SparkError { + SparkError::CastInvalidValue { value: value.to_string(), from_type: from_type.to_string(), to_type: to_type.to_string(), @@ -1227,8 +1225,8 @@ fn invalid_value(value: &str, from_type: &str, to_type: &str) -> CometError { } #[inline] -fn cast_overflow(value: &str, from_type: &str, to_type: &str) -> CometError { - CometError::CastOverFlow { +fn cast_overflow(value: &str, from_type: &str, to_type: &str) -> SparkError { + SparkError::CastOverFlow { value: value.to_string(), from_type: from_type.to_string(), to_type: to_type.to_string(), @@ -1316,7 +1314,7 @@ impl PhysicalExpr for Cast { } } -fn timestamp_parser(value: &str, eval_mode: EvalMode) -> CometResult<Option<i64>> { +fn timestamp_parser(value: &str, eval_mode: EvalMode) -> SparkResult<Option<i64>> { let value = value.trim(); if value.is_empty() { return Ok(None); @@ -1325,7 +1323,7 @@ fn timestamp_parser(value: &str, eval_mode: EvalMode) -> CometResult<Option<i64> let patterns = &[ ( Regex::new(r"^\d{4}$").unwrap(), - parse_str_to_year_timestamp as fn(&str) -> CometResult<Option<i64>>, + parse_str_to_year_timestamp as fn(&str) -> SparkResult<Option<i64>>, ), ( Regex::new(r"^\d{4}-\d{2}$").unwrap(), @@ -1369,7 +1367,7 @@ fn timestamp_parser(value: &str, eval_mode: EvalMode) -> CometResult<Option<i64> if timestamp.is_none() { return if eval_mode == EvalMode::Ansi { - Err(CometError::CastInvalidValue { + Err(SparkError::CastInvalidValue { value: value.to_string(), from_type: "STRING".to_string(), to_type: "TIMESTAMP".to_string(), @@ -1381,20 +1379,20 @@ fn timestamp_parser(value: &str, eval_mode: EvalMode) -> CometResult<Option<i64> match timestamp { Some(ts) => Ok(Some(ts)), - None => Err(CometError::Internal( + None => Err(SparkError::Internal( "Failed to parse timestamp".to_string(), )), } } -fn parse_ymd_timestamp(year: i32, month: u32, day: u32) -> CometResult<Option<i64>> { +fn parse_ymd_timestamp(year: i32, month: u32, day: u32) -> SparkResult<Option<i64>> { let datetime = chrono::Utc.with_ymd_and_hms(year, month, day, 0, 0, 0); // Check if datetime is not None let utc_datetime = match datetime.single() { Some(dt) => dt.with_timezone(&chrono::Utc), None => { - return Err(CometError::Internal( + return Err(SparkError::Internal( "Failed to parse timestamp".to_string(), )); } @@ -1411,7 +1409,7 @@ fn parse_hms_timestamp( minute: u32, second: u32, microsecond: u32, -) -> CometResult<Option<i64>> { +) -> SparkResult<Option<i64>> { let datetime = chrono::Utc.with_ymd_and_hms(year, month, day, hour, minute, second); // Check if datetime is not None @@ -1420,7 +1418,7 @@ fn parse_hms_timestamp( .with_timezone(&chrono::Utc) .with_nanosecond(microsecond * 1000), None => { - return Err(CometError::Internal( + return Err(SparkError::Internal( "Failed to parse timestamp".to_string(), )); } @@ -1429,7 +1427,7 @@ fn parse_hms_timestamp( let result = match utc_datetime { Some(dt) => dt.timestamp_micros(), None => { - return Err(CometError::Internal( + return Err(SparkError::Internal( "Failed to parse timestamp".to_string(), )); } @@ -1438,7 +1436,7 @@ fn parse_hms_timestamp( Ok(Some(result)) } -fn get_timestamp_values(value: &str, timestamp_type: &str) -> CometResult<Option<i64>> { +fn get_timestamp_values(value: &str, timestamp_type: &str) -> SparkResult<Option<i64>> { let values: Vec<_> = value .split(|c| c == 'T' || c == '-' || c == ':' || c == '.') .collect(); @@ -1458,7 +1456,7 @@ fn get_timestamp_values(value: &str, timestamp_type: &str) -> CometResult<Option "minute" => parse_hms_timestamp(year, month, day, hour, minute, 0, 0), "second" => parse_hms_timestamp(year, month, day, hour, minute, second, 0), "microsecond" => parse_hms_timestamp(year, month, day, hour, minute, second, microsecond), - _ => Err(CometError::CastInvalidValue { + _ => Err(SparkError::CastInvalidValue { value: value.to_string(), from_type: "STRING".to_string(), to_type: "TIMESTAMP".to_string(), @@ -1466,35 +1464,35 @@ fn get_timestamp_values(value: &str, timestamp_type: &str) -> CometResult<Option } } -fn parse_str_to_year_timestamp(value: &str) -> CometResult<Option<i64>> { +fn parse_str_to_year_timestamp(value: &str) -> SparkResult<Option<i64>> { get_timestamp_values(value, "year") } -fn parse_str_to_month_timestamp(value: &str) -> CometResult<Option<i64>> { +fn parse_str_to_month_timestamp(value: &str) -> SparkResult<Option<i64>> { get_timestamp_values(value, "month") } -fn parse_str_to_day_timestamp(value: &str) -> CometResult<Option<i64>> { +fn parse_str_to_day_timestamp(value: &str) -> SparkResult<Option<i64>> { get_timestamp_values(value, "day") } -fn parse_str_to_hour_timestamp(value: &str) -> CometResult<Option<i64>> { +fn parse_str_to_hour_timestamp(value: &str) -> SparkResult<Option<i64>> { get_timestamp_values(value, "hour") } -fn parse_str_to_minute_timestamp(value: &str) -> CometResult<Option<i64>> { +fn parse_str_to_minute_timestamp(value: &str) -> SparkResult<Option<i64>> { get_timestamp_values(value, "minute") } -fn parse_str_to_second_timestamp(value: &str) -> CometResult<Option<i64>> { +fn parse_str_to_second_timestamp(value: &str) -> SparkResult<Option<i64>> { get_timestamp_values(value, "second") } -fn parse_str_to_microsecond_timestamp(value: &str) -> CometResult<Option<i64>> { +fn parse_str_to_microsecond_timestamp(value: &str) -> SparkResult<Option<i64>> { get_timestamp_values(value, "microsecond") } -fn parse_str_to_time_only_timestamp(value: &str) -> CometResult<Option<i64>> { +fn parse_str_to_time_only_timestamp(value: &str) -> SparkResult<Option<i64>> { let values: Vec<&str> = value.split('T').collect(); let time_values: Vec<u32> = values[1] .split(':') @@ -1514,7 +1512,7 @@ fn parse_str_to_time_only_timestamp(value: &str) -> CometResult<Option<i64>> { } //a string to date parser - port of spark's SparkDateTimeUtils#stringToDate. -fn date_parser(date_str: &str, eval_mode: EvalMode) -> CometResult<Option<i32>> { +fn date_parser(date_str: &str, eval_mode: EvalMode) -> SparkResult<Option<i32>> { // local functions fn get_trimmed_start(bytes: &[u8]) -> usize { let mut start = 0; @@ -1545,9 +1543,9 @@ fn date_parser(date_str: &str, eval_mode: EvalMode) -> CometResult<Option<i32>> || (segment != 0 && digits > 0 && digits <= 2) } - fn return_result(date_str: &str, eval_mode: EvalMode) -> CometResult<Option<i32>> { + fn return_result(date_str: &str, eval_mode: EvalMode) -> SparkResult<Option<i32>> { if eval_mode == EvalMode::Ansi { - Err(CometError::CastInvalidValue { + Err(SparkError::CastInvalidValue { value: date_str.to_string(), from_type: "STRING".to_string(), to_type: "DATE".to_string(), diff --git a/native/core/src/execution/datafusion/expressions/mod.rs b/native/core/src/execution/datafusion/expressions/mod.rs index d573c2377..c61266cea 100644 --- a/native/core/src/execution/datafusion/expressions/mod.rs +++ b/native/core/src/execution/datafusion/expressions/mod.rs @@ -43,10 +43,10 @@ mod utils; pub mod variance; pub mod xxhash64; -pub use datafusion_comet_spark_expr::EvalMode; +pub use datafusion_comet_spark_expr::{EvalMode, SparkError}; fn arithmetic_overflow_error(from_type: &str) -> CometError { - CometError::ArithmeticOverflow { + CometError::Spark(SparkError::ArithmeticOverflow { from_type: from_type.to_string(), - } + }) } diff --git a/native/core/src/execution/datafusion/expressions/negative.rs b/native/core/src/execution/datafusion/expressions/negative.rs index cd0e9bccf..9e82812be 100644 --- a/native/core/src/execution/datafusion/expressions/negative.rs +++ b/native/core/src/execution/datafusion/expressions/negative.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use super::arithmetic_overflow_error; use crate::errors::CometError; use arrow::{compute::kernels::numeric::neg_wrapping, datatypes::IntervalDayTimeType}; use arrow_array::RecordBatch; @@ -24,6 +25,7 @@ use datafusion::{ logical_expr::{interval_arithmetic::Interval, ColumnarValue}, physical_expr::PhysicalExpr, }; +use datafusion_comet_spark_expr::SparkError; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::sort_properties::ExprProperties; use datafusion_physical_expr::aggregate::utils::down_cast_any_ref; @@ -33,8 +35,6 @@ use std::{ sync::Arc, }; -use super::arithmetic_overflow_error; - pub fn create_negate_expr( expr: Arc<dyn PhysicalExpr>, fail_on_error: bool, @@ -234,7 +234,7 @@ impl PhysicalExpr for NegativeExpr { || child_interval.lower() == &ScalarValue::Int64(Some(i64::MIN)) || child_interval.upper() == &ScalarValue::Int64(Some(i64::MIN)) { - return Err(CometError::ArithmeticOverflow { + return Err(SparkError::ArithmeticOverflow { from_type: "long".to_string(), } .into()); diff --git a/native/spark-expr/Cargo.toml b/native/spark-expr/Cargo.toml index 8bf76dff6..4a9b94087 100644 --- a/native/spark-expr/Cargo.toml +++ b/native/spark-expr/Cargo.toml @@ -34,6 +34,7 @@ datafusion-common = { workspace = true } datafusion-functions = { workspace = true } datafusion-physical-expr = { workspace = true } datafusion-comet-utils = { workspace = true } +thiserror = { workspace = true } [lib] name = "datafusion_comet_spark_expr" diff --git a/native/spark-expr/src/abs.rs b/native/spark-expr/src/abs.rs index 198a96e57..fa25a7775 100644 --- a/native/spark-expr/src/abs.rs +++ b/native/spark-expr/src/abs.rs @@ -77,9 +77,10 @@ impl ScalarUDFImpl for Abs { if self.eval_mode == EvalMode::Legacy { Ok(args[0].clone()) } else { - Err(DataFusionError::External(Box::new( - SparkError::ArithmeticOverflow(self.data_type_name.clone()), - ))) + Err(SparkError::ArithmeticOverflow { + from_type: self.data_type_name.clone(), + } + .into()) } } other => other, diff --git a/native/spark-expr/src/error.rs b/native/spark-expr/src/error.rs new file mode 100644 index 000000000..728a35a9d --- /dev/null +++ b/native/spark-expr/src/error.rs @@ -0,0 +1,73 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_schema::ArrowError; +use datafusion_common::DataFusionError; + +#[derive(thiserror::Error, Debug)] +pub enum SparkError { + // Note that this message format is based on Spark 3.4 and is more detailed than the message + // returned by Spark 3.3 + #[error("[CAST_INVALID_INPUT] The value '{value}' of the type \"{from_type}\" cannot be cast to \"{to_type}\" \ + because it is malformed. Correct the value as per the syntax, or change its target type. \ + Use `try_cast` to tolerate malformed input and return NULL instead. If necessary \ + set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + CastInvalidValue { + value: String, + from_type: String, + to_type: String, + }, + + #[error("[NUMERIC_VALUE_OUT_OF_RANGE] {value} cannot be represented as Decimal({precision}, {scale}). If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error, and return NULL instead.")] + NumericValueOutOfRange { + value: String, + precision: u8, + scale: i8, + }, + + #[error("[CAST_OVERFLOW] The value {value} of the type \"{from_type}\" cannot be cast to \"{to_type}\" \ + due to an overflow. Use `try_cast` to tolerate overflow and return NULL instead. If necessary \ + set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + CastOverFlow { + value: String, + from_type: String, + to_type: String, + }, + + #[error("[ARITHMETIC_OVERFLOW] {from_type} overflow. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + ArithmeticOverflow { from_type: String }, + + #[error("ArrowError: {0}.")] + Arrow(ArrowError), + + #[error("InternalError: {0}.")] + Internal(String), +} + +pub type SparkResult<T> = Result<T, SparkError>; + +impl From<ArrowError> for SparkError { + fn from(value: ArrowError) -> Self { + SparkError::Arrow(value) + } +} + +impl From<SparkError> for DataFusionError { + fn from(value: SparkError) -> Self { + DataFusionError::External(Box::new(value)) + } +} diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs index c36e8855e..57da56f9a 100644 --- a/native/spark-expr/src/lib.rs +++ b/native/spark-expr/src/lib.rs @@ -15,13 +15,12 @@ // specific language governing permissions and limitations // under the License. -use std::error::Error; -use std::fmt::{Display, Formatter}; - mod abs; +mod error; mod if_expr; pub use abs::Abs; +pub use error::{SparkError, SparkResult}; pub use if_expr::IfExpr; /// Spark supports three evaluation modes when evaluating expressions, which affect @@ -42,19 +41,3 @@ pub enum EvalMode { /// failing the entire query. Try, } - -#[derive(Debug)] -pub enum SparkError { - ArithmeticOverflow(String), -} - -impl Error for SparkError {} - -impl Display for SparkError { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - Self::ArithmeticOverflow(data_type) => - write!(f, "[ARITHMETIC_OVERFLOW] {} overflow. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.", data_type) - } - } -}