diff --git a/native/core/src/execution/datafusion/expressions/comet_scalar_funcs.rs b/native/core/src/execution/datafusion/expressions/comet_scalar_funcs.rs index 06717aabe..8f58ba7ca 100644 --- a/native/core/src/execution/datafusion/expressions/comet_scalar_funcs.rs +++ b/native/core/src/execution/datafusion/expressions/comet_scalar_funcs.rs @@ -34,6 +34,16 @@ use std::fmt::Debug; use std::sync::Arc; macro_rules! make_comet_scalar_udf { + ($name:expr, $func:ident, $data_type:ident, $fail_on_error:ident) => {{ + let scalar_func = CometScalarFunction::new( + $name.to_string(), + Signature::variadic_any(Volatility::Immutable), + $data_type.clone(), + Arc::new(move |args| $func(args, &$data_type)), + ); + // TODO Check for overflow + Ok(Arc::new(ScalarUDF::new_from_impl(scalar_func))) + }}; ($name:expr, $func:ident, $data_type:ident) => {{ let scalar_func = CometScalarFunction::new( $name.to_string(), @@ -59,6 +69,7 @@ pub fn create_comet_physical_fun( fun_name: &str, data_type: DataType, registry: &dyn FunctionRegistry, + _fail_on_error: &bool, ) -> Result, DataFusionError> { match fun_name { "ceil" => { @@ -72,7 +83,7 @@ pub fn create_comet_physical_fun( make_comet_scalar_udf!("read_side_padding", func, without data_type) } "round" => { - make_comet_scalar_udf!("round", spark_round, data_type) + make_comet_scalar_udf!("round", spark_round, data_type, _fail_on_error) } "unscaled_value" => { let func = Arc::new(spark_unscaled_value); diff --git a/native/core/src/execution/datafusion/planner.rs b/native/core/src/execution/datafusion/planner.rs index 15de7c9ad..822301674 100644 --- a/native/core/src/execution/datafusion/planner.rs +++ b/native/core/src/execution/datafusion/planner.rs @@ -775,10 +775,12 @@ impl PhysicalPlanner { Ok(DataType::Decimal128(_p2, _s2)), ) => { let data_type = return_type.map(to_arrow_datatype).unwrap(); + let fail_on_error = false; let fun_expr = create_comet_physical_fun( "decimal_div", data_type.clone(), &self.session_ctx.state(), + &fail_on_error, )?; Ok(Arc::new(ScalarFunctionExpr::new( "decimal_div", @@ -1872,6 +1874,7 @@ impl PhysicalPlanner { .collect::, _>>()?; let fun_name = &expr.func; + let fail_on_error = &expr.fail_on_error; let input_expr_types = args .iter() .map(|x| x.data_type(input_schema.as_ref())) @@ -1897,8 +1900,12 @@ impl PhysicalPlanner { } }; - let fun_expr = - create_comet_physical_fun(fun_name, data_type.clone(), &self.session_ctx.state())?; + let fun_expr = create_comet_physical_fun( + fun_name, + data_type.clone(), + &self.session_ctx.state(), + fail_on_error, + )?; let args = args .into_iter() diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index 88940f386..61656c0cc 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -404,6 +404,7 @@ message ScalarFunc { string func = 1; repeated Expr args = 2; DataType return_type = 3; + bool fail_on_error = 4; } message BitwiseAnd { diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 51b32b7df..b851a7e9a 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -1924,7 +1924,12 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim // `scale` must be Int64 type in DataFusion val scaleExpr = exprToProtoInternal(Literal(_scale.toLong, LongType), inputs) val optExpr = - scalarExprToProtoWithReturnType("round", r.dataType, childExpr, scaleExpr) + scalarExprToProtoWithReturnTypeAnsi( + "round", + r.dataType, + r.ansiEnabled, + childExpr, + scaleExpr) optExprWithInfo(optExpr, expr, r.child) } @@ -2610,6 +2615,20 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim } } + def scalarExprToProtoWithReturnTypeAnsi( + funcName: String, + returnType: DataType, + failOnError: Boolean, + args: Option[Expr]*): Option[Expr] = { + val builder = ExprOuterClass.ScalarFunc.newBuilder() + builder.setFunc(funcName) + builder.setFailOnError(failOnError) + serializeDataType(returnType).flatMap { t => + builder.setReturnType(t) + scalarExprToProto0(builder, args: _*) + } + } + def scalarExprToProto(funcName: String, args: Option[Expr]*): Option[Expr] = { val builder = ExprOuterClass.ScalarFunc.newBuilder() builder.setFunc(funcName) diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 16bc15b84..26fa23cfa 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -1270,6 +1270,56 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("round overflow test") { + def withAnsiMode(enabled: Boolean)(f: => Unit): Unit = { + withSQLConf( + SQLConf.ANSI_ENABLED.key -> enabled.toString, + CometConf.COMET_ANSI_MODE_ENABLED.key -> enabled.toString, + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true")(f) + } + + def checkOverflow(query: String, dtype: String): Unit = { + checkSparkMaybeThrows(sql(query)) match { + case (Some(sparkException), Some(cometException)) => + assert(sparkException.getMessage.contains(dtype + " overflow")) + assert(cometException.getMessage.contains(dtype + " overflow")) + case (None, None) => checkSparkAnswerAndOperator(sql(query)) + case (None, Some(ex)) => + fail("Comet threw an exception but Spark did not " + ex.getMessage) + case (Some(_), None) => + fail("Spark threw an exception but Comet did not") + } + } + + def runArrayTest(query: String, dtype: String, path: String): Unit = { + withParquetTable(path, "t") { + withAnsiMode(enabled = false) { + checkSparkAnswerAndOperator(sql(query)) + } + withAnsiMode(enabled = true) { + checkOverflow(query, dtype) + } + } + } + + withTempDir { dir => + // Array values test + val dataTypes = Seq( + ("array_test.parquet", Seq(Int.MaxValue, Int.MinValue).toDF("a"), "integer"), + ("long_array_test.parquet", Seq(Long.MaxValue, Long.MinValue).toDF("a"), "long"), + ("short_array_test.parquet", Seq(Short.MaxValue, Short.MinValue).toDF("a"), ""), + ("byte_array_test.parquet", Seq(Byte.MaxValue, Byte.MinValue).toDF("a"), "")) + + dataTypes.foreach { case (fileName, df, dtype) => + val path = new Path(dir.toURI.toString, fileName).toString + df.write.mode("overwrite").parquet(path) + val query = "select a, round(a, -1) FROM t" + runArrayTest(query, dtype, path) + } + } + } + test("Upper and Lower") { Seq(false, true).foreach { dictionary => withSQLConf(