diff --git a/native/core/src/execution/datafusion/planner.rs b/native/core/src/execution/datafusion/planner.rs index 83f86dbee..eadd8fc3b 100644 --- a/native/core/src/execution/datafusion/planner.rs +++ b/native/core/src/execution/datafusion/planner.rs @@ -185,6 +185,7 @@ impl PhysicalPlanner { ExprStruct::Add(expr) => self.create_binary_expr( expr.left.as_ref().unwrap(), expr.right.as_ref().unwrap(), + expr.fail_on_error, expr.return_type.as_ref(), DataFusionOperator::Plus, input_schema, @@ -192,6 +193,7 @@ impl PhysicalPlanner { ExprStruct::Subtract(expr) => self.create_binary_expr( expr.left.as_ref().unwrap(), expr.right.as_ref().unwrap(), + expr.fail_on_error, expr.return_type.as_ref(), DataFusionOperator::Minus, input_schema, @@ -199,6 +201,7 @@ impl PhysicalPlanner { ExprStruct::Multiply(expr) => self.create_binary_expr( expr.left.as_ref().unwrap(), expr.right.as_ref().unwrap(), + expr.fail_on_error, expr.return_type.as_ref(), DataFusionOperator::Multiply, input_schema, @@ -206,6 +209,7 @@ impl PhysicalPlanner { ExprStruct::Divide(expr) => self.create_binary_expr( expr.left.as_ref().unwrap(), expr.right.as_ref().unwrap(), + expr.fail_on_error, expr.return_type.as_ref(), DataFusionOperator::Divide, input_schema, @@ -213,6 +217,7 @@ impl PhysicalPlanner { ExprStruct::Remainder(expr) => self.create_binary_expr( expr.left.as_ref().unwrap(), expr.right.as_ref().unwrap(), + expr.fail_on_error, expr.return_type.as_ref(), DataFusionOperator::Modulo, input_schema, @@ -777,6 +782,7 @@ impl PhysicalPlanner { &self, left: &Expr, right: &Expr, + fail_on_error: bool, return_type: Option<&spark_expression::DataType>, op: DataFusionOperator, input_schema: SchemaRef, @@ -818,7 +824,8 @@ impl PhysicalPlanner { EvalMode::Legacy, false, )); - let child = Arc::new(BinaryExpr::new(left, op, right)); + let child = + Arc::new(BinaryExpr::new(left, op, right).with_fail_on_overflow(fail_on_error)); Ok(Arc::new(Cast::new_without_timezone( child, data_type, @@ -844,7 +851,9 @@ impl PhysicalPlanner { data_type, ))) } - _ => Ok(Arc::new(BinaryExpr::new(left, op, right))), + _ => Ok(Arc::new( + BinaryExpr::new(left, op, right).with_fail_on_overflow(fail_on_error), + )), } } diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 35f374bf0..efa8d7c53 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -942,6 +942,26 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("add overflow (ANSI enabled)") { + // Enabling ANSI will cause native engine failure, but as we cannot catch + // native error now, we cannot test it here. + withAnsiMode(enabled = true) { + withParquetTable(Seq((Int.MaxValue, 1)), "tbl") { + checkOverflow("SELECT _1 + _2 FROM tbl", "") + } + } + } + + test("subtract overflow (ANSI enabled)") { + // Enabling ANSI will cause native engine failure, but as we cannot catch + // native error now, we cannot test it here. + withAnsiMode(enabled = false) { + withParquetTable(Seq((Int.MinValue, 1)), "tbl") { + checkOverflow("SELECT _1 - _2 FROM tbl", "") + } + } + } + test("divide by zero (ANSI disable)") { // Enabling ANSI will cause native engine failure, but as we cannot catch // native error now, we cannot test it here. @@ -1922,38 +1942,41 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } } - test("unary negative integer overflow test") { - def withAnsiMode(enabled: Boolean)(f: => Unit): Unit = { - withSQLConf( - SQLConf.ANSI_ENABLED.key -> enabled.toString, - CometConf.COMET_ANSI_MODE_ENABLED.key -> enabled.toString, - CometConf.COMET_ENABLED.key -> "true", - CometConf.COMET_EXEC_ENABLED.key -> "true")(f) - } - def checkOverflow(query: String, dtype: String): Unit = { - checkSparkMaybeThrows(sql(query)) match { - case (Some(sparkException), Some(cometException)) => - assert(sparkException.getMessage.contains(dtype + " overflow")) - assert(cometException.getMessage.contains(dtype + " overflow")) - case (None, None) => checkSparkAnswerAndOperator(sql(query)) - case (None, Some(ex)) => - fail("Comet threw an exception but Spark did not " + ex.getMessage) - case (Some(_), None) => - fail("Spark threw an exception but Comet did not") - } + def withAnsiMode(enabled: Boolean)(f: => Unit): Unit = { + withSQLConf( + SQLConf.ANSI_ENABLED.key -> enabled.toString, + CometConf.COMET_ANSI_MODE_ENABLED.key -> enabled.toString, + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true")(f) + } + + def checkOverflow(query: String, dtype: String): Unit = { + checkSparkMaybeThrows(sql(query)) match { + case (Some(sparkException), Some(cometException)) => + println(cometException.getMessage) + assert(sparkException.getMessage.contains(dtype + " overflow")) + assert(cometException.getMessage.contains(dtype + " overflow")) + case (None, None) => checkSparkAnswerAndOperator(sql(query)) + case (None, Some(ex)) => + fail("Comet threw an exception but Spark did not " + ex.getMessage) + case (Some(_), None) => + fail("Spark threw an exception but Comet did not") } + } - def runArrayTest(query: String, dtype: String, path: String): Unit = { - withParquetTable(path, "t") { - withAnsiMode(enabled = false) { - checkSparkAnswerAndOperator(sql(query)) - } - withAnsiMode(enabled = true) { - checkOverflow(query, dtype) - } + def runArrayTest(query: String, dtype: String, path: String): Unit = { + withParquetTable(path, "t") { + withAnsiMode(enabled = false) { + checkSparkAnswerAndOperator(sql(query)) + } + withAnsiMode(enabled = true) { + checkOverflow(query, dtype) } } + } + + test("unary negative integer overflow test") { withTempDir { dir => // Array values test @@ -2012,6 +2035,8 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("add integer overflow test") {} + test("readSidePadding") { // https://stackoverflow.com/a/46290728 val table = "test"