diff --git a/common/src/main/scala/org/apache/spark/sql/comet/CometArithmeticException.scala b/common/src/main/scala/org/apache/spark/sql/comet/CometCastOverflowException.scala similarity index 67% rename from common/src/main/scala/org/apache/spark/sql/comet/CometArithmeticException.scala rename to common/src/main/scala/org/apache/spark/sql/comet/CometCastOverflowException.scala index 41a671538..93b633307 100644 --- a/common/src/main/scala/org/apache/spark/sql/comet/CometArithmeticException.scala +++ b/common/src/main/scala/org/apache/spark/sql/comet/CometCastOverflowException.scala @@ -20,6 +20,16 @@ package org.apache.spark.sql.comet import org.apache.spark.SparkArithmeticException +import org.apache.spark.sql.errors.QueryExecutionErrors.toSQLConf +import org.apache.spark.sql.internal.SQLConf -class CometArithmeticException(message: String) - extends SparkArithmeticException("CAST_OVERFLOW", Map(), Array(), message) {} +class CometCastOverflowException(t: String, from: String, to: String) + extends SparkArithmeticException( + "CAST_OVERFLOW", + Map( + "value" -> t, + "sourceType" -> from, + "targetType" -> to, + "ansiConfig" -> toSQLConf(SQLConf.ANSI_ENABLED.key)), + Array.empty, + "") {} diff --git a/native/core/src/errors.rs b/native/core/src/errors.rs index 3f7ba0a21..f0e910f75 100644 --- a/native/core/src/errors.rs +++ b/native/core/src/errors.rs @@ -39,7 +39,7 @@ use jni::sys::{jboolean, jbyte, jchar, jdouble, jfloat, jint, jlong, jobject, js use crate::execution::operators::ExecutionError; use datafusion_comet_spark_expr::SparkError; -use jni::objects::{GlobalRef, JThrowable}; +use jni::objects::{GlobalRef, JThrowable, JValue}; use jni::JNIEnv; use lazy_static::lazy_static; use parquet::errors::ParquetError; @@ -234,13 +234,6 @@ impl jni::errors::ToException for CometError { class: "org/apache/comet/ParquetRuntimeException".to_string(), msg: self.to_string(), }, - CometError::DataFusion { - msg: _, - source: DataFusionError::External(e), - } if matches!(e.downcast_ref(), Some(SparkError::CastOverFlow { .. })) => Exception { - class: "org/apache/spark/sql/comet/CometArithmeticException".to_string(), - msg: self.to_string(), - }, _other => Exception { class: "org/apache/comet/CometNativeException".to_string(), msg: self.to_string(), @@ -390,6 +383,33 @@ fn throw_exception(env: &mut JNIEnv, error: &CometError, backtrace: Option env.throw(<&JThrowable>::from(throwable.as_obj())), + CometError::DataFusion { + msg: _, + source: DataFusionError::External(e), + } if matches!(e.downcast_ref(), Some(SparkError::CastOverFlow { .. })) => { + match e.downcast_ref() { + Some(SparkError::CastOverFlow { + value, + from_type, + to_type, + }) => { + let throwable: JThrowable = env + .new_object( + "org/apache/spark/sql/comet/CometCastOverflowException", + "(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;)V", + &[ + JValue::Object(&env.new_string(value).unwrap()), + JValue::Object(&env.new_string(from_type).unwrap()), + JValue::Object(&env.new_string(to_type).unwrap()), + ], + ) + .unwrap() + .into(); + env.throw(throwable) + } + _ => unreachable!(), + } + } _ => { let exception = error.to_exception(); match backtrace {