From a5bbbf0490daf6075c1f94579de8f6887bdfdce5 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 12 Jul 2024 05:52:16 -0600 Subject: [PATCH] chore: Refactoring of CometError/SparkError (#655) --- native/Cargo.lock | 1 + native/Cargo.toml | 1 + native/core/Cargo.toml | 2 +- native/core/src/errors.rs | 41 +----- .../execution/datafusion/expressions/cast.rs | 126 +++++++++--------- .../execution/datafusion/expressions/mod.rs | 6 +- .../datafusion/expressions/negative.rs | 6 +- native/spark-expr/Cargo.toml | 1 + native/spark-expr/src/abs.rs | 7 +- native/spark-expr/src/error.rs | 73 ++++++++++ native/spark-expr/src/lib.rs | 21 +-- 11 files changed, 157 insertions(+), 128 deletions(-) create mode 100644 native/spark-expr/src/error.rs diff --git a/native/Cargo.lock b/native/Cargo.lock index 9bf8247d0c..605af92ee9 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -914,6 +914,7 @@ dependencies = [ "datafusion-common", "datafusion-functions", "datafusion-physical-expr", + "thiserror", ] [[package]] diff --git a/native/Cargo.toml b/native/Cargo.toml index 53afed85a4..944b6e28a4 100644 --- a/native/Cargo.toml +++ b/native/Cargo.toml @@ -48,6 +48,7 @@ datafusion-physical-expr-common = { git = "https://github.com/apache/datafusion. datafusion-physical-expr = { git = "https://github.com/apache/datafusion.git", rev = "40.0.0-rc1", default-features = false } datafusion-comet-spark-expr = { path = "spark-expr", version = "0.1.0" } datafusion-comet-utils = { path = "utils", version = "0.1.0" } +thiserror = "1" [profile.release] debug = true diff --git a/native/core/Cargo.toml b/native/core/Cargo.toml index be135d4e9a..50c1ce2b36 100644 --- a/native/core/Cargo.toml +++ b/native/core/Cargo.toml @@ -48,7 +48,7 @@ tokio = { version = "1", features = ["rt-multi-thread"] } async-trait = "0.1" log = "0.4" log4rs = "1.2.0" -thiserror = "1" +thiserror = { workspace = true } serde = { version = "1", features = ["derive"] } lazy_static = "1.4.0" prost = "0.12.1" diff --git a/native/core/src/errors.rs b/native/core/src/errors.rs index 8c02a72d19..ff89e77d21 100644 --- a/native/core/src/errors.rs +++ b/native/core/src/errors.rs @@ -38,6 +38,7 @@ use std::{ use jni::sys::{jboolean, jbyte, jchar, jdouble, jfloat, jint, jlong, jobject, jshort}; use crate::execution::operators::ExecutionError; +use datafusion_comet_spark_expr::SparkError; use jni::objects::{GlobalRef, JThrowable}; use jni::JNIEnv; use lazy_static::lazy_static; @@ -62,36 +63,10 @@ pub enum CometError { #[error("Comet Internal Error: {0}")] Internal(String), - // Note that this message format is based on Spark 3.4 and is more detailed than the message - // returned by Spark 3.3 - #[error("[CAST_INVALID_INPUT] The value '{value}' of the type \"{from_type}\" cannot be cast to \"{to_type}\" \ - because it is malformed. Correct the value as per the syntax, or change its target type. \ - Use `try_cast` to tolerate malformed input and return NULL instead. If necessary \ - set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] - CastInvalidValue { - value: String, - from_type: String, - to_type: String, - }, - - #[error("[NUMERIC_VALUE_OUT_OF_RANGE] {value} cannot be represented as Decimal({precision}, {scale}). If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error, and return NULL instead.")] - NumericValueOutOfRange { - value: String, - precision: u8, - scale: i8, - }, - - #[error("[CAST_OVERFLOW] The value {value} of the type \"{from_type}\" cannot be cast to \"{to_type}\" \ - due to an overflow. Use `try_cast` to tolerate overflow and return NULL instead. If necessary \ - set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] - CastOverFlow { - value: String, - from_type: String, - to_type: String, - }, - - #[error("[ARITHMETIC_OVERFLOW] {from_type} overflow. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] - ArithmeticOverflow { from_type: String }, + /// CometError::Spark is typically used in native code to emulate the same errors + /// that Spark would return + #[error(transparent)] + Spark(SparkError), #[error(transparent)] Arrow { @@ -239,11 +214,7 @@ impl jni::errors::ToException for CometError { class: "java/lang/NullPointerException".to_string(), msg: self.to_string(), }, - CometError::CastInvalidValue { .. } => Exception { - class: "org/apache/spark/SparkException".to_string(), - msg: self.to_string(), - }, - CometError::NumericValueOutOfRange { .. } => Exception { + CometError::Spark { .. } => Exception { class: "org/apache/spark/SparkException".to_string(), msg: self.to_string(), }, diff --git a/native/core/src/execution/datafusion/expressions/cast.rs b/native/core/src/execution/datafusion/expressions/cast.rs index 154ff28b5a..0b513e7763 100644 --- a/native/core/src/execution/datafusion/expressions/cast.rs +++ b/native/core/src/execution/datafusion/expressions/cast.rs @@ -40,16 +40,14 @@ use arrow_array::{ use arrow_schema::{DataType, Schema}; use chrono::{NaiveDate, NaiveDateTime, TimeZone, Timelike}; use datafusion::logical_expr::ColumnarValue; +use datafusion_comet_spark_expr::{SparkError, SparkResult}; use datafusion_common::{internal_err, Result as DataFusionResult, ScalarValue}; use datafusion_physical_expr::PhysicalExpr; use num::{cast::AsPrimitive, traits::CheckedNeg, CheckedSub, Integer, Num, ToPrimitive}; use regex::Regex; -use crate::{ - errors::{CometError, CometResult}, - execution::datafusion::expressions::utils::{ - array_with_timezone, down_cast_any_ref, spark_cast, - }, +use crate::execution::datafusion::expressions::utils::{ + array_with_timezone, down_cast_any_ref, spark_cast, }; use super::EvalMode; @@ -87,7 +85,7 @@ macro_rules! cast_utf8_to_int { cast_array.append_null() } } - let result: CometResult = Ok(Arc::new(cast_array.finish()) as ArrayRef); + let result: SparkResult = Ok(Arc::new(cast_array.finish()) as ArrayRef); result }}; } @@ -116,7 +114,7 @@ macro_rules! cast_float_to_string { fn cast( from: &dyn Array, _eval_mode: EvalMode, - ) -> CometResult + ) -> SparkResult where OffsetSize: OffsetSizeTrait, { let array = from.as_any().downcast_ref::<$output_type>().unwrap(); @@ -169,7 +167,7 @@ macro_rules! cast_float_to_string { Some(value) => Ok(Some(value.to_string())), _ => Ok(None), }) - .collect::, CometError>>()?; + .collect::, SparkError>>()?; Ok(Arc::new(output_array)) } @@ -205,7 +203,7 @@ macro_rules! cast_int_to_int_macro { .iter() .map(|value| match value { Some(value) => { - Ok::, CometError>(Some(value as $to_native_type)) + Ok::, SparkError>(Some(value as $to_native_type)) } _ => Ok(None), }) @@ -222,14 +220,14 @@ macro_rules! cast_int_to_int_macro { $spark_to_data_type_name, )) } else { - Ok::, CometError>(Some(res.unwrap())) + Ok::, SparkError>(Some(res.unwrap())) } } _ => Ok(None), }) .collect::, _>>(), }?; - let result: CometResult = Ok(Arc::new(output_array) as ArrayRef); + let result: SparkResult = Ok(Arc::new(output_array) as ArrayRef); result }}; } @@ -286,7 +284,7 @@ macro_rules! cast_float_to_int16_down { .map(|value| match value { Some(value) => { let i32_value = value as i32; - Ok::, CometError>(Some( + Ok::, SparkError>(Some( i32_value as $rust_dest_type, )) } @@ -339,7 +337,7 @@ macro_rules! cast_float_to_int32_up { .iter() .map(|value| match value { Some(value) => { - Ok::, CometError>(Some(value as $rust_dest_type)) + Ok::, SparkError>(Some(value as $rust_dest_type)) } None => Ok(None), }) @@ -402,7 +400,7 @@ macro_rules! cast_decimal_to_int16_down { Some(value) => { let divisor = 10_i128.pow($scale as u32); let i32_value = (value / divisor) as i32; - Ok::, CometError>(Some( + Ok::, SparkError>(Some( i32_value as $rust_dest_type, )) } @@ -456,7 +454,7 @@ macro_rules! cast_decimal_to_int32_up { Some(value) => { let divisor = 10_i128.pow($scale as u32); let truncated = value / divisor; - Ok::, CometError>(Some( + Ok::, SparkError>(Some( truncated as $rust_dest_type, )) } @@ -596,7 +594,7 @@ impl Cast { // we should never reach this code because the Scala code should be checking // for supported cast operations and falling back to Spark for anything that // is not yet supported - Err(CometError::Internal(format!( + Err(SparkError::Internal(format!( "Native cast invoked for unsupported cast from {from_type:?} to {to_type:?}" ))) } @@ -680,7 +678,7 @@ impl Cast { to_type: &DataType, array: &ArrayRef, eval_mode: EvalMode, - ) -> CometResult { + ) -> SparkResult { let string_array = array .as_any() .downcast_ref::>() @@ -711,7 +709,7 @@ impl Cast { array: &ArrayRef, to_type: &DataType, eval_mode: EvalMode, - ) -> CometResult { + ) -> SparkResult { let string_array = array .as_any() .downcast_ref::>() @@ -743,7 +741,7 @@ impl Cast { array: &ArrayRef, to_type: &DataType, eval_mode: EvalMode, - ) -> CometResult { + ) -> SparkResult { let string_array = array .as_any() .downcast_ref::>() @@ -768,7 +766,7 @@ impl Cast { precision: u8, scale: i8, eval_mode: EvalMode, - ) -> CometResult { + ) -> SparkResult { Self::cast_floating_point_to_decimal128::(array, precision, scale, eval_mode) } @@ -777,7 +775,7 @@ impl Cast { precision: u8, scale: i8, eval_mode: EvalMode, - ) -> CometResult { + ) -> SparkResult { Self::cast_floating_point_to_decimal128::(array, precision, scale, eval_mode) } @@ -786,7 +784,7 @@ impl Cast { precision: u8, scale: i8, eval_mode: EvalMode, - ) -> CometResult + ) -> SparkResult where ::Native: AsPrimitive, { @@ -806,7 +804,7 @@ impl Cast { Some(v) => { if Decimal128Type::validate_decimal_precision(v, precision).is_err() { if eval_mode == EvalMode::Ansi { - return Err(CometError::NumericValueOutOfRange { + return Err(SparkError::NumericValueOutOfRange { value: input_value.to_string(), precision, scale, @@ -819,7 +817,7 @@ impl Cast { } None => { if eval_mode == EvalMode::Ansi { - return Err(CometError::NumericValueOutOfRange { + return Err(SparkError::NumericValueOutOfRange { value: input_value.to_string(), precision, scale, @@ -843,7 +841,7 @@ impl Cast { fn spark_cast_float64_to_utf8( from: &dyn Array, _eval_mode: EvalMode, - ) -> CometResult + ) -> SparkResult where OffsetSize: OffsetSizeTrait, { @@ -853,7 +851,7 @@ impl Cast { fn spark_cast_float32_to_utf8( from: &dyn Array, _eval_mode: EvalMode, - ) -> CometResult + ) -> SparkResult where OffsetSize: OffsetSizeTrait, { @@ -865,7 +863,7 @@ impl Cast { eval_mode: EvalMode, from_type: &DataType, to_type: &DataType, - ) -> CometResult { + ) -> SparkResult { match (from_type, to_type) { (DataType::Int64, DataType::Int32) => cast_int_to_int_macro!( array, eval_mode, Int64Type, Int32Type, from_type, i32, "BIGINT", "INT" @@ -895,7 +893,7 @@ impl Cast { fn spark_cast_utf8_to_boolean( from: &dyn Array, eval_mode: EvalMode, - ) -> CometResult + ) -> SparkResult where OffsetSize: OffsetSizeTrait, { @@ -910,7 +908,7 @@ impl Cast { Some(value) => match value.to_ascii_lowercase().trim() { "t" | "true" | "y" | "yes" | "1" => Ok(Some(true)), "f" | "false" | "n" | "no" | "0" => Ok(Some(false)), - _ if eval_mode == EvalMode::Ansi => Err(CometError::CastInvalidValue { + _ if eval_mode == EvalMode::Ansi => Err(SparkError::CastInvalidValue { value: value.to_string(), from_type: "STRING".to_string(), to_type: "BOOLEAN".to_string(), @@ -929,7 +927,7 @@ impl Cast { eval_mode: EvalMode, from_type: &DataType, to_type: &DataType, - ) -> CometResult { + ) -> SparkResult { match (from_type, to_type) { (DataType::Float32, DataType::Int8) => cast_float_to_int16_down!( array, @@ -1066,7 +1064,7 @@ impl Cast { } /// Equivalent to org.apache.spark.unsafe.types.UTF8String.toByte -fn cast_string_to_i8(str: &str, eval_mode: EvalMode) -> CometResult> { +fn cast_string_to_i8(str: &str, eval_mode: EvalMode) -> SparkResult> { Ok(cast_string_to_int_with_range_check( str, eval_mode, @@ -1078,7 +1076,7 @@ fn cast_string_to_i8(str: &str, eval_mode: EvalMode) -> CometResult> } /// Equivalent to org.apache.spark.unsafe.types.UTF8String.toShort -fn cast_string_to_i16(str: &str, eval_mode: EvalMode) -> CometResult> { +fn cast_string_to_i16(str: &str, eval_mode: EvalMode) -> SparkResult> { Ok(cast_string_to_int_with_range_check( str, eval_mode, @@ -1090,12 +1088,12 @@ fn cast_string_to_i16(str: &str, eval_mode: EvalMode) -> CometResult } /// Equivalent to org.apache.spark.unsafe.types.UTF8String.toInt(IntWrapper intWrapper) -fn cast_string_to_i32(str: &str, eval_mode: EvalMode) -> CometResult> { +fn cast_string_to_i32(str: &str, eval_mode: EvalMode) -> SparkResult> { do_cast_string_to_int::(str, eval_mode, "INT", i32::MIN) } /// Equivalent to org.apache.spark.unsafe.types.UTF8String.toLong(LongWrapper intWrapper) -fn cast_string_to_i64(str: &str, eval_mode: EvalMode) -> CometResult> { +fn cast_string_to_i64(str: &str, eval_mode: EvalMode) -> SparkResult> { do_cast_string_to_int::(str, eval_mode, "BIGINT", i64::MIN) } @@ -1105,7 +1103,7 @@ fn cast_string_to_int_with_range_check( type_name: &str, min: i32, max: i32, -) -> CometResult> { +) -> SparkResult> { match do_cast_string_to_int(str, eval_mode, type_name, i32::MIN)? { None => Ok(None), Some(v) if v >= min && v <= max => Ok(Some(v)), @@ -1124,7 +1122,7 @@ fn do_cast_string_to_int< eval_mode: EvalMode, type_name: &str, min_value: T, -) -> CometResult> { +) -> SparkResult> { let trimmed_str = str.trim(); if trimmed_str.is_empty() { return none_or_err(eval_mode, type_name, str); @@ -1208,9 +1206,9 @@ fn do_cast_string_to_int< Ok(Some(result)) } -/// Either return Ok(None) or Err(CometError::CastInvalidValue) depending on the evaluation mode +/// Either return Ok(None) or Err(SparkError::CastInvalidValue) depending on the evaluation mode #[inline] -fn none_or_err(eval_mode: EvalMode, type_name: &str, str: &str) -> CometResult> { +fn none_or_err(eval_mode: EvalMode, type_name: &str, str: &str) -> SparkResult> { match eval_mode { EvalMode::Ansi => Err(invalid_value(str, "STRING", type_name)), _ => Ok(None), @@ -1218,8 +1216,8 @@ fn none_or_err(eval_mode: EvalMode, type_name: &str, str: &str) -> CometResul } #[inline] -fn invalid_value(value: &str, from_type: &str, to_type: &str) -> CometError { - CometError::CastInvalidValue { +fn invalid_value(value: &str, from_type: &str, to_type: &str) -> SparkError { + SparkError::CastInvalidValue { value: value.to_string(), from_type: from_type.to_string(), to_type: to_type.to_string(), @@ -1227,8 +1225,8 @@ fn invalid_value(value: &str, from_type: &str, to_type: &str) -> CometError { } #[inline] -fn cast_overflow(value: &str, from_type: &str, to_type: &str) -> CometError { - CometError::CastOverFlow { +fn cast_overflow(value: &str, from_type: &str, to_type: &str) -> SparkError { + SparkError::CastOverFlow { value: value.to_string(), from_type: from_type.to_string(), to_type: to_type.to_string(), @@ -1316,7 +1314,7 @@ impl PhysicalExpr for Cast { } } -fn timestamp_parser(value: &str, eval_mode: EvalMode) -> CometResult> { +fn timestamp_parser(value: &str, eval_mode: EvalMode) -> SparkResult> { let value = value.trim(); if value.is_empty() { return Ok(None); @@ -1325,7 +1323,7 @@ fn timestamp_parser(value: &str, eval_mode: EvalMode) -> CometResult let patterns = &[ ( Regex::new(r"^\d{4}$").unwrap(), - parse_str_to_year_timestamp as fn(&str) -> CometResult>, + parse_str_to_year_timestamp as fn(&str) -> SparkResult>, ), ( Regex::new(r"^\d{4}-\d{2}$").unwrap(), @@ -1369,7 +1367,7 @@ fn timestamp_parser(value: &str, eval_mode: EvalMode) -> CometResult if timestamp.is_none() { return if eval_mode == EvalMode::Ansi { - Err(CometError::CastInvalidValue { + Err(SparkError::CastInvalidValue { value: value.to_string(), from_type: "STRING".to_string(), to_type: "TIMESTAMP".to_string(), @@ -1381,20 +1379,20 @@ fn timestamp_parser(value: &str, eval_mode: EvalMode) -> CometResult match timestamp { Some(ts) => Ok(Some(ts)), - None => Err(CometError::Internal( + None => Err(SparkError::Internal( "Failed to parse timestamp".to_string(), )), } } -fn parse_ymd_timestamp(year: i32, month: u32, day: u32) -> CometResult> { +fn parse_ymd_timestamp(year: i32, month: u32, day: u32) -> SparkResult> { let datetime = chrono::Utc.with_ymd_and_hms(year, month, day, 0, 0, 0); // Check if datetime is not None let utc_datetime = match datetime.single() { Some(dt) => dt.with_timezone(&chrono::Utc), None => { - return Err(CometError::Internal( + return Err(SparkError::Internal( "Failed to parse timestamp".to_string(), )); } @@ -1411,7 +1409,7 @@ fn parse_hms_timestamp( minute: u32, second: u32, microsecond: u32, -) -> CometResult> { +) -> SparkResult> { let datetime = chrono::Utc.with_ymd_and_hms(year, month, day, hour, minute, second); // Check if datetime is not None @@ -1420,7 +1418,7 @@ fn parse_hms_timestamp( .with_timezone(&chrono::Utc) .with_nanosecond(microsecond * 1000), None => { - return Err(CometError::Internal( + return Err(SparkError::Internal( "Failed to parse timestamp".to_string(), )); } @@ -1429,7 +1427,7 @@ fn parse_hms_timestamp( let result = match utc_datetime { Some(dt) => dt.timestamp_micros(), None => { - return Err(CometError::Internal( + return Err(SparkError::Internal( "Failed to parse timestamp".to_string(), )); } @@ -1438,7 +1436,7 @@ fn parse_hms_timestamp( Ok(Some(result)) } -fn get_timestamp_values(value: &str, timestamp_type: &str) -> CometResult> { +fn get_timestamp_values(value: &str, timestamp_type: &str) -> SparkResult> { let values: Vec<_> = value .split(|c| c == 'T' || c == '-' || c == ':' || c == '.') .collect(); @@ -1458,7 +1456,7 @@ fn get_timestamp_values(value: &str, timestamp_type: &str) -> CometResult