Skip to content

Commit

Permalink
feat: Implement ANSI support for UnaryMinus (apache#471)
Browse files Browse the repository at this point in the history
* checking for invalid inputs for unary minus

* adding eval mode to expressions and proto message

* extending evaluate function for negative expression

* remove print statements

* fix format errors

* removing units

* fix clippy errors

* expect instead of unwrap, map_err instead of match and removing Float16

* adding test case for unary negative integer overflow

* added a function to make the code more readable

* adding comet sql ansi config

* using withTempDir and checkSparkAnswerAndOperator

* adding macros to improve code readability

* using withParquetTable

* adding scalar tests

* adding more test cases and bug fix

* using failonerror and removing eval_mode

* bug fix

* removing checks for float64 and monthdaynano

* removing checks of float and monthday nano

* adding checks while evalute bounds

* IntervalDayTime splitting i64 and then checking

* Adding interval test

* fix ci errors
  • Loading branch information
vaibhawvipul authored and kazuyukitanimura committed Jul 1, 2024
1 parent 6ae6433 commit 9108f2a
Show file tree
Hide file tree
Showing 7 changed files with 381 additions and 4 deletions.
3 changes: 3 additions & 0 deletions core/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ pub enum CometError {
to_type: String,
},

#[error("[ARITHMETIC_OVERFLOW] {from_type} overflow. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")]
ArithmeticOverflow { from_type: String },

#[error(transparent)]
Arrow {
#[from]
Expand Down
1 change: 1 addition & 0 deletions core/src/execution/datafusion/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ pub mod avg_decimal;
pub mod bloom_filter_might_contain;
pub mod correlation;
pub mod covariance;
pub mod negative;
pub mod stats;
pub mod stddev;
pub mod strings;
Expand Down
270 changes: 270 additions & 0 deletions core/src/execution/datafusion/expressions/negative.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use crate::errors::CometError;
use arrow::{compute::kernels::numeric::neg_wrapping, datatypes::IntervalDayTimeType};
use arrow_array::RecordBatch;
use arrow_schema::{DataType, Schema};
use datafusion::{
logical_expr::{interval_arithmetic::Interval, ColumnarValue},
physical_expr::PhysicalExpr,
};
use datafusion_common::{Result, ScalarValue};
use datafusion_physical_expr::{
aggregate::utils::down_cast_any_ref, sort_properties::SortProperties,
};
use std::{
any::Any,
hash::{Hash, Hasher},
sync::Arc,
};

pub fn create_negate_expr(
expr: Arc<dyn PhysicalExpr>,
fail_on_error: bool,
) -> Result<Arc<dyn PhysicalExpr>, CometError> {
Ok(Arc::new(NegativeExpr::new(expr, fail_on_error)))
}

/// Negative expression
#[derive(Debug, Hash)]
pub struct NegativeExpr {
/// Input expression
arg: Arc<dyn PhysicalExpr>,
fail_on_error: bool,
}

fn arithmetic_overflow_error(from_type: &str) -> CometError {
CometError::ArithmeticOverflow {
from_type: from_type.to_string(),
}
}

macro_rules! check_overflow {
($array:expr, $array_type:ty, $min_val:expr, $type_name:expr) => {{
let typed_array = $array
.as_any()
.downcast_ref::<$array_type>()
.expect(concat!(stringify!($array_type), " expected"));
for i in 0..typed_array.len() {
if typed_array.value(i) == $min_val {
if $type_name == "byte" || $type_name == "short" {
let value = typed_array.value(i).to_string() + " caused";
return Err(arithmetic_overflow_error(value.as_str()).into());
}
return Err(arithmetic_overflow_error($type_name).into());
}
}
}};
}

impl NegativeExpr {
/// Create new not expression
pub fn new(arg: Arc<dyn PhysicalExpr>, fail_on_error: bool) -> Self {
Self { arg, fail_on_error }
}

/// Get the input expression
pub fn arg(&self) -> &Arc<dyn PhysicalExpr> {
&self.arg
}
}

impl std::fmt::Display for NegativeExpr {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "(- {})", self.arg)
}
}

