diff --git a/core/src/errors.rs b/core/src/errors.rs index 04a1629d53..af4fd26973 100644 --- a/core/src/errors.rs +++ b/core/src/errors.rs @@ -88,6 +88,9 @@ pub enum CometError { to_type: String, }, + #[error("[ARITHMETIC_OVERFLOW] {from_type} overflow. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + ArithmeticOverflow { from_type: String }, + #[error(transparent)] Arrow { #[from] diff --git a/core/src/execution/datafusion/expressions/mod.rs b/core/src/execution/datafusion/expressions/mod.rs index 9db4b65b3f..084fef2df3 100644 --- a/core/src/execution/datafusion/expressions/mod.rs +++ b/core/src/execution/datafusion/expressions/mod.rs @@ -29,6 +29,7 @@ pub mod avg_decimal; pub mod bloom_filter_might_contain; pub mod correlation; pub mod covariance; +pub mod negative; pub mod stats; pub mod stddev; pub mod strings; diff --git a/core/src/execution/datafusion/expressions/negative.rs b/core/src/execution/datafusion/expressions/negative.rs new file mode 100644 index 0000000000..e7aa2ac646 --- /dev/null +++ b/core/src/execution/datafusion/expressions/negative.rs @@ -0,0 +1,270 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::errors::CometError; +use arrow::{compute::kernels::numeric::neg_wrapping, datatypes::IntervalDayTimeType}; +use arrow_array::RecordBatch; +use arrow_schema::{DataType, Schema}; +use datafusion::{ + logical_expr::{interval_arithmetic::Interval, ColumnarValue}, + physical_expr::PhysicalExpr, +}; +use datafusion_common::{Result, ScalarValue}; +use datafusion_physical_expr::{ + aggregate::utils::down_cast_any_ref, sort_properties::SortProperties, +}; +use std::{ + any::Any, + hash::{Hash, Hasher}, + sync::Arc, +}; + +pub fn create_negate_expr( + expr: Arc, + fail_on_error: bool, +) -> Result, CometError> { + Ok(Arc::new(NegativeExpr::new(expr, fail_on_error))) +} + +/// Negative expression +#[derive(Debug, Hash)] +pub struct NegativeExpr { + /// Input expression + arg: Arc, + fail_on_error: bool, +} + +fn arithmetic_overflow_error(from_type: &str) -> CometError { + CometError::ArithmeticOverflow { + from_type: from_type.to_string(), + } +} + +macro_rules! check_overflow { + ($array:expr, $array_type:ty, $min_val:expr, $type_name:expr) => {{ + let typed_array = $array + .as_any() + .downcast_ref::<$array_type>() + .expect(concat!(stringify!($array_type), " expected")); + for i in 0..typed_array.len() { + if typed_array.value(i) == $min_val { + if $type_name == "byte" || $type_name == "short" { + let value = typed_array.value(i).to_string() + " caused"; + return Err(arithmetic_overflow_error(value.as_str()).into()); + } + return Err(arithmetic_overflow_error($type_name).into()); + } + } + }}; +} + +impl NegativeExpr { + /// Create new not expression + pub fn new(arg: Arc, fail_on_error: bool) -> Self { + Self { arg, fail_on_error } + } + + /// Get the input expression + pub fn arg(&self) -> &Arc { + &self.arg + } +} + +impl std::fmt::Display for NegativeExpr { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "(- {})", self.arg) + } +} + +impl PhysicalExpr for NegativeExpr { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, input_schema: &Schema) -> Result { + self.arg.data_type(input_schema) + } + + fn nullable(&self, input_schema: &Schema) -> Result { + self.arg.nullable(input_schema) + } + + fn evaluate(&self, batch: &RecordBatch) -> Result { + let arg = self.arg.evaluate(batch)?; + + // overflow checks only apply in ANSI mode + // datatypes supported are byte, short, integer, long, float, interval + match arg { + ColumnarValue::Array(array) => { + if self.fail_on_error { + match array.data_type() { + DataType::Int8 => { + check_overflow!(array, arrow::array::Int8Array, i8::MIN, "byte") + } + DataType::Int16 => { + check_overflow!(array, arrow::array::Int16Array, i16::MIN, "short") + } + DataType::Int32 => { + check_overflow!(array, arrow::array::Int32Array, i32::MIN, "integer") + } + DataType::Int64 => { + check_overflow!(array, arrow::array::Int64Array, i64::MIN, "long") + } + DataType::Interval(value) => match value { + arrow::datatypes::IntervalUnit::YearMonth => check_overflow!( + array, + arrow::array::IntervalYearMonthArray, + i32::MIN, + "interval" + ), + arrow::datatypes::IntervalUnit::DayTime => check_overflow!( + array, + arrow::array::IntervalDayTimeArray, + i64::MIN, + "interval" + ), + arrow::datatypes::IntervalUnit::MonthDayNano => { + // Overflow checks are not supported + } + }, + _ => { + // Overflow checks are not supported for other datatypes + } + } + } + let result = neg_wrapping(array.as_ref())?; + Ok(ColumnarValue::Array(result)) + } + ColumnarValue::Scalar(scalar) => { + if self.fail_on_error { + match scalar { + ScalarValue::Int8(value) => { + if value == Some(i8::MIN) { + return Err(arithmetic_overflow_error(" caused").into()); + } + } + ScalarValue::Int16(value) => { + if value == Some(i16::MIN) { + return Err(arithmetic_overflow_error(" caused").into()); + } + } + ScalarValue::Int32(value) => { + if value == Some(i32::MIN) { + return Err(arithmetic_overflow_error("integer").into()); + } + } + ScalarValue::Int64(value) => { + if value == Some(i64::MIN) { + return Err(arithmetic_overflow_error("long").into()); + } + } + ScalarValue::IntervalDayTime(value) => { + let (days, ms) = + IntervalDayTimeType::to_parts(value.unwrap_or_default()); + if days == i32::MIN || ms == i32::MIN { + return Err(arithmetic_overflow_error("interval").into()); + } + } + ScalarValue::IntervalYearMonth(value) => { + if value == Some(i32::MIN) { + return Err(arithmetic_overflow_error("interval").into()); + } + } + _ => { + // Overflow checks are not supported for other datatypes + } + } + } + Ok(ColumnarValue::Scalar((scalar.arithmetic_negate())?)) + } + } + } + + fn children(&self) -> Vec> { + vec![self.arg.clone()] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(NegativeExpr::new( + children[0].clone(), + self.fail_on_error, + ))) + } + + fn dyn_hash(&self, state: &mut dyn Hasher) { + let mut s = state; + self.hash(&mut s); + } + + /// Given the child interval of a NegativeExpr, it calculates the NegativeExpr's interval. + /// It replaces the upper and lower bounds after multiplying them with -1. + /// Ex: `(a, b]` => `[-b, -a)` + fn evaluate_bounds(&self, children: &[&Interval]) -> Result { + Interval::try_new( + children[0].upper().arithmetic_negate()?, + children[0].lower().arithmetic_negate()?, + ) + } + + /// Returns a new [`Interval`] of a NegativeExpr that has the existing `interval` given that + /// given the input interval is known to be `children`. + fn propagate_constraints( + &self, + interval: &Interval, + children: &[&Interval], + ) -> Result>> { + let child_interval = children[0]; + + if child_interval.lower() == &ScalarValue::Int32(Some(i32::MIN)) + || child_interval.upper() == &ScalarValue::Int32(Some(i32::MIN)) + || child_interval.lower() == &ScalarValue::Int64(Some(i64::MIN)) + || child_interval.upper() == &ScalarValue::Int64(Some(i64::MIN)) + { + return Err(CometError::ArithmeticOverflow { + from_type: "long".to_string(), + } + .into()); + } + + let negated_interval = Interval::try_new( + interval.upper().arithmetic_negate()?, + interval.lower().arithmetic_negate()?, + )?; + + Ok(child_interval + .intersect(negated_interval)? + .map(|result| vec![result])) + } + + /// The ordering of a [`NegativeExpr`] is simply the reverse of its child. + fn get_ordering(&self, children: &[SortProperties]) -> SortProperties { + -children[0] + } +} + +impl PartialEq for NegativeExpr { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| self.arg.eq(&x.arg)) + .unwrap_or(false) + } +} diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index 20119b13a5..3a8548f776 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -33,7 +33,7 @@ use datafusion::{ expressions::{ in_list, BinaryExpr, BitAnd, BitOr, BitXor, CaseExpr, CastExpr, Column, Count, FirstValue, InListExpr, IsNotNullExpr, IsNullExpr, LastValue, - Literal as DataFusionLiteral, Max, Min, NegativeExpr, NotExpr, Sum, UnKnownColumn, + Literal as DataFusionLiteral, Max, Min, NotExpr, Sum, UnKnownColumn, }, AggregateExpr, PhysicalExpr, PhysicalSortExpr, ScalarFunctionExpr, }, @@ -70,6 +70,7 @@ use crate::{ correlation::Correlation, covariance::Covariance, if_expr::IfExpr, + negative, scalar_funcs::create_comet_physical_fun, stats::StatsType, stddev::Stddev, @@ -563,8 +564,10 @@ impl PhysicalPlanner { Ok(Arc::new(NotExpr::new(child))) } ExprStruct::Negative(expr) => { - let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?; - Ok(Arc::new(NegativeExpr::new(child))) + let child: Arc = + self.create_expr(expr.child.as_ref().unwrap(), input_schema.clone())?; + let result = negative::create_negate_expr(child, expr.fail_on_error); + result.map_err(|e| ExecutionError::GeneralError(e.to_string())) } ExprStruct::NormalizeNanAndZero(expr) => { let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?; diff --git a/core/src/execution/proto/expr.proto b/core/src/execution/proto/expr.proto index bcd9838753..9c60490135 100644 --- a/core/src/execution/proto/expr.proto +++ b/core/src/execution/proto/expr.proto @@ -454,6 +454,7 @@ message Not { message Negative { Expr child = 1; + bool fail_on_error = 2; } message IfExpr { diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 79a75102d7..5fe290cf65 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -1984,11 +1984,12 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim None } - case UnaryMinus(child, _) => + case UnaryMinus(child, failOnError) => val childExpr = exprToProtoInternal(child, inputs) if (childExpr.isDefined) { val builder = ExprOuterClass.Negative.newBuilder() builder.setChild(childExpr.get) + builder.setFailOnError(failOnError) Some( ExprOuterClass.Expr .newBuilder() diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index e15b09ca2b..e69054dbde 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -1548,5 +1548,103 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } } + test("unary negative integer overflow test") { + def withAnsiMode(enabled: Boolean)(f: => Unit): Unit = { + withSQLConf( + SQLConf.ANSI_ENABLED.key -> enabled.toString, + CometConf.COMET_ANSI_MODE_ENABLED.key -> enabled.toString, + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true")(f) + } + + def checkOverflow(query: String, dtype: String): Unit = { + checkSparkMaybeThrows(sql(query)) match { + case (Some(sparkException), Some(cometException)) => + assert(sparkException.getMessage.contains(dtype + " overflow")) + assert(cometException.getMessage.contains(dtype + " overflow")) + case (None, None) => assert(true) // got same outputs + case (None, Some(ex)) => + fail("Comet threw an exception but Spark did not " + ex.getMessage) + case (Some(_), None) => + fail("Spark threw an exception but Comet did not") + } + } + + def runArrayTest(query: String, dtype: String, path: String): Unit = { + withParquetTable(path, "t") { + withAnsiMode(enabled = false) { + checkSparkAnswerAndOperator(sql(query)) + } + withAnsiMode(enabled = true) { + checkOverflow(query, dtype) + } + } + } + withTempDir { dir => + // Array values test + val arrayPath = new Path(dir.toURI.toString, "array_test.parquet").toString + Seq(Int.MaxValue, Int.MinValue).toDF("a").write.mode("overwrite").parquet(arrayPath) + val arrayQuery = "select a, -a from t" + runArrayTest(arrayQuery, "integer", arrayPath) + + // long values test + val longArrayPath = new Path(dir.toURI.toString, "long_array_test.parquet").toString + Seq(Long.MaxValue, Long.MinValue) + .toDF("a") + .write + .mode("overwrite") + .parquet(longArrayPath) + val longArrayQuery = "select a, -a from t" + runArrayTest(longArrayQuery, "long", longArrayPath) + + // short values test + val shortArrayPath = new Path(dir.toURI.toString, "short_array_test.parquet").toString + Seq(Short.MaxValue, Short.MinValue) + .toDF("a") + .write + .mode("overwrite") + .parquet(shortArrayPath) + val shortArrayQuery = "select a, -a from t" + runArrayTest(shortArrayQuery, " caused", shortArrayPath) + + // byte values test + val byteArrayPath = new Path(dir.toURI.toString, "byte_array_test.parquet").toString + Seq(Byte.MaxValue, Byte.MinValue) + .toDF("a") + .write + .mode("overwrite") + .parquet(byteArrayPath) + val byteArrayQuery = "select a, -a from t" + runArrayTest(byteArrayQuery, " caused", byteArrayPath) + + // interval values test + withTable("t_interval") { + spark.sql("CREATE TABLE t_interval(a STRING) USING PARQUET") + spark.sql("INSERT INTO t_interval VALUES ('INTERVAL 10000000000 YEAR')") + withAnsiMode(enabled = true) { + spark + .sql("SELECT CAST(a AS INTERVAL) AS a FROM t_interval") + .createOrReplaceTempView("t_interval_casted") + checkOverflow("SELECT a, -a FROM t_interval_casted", "interval") + } + } + + withTable("t") { + sql("create table t(a int) using parquet") + sql("insert into t values (-2147483648)") + withAnsiMode(enabled = true) { + checkOverflow("select a, -a from t", "integer") + } + } + + withTable("t_float") { + sql("create table t_float(a float) using parquet") + sql("insert into t_float values (3.4128235E38)") + withAnsiMode(enabled = true) { + checkOverflow("select a, -a from t_float", "float") + } + } + } + } }