forked from apache/datafusion-comet
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Implement ANSI support for UnaryMinus (apache#471)
* checking for invalid inputs for unary minus * adding eval mode to expressions and proto message * extending evaluate function for negative expression * remove print statements * fix format errors * removing units * fix clippy errors * expect instead of unwrap, map_err instead of match and removing Float16 * adding test case for unary negative integer overflow * added a function to make the code more readable * adding comet sql ansi config * using withTempDir and checkSparkAnswerAndOperator * adding macros to improve code readability * using withParquetTable * adding scalar tests * adding more test cases and bug fix * using failonerror and removing eval_mode * bug fix * removing checks for float64 and monthdaynano * removing checks of float and monthday nano * adding checks while evalute bounds * IntervalDayTime splitting i64 and then checking * Adding interval test * fix ci errors
- Loading branch information
1 parent
6ae6433
commit 9108f2a
Showing
7 changed files
with
381 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,270 @@ | ||
// Licensed to the Apache Software Foundation (ASF) under one | ||
// or more contributor license agreements. See the NOTICE file | ||
// distributed with this work for additional information | ||
// regarding copyright ownership. The ASF licenses this file | ||
// to you under the Apache License, Version 2.0 (the | ||
// "License"); you may not use this file except in compliance | ||
// with the License. You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, | ||
// software distributed under the License is distributed on an | ||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
// KIND, either express or implied. See the License for the | ||
// specific language governing permissions and limitations | ||
// under the License. | ||
|
||
use crate::errors::CometError; | ||
use arrow::{compute::kernels::numeric::neg_wrapping, datatypes::IntervalDayTimeType}; | ||
use arrow_array::RecordBatch; | ||
use arrow_schema::{DataType, Schema}; | ||
use datafusion::{ | ||
logical_expr::{interval_arithmetic::Interval, ColumnarValue}, | ||
physical_expr::PhysicalExpr, | ||
}; | ||
use datafusion_common::{Result, ScalarValue}; | ||
use datafusion_physical_expr::{ | ||
aggregate::utils::down_cast_any_ref, sort_properties::SortProperties, | ||
}; | ||
use std::{ | ||
any::Any, | ||
hash::{Hash, Hasher}, | ||
sync::Arc, | ||
}; | ||
|
||
pub fn create_negate_expr( | ||
expr: Arc<dyn PhysicalExpr>, | ||
fail_on_error: bool, | ||
) -> Result<Arc<dyn PhysicalExpr>, CometError> { | ||
Ok(Arc::new(NegativeExpr::new(expr, fail_on_error))) | ||
} | ||
|
||
/// Negative expression | ||
#[derive(Debug, Hash)] | ||
pub struct NegativeExpr { | ||
/// Input expression | ||
arg: Arc<dyn PhysicalExpr>, | ||
fail_on_error: bool, | ||
} | ||
|
||
fn arithmetic_overflow_error(from_type: &str) -> CometError { | ||
CometError::ArithmeticOverflow { | ||
from_type: from_type.to_string(), | ||
} | ||
} | ||
|
||
macro_rules! check_overflow { | ||
($array:expr, $array_type:ty, $min_val:expr, $type_name:expr) => {{ | ||
let typed_array = $array | ||
.as_any() | ||
.downcast_ref::<$array_type>() | ||
.expect(concat!(stringify!($array_type), " expected")); | ||
for i in 0..typed_array.len() { | ||
if typed_array.value(i) == $min_val { | ||
if $type_name == "byte" || $type_name == "short" { | ||
let value = typed_array.value(i).to_string() + " caused"; | ||
return Err(arithmetic_overflow_error(value.as_str()).into()); | ||
} | ||
return Err(arithmetic_overflow_error($type_name).into()); | ||
} | ||
} | ||
}}; | ||
} | ||
|
||
impl NegativeExpr { | ||
/// Create new not expression | ||
pub fn new(arg: Arc<dyn PhysicalExpr>, fail_on_error: bool) -> Self { | ||
Self { arg, fail_on_error } | ||
} | ||
|
||
/// Get the input expression | ||
pub fn arg(&self) -> &Arc<dyn PhysicalExpr> { | ||
&self.arg | ||
} | ||
} | ||
|
||
impl std::fmt::Display for NegativeExpr { | ||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { | ||
write!(f, "(- {})", self.arg) | ||
} | ||
} | ||
|
||
impl PhysicalExpr for NegativeExpr { | ||
/// Return a reference to Any that can be used for downcasting | ||
fn as_any(&self) -> &dyn Any { | ||
self | ||
} | ||
|
||
fn data_type(&self, input_schema: &Schema) -> Result<DataType> { | ||
self.arg.data_type(input_schema) | ||
} | ||
|
||
fn nullable(&self, input_schema: &Schema) -> Result<bool> { | ||
self.arg.nullable(input_schema) | ||
} | ||
|
||
fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> { | ||
let arg = self.arg.evaluate(batch)?; | ||
|
||
// overflow checks only apply in ANSI mode | ||
// datatypes supported are byte, short, integer, long, float, interval | ||
match arg { | ||
ColumnarValue::Array(array) => { | ||
if self.fail_on_error { | ||
match array.data_type() { | ||
DataType::Int8 => { | ||
check_overflow!(array, arrow::array::Int8Array, i8::MIN, "byte") | ||
} | ||
DataType::Int16 => { | ||
check_overflow!(array, arrow::array::Int16Array, i16::MIN, "short") | ||
} | ||
DataType::Int32 => { | ||
check_overflow!(array, arrow::array::Int32Array, i32::MIN, "integer") | ||
} | ||
DataType::Int64 => { | ||
check_overflow!(array, arrow::array::Int64Array, i64::MIN, "long") | ||
} | ||
DataType::Interval(value) => match value { | ||
arrow::datatypes::IntervalUnit::YearMonth => check_overflow!( | ||
array, | ||
arrow::array::IntervalYearMonthArray, | ||
i32::MIN, | ||
"interval" | ||
), | ||
arrow::datatypes::IntervalUnit::DayTime => check_overflow!( | ||
array, | ||
arrow::array::IntervalDayTimeArray, | ||
i64::MIN, | ||
"interval" | ||
), | ||
arrow::datatypes::IntervalUnit::MonthDayNano => { | ||
// Overflow checks are not supported | ||
} | ||
}, | ||
_ => { | ||
// Overflow checks are not supported for other datatypes | ||
} | ||
} | ||
} | ||
let result = neg_wrapping(array.as_ref())?; | ||
Ok(ColumnarValue::Array(result)) | ||
} | ||
ColumnarValue::Scalar(scalar) => { | ||
if self.fail_on_error { | ||
match scalar { | ||
ScalarValue::Int8(value) => { | ||
if value == Some(i8::MIN) { | ||
return Err(arithmetic_overflow_error(" caused").into()); | ||
} | ||
} | ||
ScalarValue::Int16(value) => { | ||
if value == Some(i16::MIN) { | ||
return Err(arithmetic_overflow_error(" caused").into()); | ||
} | ||
} | ||
ScalarValue::Int32(value) => { | ||
if value == Some(i32::MIN) { | ||
return Err(arithmetic_overflow_error("integer").into()); | ||
} | ||
} | ||
ScalarValue::Int64(value) => { | ||
if value == Some(i64::MIN) { | ||
return Err(arithmetic_overflow_error("long").into()); | ||
} | ||
} | ||
ScalarValue::IntervalDayTime(value) => { | ||
let (days, ms) = | ||
IntervalDayTimeType::to_parts(value.unwrap_or_default()); | ||
if days == i32::MIN || ms == i32::MIN { | ||
return Err(arithmetic_overflow_error("interval").into()); | ||
} | ||
} | ||
ScalarValue::IntervalYearMonth(value) => { | ||
if value == Some(i32::MIN) { | ||
return Err(arithmetic_overflow_error("interval").into()); | ||
} | ||
} | ||
_ => { | ||
// Overflow checks are not supported for other datatypes | ||
} | ||
} | ||
} | ||
Ok(ColumnarValue::Scalar((scalar.arithmetic_negate())?)) | ||
} | ||
} | ||
} | ||
|
||
fn children(&self) -> Vec<Arc<dyn PhysicalExpr>> { | ||
vec![self.arg.clone()] | ||
} | ||
|
||
fn with_new_children( | ||
self: Arc<Self>, | ||
children: Vec<Arc<dyn PhysicalExpr>>, | ||
) -> Result<Arc<dyn PhysicalExpr>> { | ||
Ok(Arc::new(NegativeExpr::new( | ||
children[0].clone(), | ||
self.fail_on_error, | ||
))) | ||
} | ||
|
||
fn dyn_hash(&self, state: &mut dyn Hasher) { | ||
let mut s = state; | ||
self.hash(&mut s); | ||
} | ||
|
||
/// Given the child interval of a NegativeExpr, it calculates the NegativeExpr's interval. | ||
/// It replaces the upper and lower bounds after multiplying them with -1. | ||
/// Ex: `(a, b]` => `[-b, -a)` | ||
fn evaluate_bounds(&self, children: &[&Interval]) -> Result<Interval> { | ||
Interval::try_new( | ||
children[0].upper().arithmetic_negate()?, | ||
children[0].lower().arithmetic_negate()?, | ||
) | ||
} | ||
|
||
/// Returns a new [`Interval`] of a NegativeExpr that has the existing `interval` given that | ||
/// given the input interval is known to be `children`. | ||
fn propagate_constraints( | ||
&self, | ||
interval: &Interval, | ||
children: &[&Interval], | ||
) -> Result<Option<Vec<Interval>>> { | ||
let child_interval = children[0]; | ||
|
||
if child_interval.lower() == &ScalarValue::Int32(Some(i32::MIN)) | ||
|| child_interval.upper() == &ScalarValue::Int32(Some(i32::MIN)) | ||
|| child_interval.lower() == &ScalarValue::Int64(Some(i64::MIN)) | ||
|| child_interval.upper() == &ScalarValue::Int64(Some(i64::MIN)) | ||
{ | ||
return Err(CometError::ArithmeticOverflow { | ||
from_type: "long".to_string(), | ||
} | ||
.into()); | ||
} | ||
|
||
let negated_interval = Interval::try_new( | ||
interval.upper().arithmetic_negate()?, | ||
interval.lower().arithmetic_negate()?, | ||
)?; | ||
|
||
Ok(child_interval | ||
.intersect(negated_interval)? | ||
.map(|result| vec![result])) | ||
} | ||
|
||
/// The ordering of a [`NegativeExpr`] is simply the reverse of its child. | ||
fn get_ordering(&self, children: &[SortProperties]) -> SortProperties { | ||
-children[0] | ||
} | ||
} | ||
|
||
impl PartialEq<dyn Any> for NegativeExpr { | ||
fn eq(&self, other: &dyn Any) -> bool { | ||
down_cast_any_ref(other) | ||
.downcast_ref::<Self>() | ||
.map(|x| self.arg.eq(&x.arg)) | ||
.unwrap_or(false) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -454,6 +454,7 @@ message Not { | |
|
||
message Negative { | ||
Expr child = 1; | ||
bool fail_on_error = 2; | ||
} | ||
|
||
message IfExpr { | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.