impl PhysicalExpr for NegativeExpr {
/// Return a reference to Any that can be used for downcasting
fn as_any(&self) -> &dyn Any {
self
}

fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
self.arg.data_type(input_schema)
}

fn nullable(&self, input_schema: &Schema) -> Result<bool> {
self.arg.nullable(input_schema)
}

fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
let arg = self.arg.evaluate(batch)?;

// overflow checks only apply in ANSI mode
// datatypes supported are byte, short, integer, long, float, interval
match arg {
ColumnarValue::Array(array) => {
if self.fail_on_error {
match array.data_type() {
DataType::Int8 => {
check_overflow!(array, arrow::array::Int8Array, i8::MIN, "byte")
}
DataType::Int16 => {
check_overflow!(array, arrow::array::Int16Array, i16::MIN, "short")
}
DataType::Int32 => {
check_overflow!(array, arrow::array::Int32Array, i32::MIN, "integer")
}
DataType::Int64 => {
check_overflow!(array, arrow::array::Int64Array, i64::MIN, "long")
}
DataType::Interval(value) => match value {
arrow::datatypes::IntervalUnit::YearMonth => check_overflow!(
array,
arrow::array::IntervalYearMonthArray,
i32::MIN,
"interval"
),
arrow::datatypes::IntervalUnit::DayTime => check_overflow!(
array,
arrow::array::IntervalDayTimeArray,
i64::MIN,
"interval"
),
arrow::datatypes::IntervalUnit::MonthDayNano => {
// Overflow checks are not supported
}
},
_ => {
// Overflow checks are not supported for other datatypes
}
}
}
let result = neg_wrapping(array.as_ref())?;
Ok(ColumnarValue::Array(result))
}
ColumnarValue::Scalar(scalar) => {
if self.fail_on_error {
match scalar {
ScalarValue::Int8(value) => {
if value == Some(i8::MIN) {
return Err(arithmetic_overflow_error(" caused").into());
}
}
ScalarValue::Int16(value) => {
if value == Some(i16::MIN) {
return Err(arithmetic_overflow_error(" caused").into());
}
}
ScalarValue::Int32(value) => {
if value == Some(i32::MIN) {
return Err(arithmetic_overflow_error("integer").into());
}
}
ScalarValue::Int64(value) => {
if value == Some(i64::MIN) {
return Err(arithmetic_overflow_error("long").into());
}
}
ScalarValue::IntervalDayTime(value) => {
let (days, ms) =
IntervalDayTimeType::to_parts(value.unwrap_or_default());
if days == i32::MIN || ms == i32::MIN {
return Err(arithmetic_overflow_error("interval").into());
}
}
ScalarValue::IntervalYearMonth(value) => {
if value == Some(i32::MIN) {
return Err(arithmetic_overflow_error("interval").into());
}
}
_ => {
// Overflow checks are not supported for other datatypes
}
}
}
Ok(ColumnarValue::Scalar((scalar.arithmetic_negate())?))
}
}
}

fn children(&self) -> Vec<Arc<dyn PhysicalExpr>> {
vec![self.arg.clone()]
}

fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn PhysicalExpr>>,
) -> Result<Arc<dyn PhysicalExpr>> {
Ok(Arc::new(NegativeExpr::new(
children[0].clone(),
self.fail_on_error,
)))
}

fn dyn_hash(&self, state: &mut dyn Hasher) {
let mut s = state;
self.hash(&mut s);
}

/// Given the child interval of a NegativeExpr, it calculates the NegativeExpr's interval.
/// It replaces the upper and lower bounds after multiplying them with -1.
/// Ex: `(a, b]` => `[-b, -a)`
fn evaluate_bounds(&self, children: &[&Interval]) -> Result<Interval> {
Interval::try_new(
children[0].upper().arithmetic_negate()?,
children[0].lower().arithmetic_negate()?,
)
}

