Skip to content

Commit

Permalink
feat: Support Ansi mode in abs function (#500)
Browse files Browse the repository at this point in the history
* change proto msg

* QueryPlanSerde with eval mode

* Move eval mode

* Add abs in planner

* CometAbsFunc wrapper

* Add error management

* Add tests

* Add license

* spotless apply

* format

* Fix clippy

* error msg for all spark versions

* Fix benches

* Use enum to ansi mode

* Fix format

* Add more tests

* Format

* Refactor

* refactor

* fix merge

* fix merge
  • Loading branch information
planga82 authored Jun 11, 2024
1 parent fa95f1b commit e07f24c
Show file tree
Hide file tree
Showing 10 changed files with 195 additions and 30 deletions.
2 changes: 1 addition & 1 deletion core/benches/cast_from_string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

use arrow_array::{builder::StringBuilder, RecordBatch};
use arrow_schema::{DataType, Field, Schema};
use comet::execution::datafusion::expressions::cast::{Cast, EvalMode};
use comet::execution::datafusion::expressions::{cast::Cast, EvalMode};
use criterion::{criterion_group, criterion_main, Criterion};
use datafusion_physical_expr::{expressions::Column, PhysicalExpr};
use std::sync::Arc;
Expand Down
2 changes: 1 addition & 1 deletion core/benches/cast_numeric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

use arrow_array::{builder::Int32Builder, RecordBatch};
use arrow_schema::{DataType, Field, Schema};
use comet::execution::datafusion::expressions::cast::{Cast, EvalMode};
use comet::execution::datafusion::expressions::{cast::Cast, EvalMode};
use criterion::{criterion_group, criterion_main, Criterion};
use datafusion_physical_expr::{expressions::Column, PhysicalExpr};
use std::sync::Arc;
Expand Down
87 changes: 87 additions & 0 deletions core/src/execution/datafusion/expressions/abs.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use arrow::datatypes::DataType;
use arrow_schema::ArrowError;
use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature};
use datafusion_common::DataFusionError;
use datafusion_functions::math;
use std::{any::Any, sync::Arc};

use crate::execution::operators::ExecutionError;

use super::{arithmetic_overflow_error, EvalMode};

#[derive(Debug)]
pub struct CometAbsFunc {
inner_abs_func: Arc<dyn ScalarUDFImpl>,
eval_mode: EvalMode,
data_type_name: String,
}

impl CometAbsFunc {
pub fn new(eval_mode: EvalMode, data_type_name: String) -> Result<Self, ExecutionError> {
if let EvalMode::Legacy | EvalMode::Ansi = eval_mode {
Ok(Self {
inner_abs_func: math::abs().inner(),
eval_mode,
data_type_name,
})
} else {
Err(ExecutionError::GeneralError(format!(
"Invalid EvalMode: \"{:?}\"",
eval_mode
)))
}
}
}

impl ScalarUDFImpl for CometAbsFunc {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"abs"
}

fn signature(&self) -> &Signature {
self.inner_abs_func.signature()
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType, DataFusionError> {
self.inner_abs_func.return_type(arg_types)
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
match self.inner_abs_func.invoke(args) {
Err(DataFusionError::ArrowError(ArrowError::ComputeError(msg), trace))
if msg.contains("overflow") =>
{
if self.eval_mode == EvalMode::Legacy {
Ok(args[0].clone())
} else {
let msg = arithmetic_overflow_error(&self.data_type_name).to_string();
Err(DataFusionError::ArrowError(
ArrowError::ComputeError(msg),
trace,
))
}
}
other => other,
}
}
}
9 changes: 2 additions & 7 deletions core/src/execution/datafusion/expressions/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ use crate::{
},
};

use super::EvalMode;

static TIMESTAMP_FORMAT: Option<&str> = Some("%Y-%m-%d %H:%M:%S%.f");