/// Returns a new [`Interval`] of a NegativeExpr that has the existing `interval` given that
/// given the input interval is known to be `children`.
fn propagate_constraints(
&self,
interval: &Interval,
children: &[&Interval],
) -> Result<Option<Vec<Interval>>> {
let child_interval = children[0];

if child_interval.lower() == &ScalarValue::Int32(Some(i32::MIN))
|| child_interval.upper() == &ScalarValue::Int32(Some(i32::MIN))
|| child_interval.lower() == &ScalarValue::Int64(Some(i64::MIN))
|| child_interval.upper() == &ScalarValue::Int64(Some(i64::MIN))
{
return Err(CometError::ArithmeticOverflow {
from_type: "long".to_string(),
}
.into());
}

let negated_interval = Interval::try_new(
interval.upper().arithmetic_negate()?,
interval.lower().arithmetic_negate()?,
)?;

Ok(child_interval
.intersect(negated_interval)?
.map(|result| vec![result]))
}

/// The ordering of a [`NegativeExpr`] is simply the reverse of its child.
fn get_ordering(&self, children: &[SortProperties]) -> SortProperties {
-children[0]
}
}

impl PartialEq<dyn Any> for NegativeExpr {
fn eq(&self, other: &dyn Any) -> bool {
down_cast_any_ref(other)
.downcast_ref::<Self>()
.map(|x| self.arg.eq(&x.arg))
.unwrap_or(false)
}
}
9 changes: 6 additions & 3 deletions core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use datafusion::{
expressions::{
in_list, BinaryExpr, BitAnd, BitOr, BitXor, CaseExpr, CastExpr, Column, Count,
FirstValue, InListExpr, IsNotNullExpr, IsNullExpr, LastValue,
Literal as DataFusionLiteral, Max, Min, NegativeExpr, NotExpr, Sum, UnKnownColumn,
Literal as DataFusionLiteral, Max, Min, NotExpr, Sum, UnKnownColumn,
},
AggregateExpr, PhysicalExpr, PhysicalSortExpr, ScalarFunctionExpr,
},
Expand Down Expand Up @@ -70,6 +70,7 @@ use crate::{
correlation::Correlation,
covariance::Covariance,
if_expr::IfExpr,
negative,
scalar_funcs::create_comet_physical_fun,
stats::StatsType,
stddev::Stddev,
Expand Down Expand Up @@ -563,8 +564,10 @@ impl PhysicalPlanner {
Ok(Arc::new(NotExpr::new(child)))
}
ExprStruct::Negative(expr) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?;
Ok(Arc::new(NegativeExpr::new(child)))
let child: Arc<dyn PhysicalExpr> =
self.create_expr(expr.child.as_ref().unwrap(), input_schema.clone())?;
let result = negative::create_negate_expr(child, expr.fail_on_error);
result.map_err(|e| ExecutionError::GeneralError(e.to_string()))
}
ExprStruct::NormalizeNanAndZero(expr) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?;
Expand Down
1 change: 1 addition & 0 deletions core/src/execution/proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,7 @@ message Not {

message Negative {
Expr child = 1;
bool fail_on_error = 2;
}

message IfExpr {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1984,11 +1984,12 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
None
}

case UnaryMinus(child, _) =>
case UnaryMinus(child, failOnError) =>
val childExpr = exprToProtoInternal(child, inputs)
if (childExpr.isDefined) {
val builder = ExprOuterClass.Negative.newBuilder()
builder.setChild(childExpr.get)
builder.setFailOnError(failOnError)
Some(
ExprOuterClass.Expr
.newBuilder()
Expand Down
Loading

0 comments on commit 9108f2a

Please sign in to comment.