static CAST_OPTIONS: CastOptions = CastOptions {
Expand All @@ -61,13 +63,6 @@ static CAST_OPTIONS: CastOptions = CastOptions {
.with_timestamp_format(TIMESTAMP_FORMAT),
};

#[derive(Debug, Hash, PartialEq, Clone, Copy)]
pub enum EvalMode {
Legacy,
Ansi,
Try,
}

#[derive(Debug, Hash)]
pub struct Cast {
pub child: Arc<dyn PhysicalExpr>,
Expand Down
29 changes: 29 additions & 0 deletions core/src/execution/datafusion/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ pub mod if_expr;
mod normalize_nan;
pub mod scalar_funcs;
pub use normalize_nan::NormalizeNaNAndZero;
use prost::DecodeError;

use crate::{errors::CometError, execution::spark_expression};
pub mod abs;
pub mod avg;
pub mod avg_decimal;
pub mod bloom_filter_might_contain;
Expand All @@ -39,3 +43,28 @@ pub mod temporal;
pub mod unbound;
mod utils;
pub mod variance;

#[derive(Debug, Hash, PartialEq, Clone, Copy)]
pub enum EvalMode {
Legacy,
Ansi,
Try,
}

impl TryFrom<i32> for EvalMode {
type Error = DecodeError;

fn try_from(value: i32) -> Result<Self, Self::Error> {
match spark_expression::EvalMode::try_from(value)? {
spark_expression::EvalMode::Legacy => Ok(EvalMode::Legacy),
spark_expression::EvalMode::Try => Ok(EvalMode::Try),
spark_expression::EvalMode::Ansi => Ok(EvalMode::Ansi),
}
}
}

fn arithmetic_overflow_error(from_type: &str) -> CometError {
CometError::ArithmeticOverflow {
from_type: from_type.to_string(),
}
}
8 changes: 2 additions & 6 deletions core/src/execution/datafusion/expressions/negative.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ use std::{
sync::Arc,
};

use super::arithmetic_overflow_error;

pub fn create_negate_expr(
expr: Arc<dyn PhysicalExpr>,
fail_on_error: bool,
Expand All @@ -48,12 +50,6 @@ pub struct NegativeExpr {
fail_on_error: bool,
}

fn arithmetic_overflow_error(from_type: &str) -> CometError {
CometError::ArithmeticOverflow {
from_type: from_type.to_string(),
}
}

macro_rules! check_overflow {
($array:expr, $array_type:ty, $min_val:expr, $type_name:expr) => {{
let typed_array = $array
Expand Down
19 changes: 11 additions & 8 deletions core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ use datafusion::{
arrow::{compute::SortOptions, datatypes::SchemaRef},
common::DataFusionError,
execution::FunctionRegistry,
functions::math,
logical_expr::Operator as DataFusionOperator,
physical_expr::{
execution_props::ExecutionProps,
Expand All @@ -51,6 +50,7 @@ use datafusion_common::{
tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter},
JoinType as DFJoinType, ScalarValue,
};
use datafusion_expr::ScalarUDF;
use datafusion_physical_expr_common::aggregate::create_aggregate_expr;
use itertools::Itertools;
use jni::objects::GlobalRef;
Expand All @@ -65,7 +65,7 @@ use crate::{
avg_decimal::AvgDecimal,
bitwise_not::BitwiseNotExpr,
bloom_filter_might_contain::BloomFilterMightContain,
cast::{Cast, EvalMode},
cast::Cast,
checkoverflow::CheckOverflow,
correlation::Correlation,
covariance::Covariance,
Expand Down Expand Up @@ -97,6 +97,8 @@ use crate::{
},
};

use super::expressions::{abs::CometAbsFunc, EvalMode};

// For clippy error on type_complexity.
type ExecResult<T> = Result<T, ExecutionError>;
type PhyAggResult = Result<Vec<Arc<dyn AggregateExpr>>, ExecutionError>;
Expand Down Expand Up @@ -356,11 +358,7 @@ impl PhysicalPlanner {
let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?;
let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap());
let timezone = expr.timezone.clone();
let eval_mode = match spark_expression::EvalMode::try_from(expr.eval_mode)? {
spark_expression::EvalMode::Legacy => EvalMode::Legacy,
spark_expression::EvalMode::Try => EvalMode::Try,
spark_expression::EvalMode::Ansi => EvalMode::Ansi,
};
let eval_mode = expr.eval_mode.try_into()?;

Ok(Arc::new(Cast::new(child, datatype, eval_mode, timezone)))
}
Expand Down Expand Up @@ -499,7 +497,12 @@ impl PhysicalPlanner {
let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema.clone())?;
let return_type = child.data_type(&input_schema)?;
let args = vec![child];
let expr = ScalarFunctionExpr::new("abs", math::abs(), args, return_type);
let eval_mode = expr.eval_mode.try_into()?;
let comet_abs = Arc::new(ScalarUDF::new_from_impl(CometAbsFunc::new(
eval_mode,
return_type.to_string(),
)?));
let expr = ScalarFunctionExpr::new("abs", comet_abs, args, return_type);
Ok(Arc::new(expr))
}
ExprStruct::CaseWhen(case_when) => {
Expand Down
1 change: 1 addition & 0 deletions core/src/execution/proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,7 @@ message BitwiseNot {

message Abs {
Expr child = 1;
EvalMode eval_mode = 2;
}

message Subquery {
Expand Down
14 changes: 7 additions & 7 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1476,15 +1476,15 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
None
}

case Abs(child, _) =>
case Abs(child, failOnErr) =>
val childExpr = exprToProtoInternal(child, inputs)
if (childExpr.isDefined) {
val abs =
ExprOuterClass.Abs
.newBuilder()
.setChild(childExpr.get)
.build()
Some(Expr.newBuilder().setAbs(abs).build())
val evalModeStr =
if (failOnErr) ExprOuterClass.EvalMode.ANSI else ExprOuterClass.EvalMode.LEGACY
val absBuilder = ExprOuterClass.Abs.newBuilder()
absBuilder.setChild(childExpr.get)
absBuilder.setEvalMode(evalModeStr)
Some(Expr.newBuilder().setAbs(absBuilder).build())
} else {
withInfo(expr, child)
None
Expand Down
54 changes: 54 additions & 0 deletions spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@

package org.apache.comet

import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag

import org.apache.hadoop.fs.Path
import org.apache.spark.sql.{CometTestBase, DataFrame, Row}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
Expand Down Expand Up @@ -850,6 +853,57 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}

test("abs Overflow ansi mode") {

def testAbsAnsiOverflow[T <: Product: ClassTag: TypeTag](data: Seq[T]): Unit = {
withParquetTable(data, "tbl") {
checkSparkMaybeThrows(sql("select abs(_1), abs(_2) from tbl")) match {
case (Some(sparkExc), Some(cometExc)) =>
val cometErrorPattern =
""".+[ARITHMETIC_OVERFLOW].+overflow. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error.""".r
assert(cometErrorPattern.findFirstIn(cometExc.getMessage).isDefined)
assert(sparkExc.getMessage.contains("overflow"))
case _ => fail("Exception should be thrown")
}
}
}

def testAbsAnsi[T <: Product: ClassTag: TypeTag](data: Seq[T]): Unit = {
withParquetTable(data, "tbl") {
checkSparkAnswerAndOperator("select abs(_1), abs(_2) from tbl")
}
}

withSQLConf(
SQLConf.ANSI_ENABLED.key -> "true",
CometConf.COMET_ANSI_MODE_ENABLED.key -> "true") {
testAbsAnsiOverflow(Seq((Byte.MaxValue, Byte.MinValue)))
testAbsAnsiOverflow(Seq((Short.MaxValue, Short.MinValue)))
testAbsAnsiOverflow(Seq((Int.MaxValue, Int.MinValue)))
testAbsAnsiOverflow(Seq((Long.MaxValue, Long.MinValue)))
testAbsAnsi(Seq((Float.MaxValue, Float.MinValue)))
testAbsAnsi(Seq((Double.MaxValue, Double.MinValue)))
}
}

test("abs Overflow legacy mode") {

def testAbsLegacyOverflow[T <: Product: ClassTag: TypeTag](data: Seq[T]): Unit = {
withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") {
withParquetTable(data, "tbl") {
checkSparkAnswerAndOperator("select abs(_1), abs(_2) from tbl")
}
}
}

testAbsLegacyOverflow(Seq((Byte.MaxValue, Byte.MinValue)))
testAbsLegacyOverflow(Seq((Short.MaxValue, Short.MinValue)))
testAbsLegacyOverflow(Seq((Int.MaxValue, Int.MinValue)))
testAbsLegacyOverflow(Seq((Long.MaxValue, Long.MinValue)))
testAbsLegacyOverflow(Seq((Float.MaxValue, Float.MinValue)))
testAbsLegacyOverflow(Seq((Double.MaxValue, Double.MinValue)))
}

test("ceil and floor") {
Seq("true", "false").foreach { dictionary =>
withSQLConf(
Expand Down

0 comments on commit e07f24c

Please sign in to comment.