From fce78fad726d5c79394f8bc48c7f02a7a805a863 Mon Sep 17 00:00:00 2001 From: Pablo Langa Date: Thu, 30 May 2024 16:58:15 +0200 Subject: [PATCH 01/21] change proto msg --- core/src/execution/proto/expr.proto | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/execution/proto/expr.proto b/core/src/execution/proto/expr.proto index be85e8a92..cbb6939bf 100644 --- a/core/src/execution/proto/expr.proto +++ b/core/src/execution/proto/expr.proto @@ -473,6 +473,7 @@ message BitwiseNot { message Abs { Expr child = 1; + string eval_mode = 2; } message Subquery { From 1071eee936316f197697b584078920da0c50d4c9 Mon Sep 17 00:00:00 2001 From: Pablo Langa Date: Thu, 30 May 2024 16:59:31 +0200 Subject: [PATCH 02/21] QueryPlanSerde with eval mode --- .../apache/comet/serde/QueryPlanSerde.scala | 22 ++++++++++++------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 06b9bc7f4..a0b37e542 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -54,6 +54,13 @@ import org.apache.comet.shims.ShimQueryPlanSerde * An utility object for query plan and expression serialization. */ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim { + + object ExecutionMode { + val ANSI = "ANSI" + val LEGACY = "LEGACY" + val TRY = "TRY" + } + def emitWarning(reason: String): Unit = { logWarning(s"Comet native execution is disabled due to: $reason") } @@ -691,7 +698,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim case Cast(child, dt, timeZoneId, evalMode) => val evalModeStr = if (evalMode.isInstanceOf[Boolean]) { // Spark 3.2 & 3.3 has ansiEnabled boolean - if (evalMode.asInstanceOf[Boolean]) "ANSI" else "LEGACY" + if (evalMode.asInstanceOf[Boolean]) ExecutionMode.ANSI else ExecutionMode.LEGACY } else { // Spark 3.4+ has EvalMode enum with values LEGACY, ANSI, and TRY evalMode.toString @@ -1474,15 +1481,14 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim None } - case Abs(child, _) => + case Abs(child, failOnErr) => val childExpr = exprToProtoInternal(child, inputs) if (childExpr.isDefined) { - val abs = - ExprOuterClass.Abs - .newBuilder() - .setChild(childExpr.get) - .build() - Some(Expr.newBuilder().setAbs(abs).build()) + val evalModeStr = if (failOnErr) ExecutionMode.ANSI else ExecutionMode.LEGACY + val absBuilder = ExprOuterClass.Abs.newBuilder() + absBuilder.setChild(childExpr.get) + absBuilder.setEvalMode(evalModeStr) + Some(Expr.newBuilder().setAbs(absBuilder).build()) } else { withInfo(expr, child) None From 750331d6780dd6de796a5190f726544477da842b Mon Sep 17 00:00:00 2001 From: Pablo Langa Date: Thu, 30 May 2024 17:05:47 +0200 Subject: [PATCH 03/21] Move eval mode --- core/src/execution/datafusion/expressions/cast.rs | 8 ++------ core/src/execution/datafusion/expressions/mod.rs | 8 ++++++++ 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index 7e8a96f28..76aeb60ae 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -51,6 +51,8 @@ use crate::{ }, }; +use super::EvalMode; + static TIMESTAMP_FORMAT: Option<&str> = Some("%Y-%m-%d %H:%M:%S%.f"); static CAST_OPTIONS: CastOptions = CastOptions { @@ -60,12 +62,6 @@ static CAST_OPTIONS: CastOptions = CastOptions { .with_timestamp_format(TIMESTAMP_FORMAT), }; -#[derive(Debug, Hash, PartialEq, Clone, Copy)] -pub enum EvalMode { - Legacy, - Ansi, - Try, -} #[derive(Debug, Hash)] pub struct Cast { diff --git a/core/src/execution/datafusion/expressions/mod.rs b/core/src/execution/datafusion/expressions/mod.rs index 9db4b65b3..92cf5b99e 100644 --- a/core/src/execution/datafusion/expressions/mod.rs +++ b/core/src/execution/datafusion/expressions/mod.rs @@ -24,6 +24,7 @@ pub mod if_expr; mod normalize_nan; pub mod scalar_funcs; pub use normalize_nan::NormalizeNaNAndZero; +pub mod abs; pub mod avg; pub mod avg_decimal; pub mod bloom_filter_might_contain; @@ -37,3 +38,10 @@ pub mod sum_decimal; pub mod temporal; mod utils; pub mod variance; + +#[derive(Debug, Hash, PartialEq, Clone, Copy)] +pub enum EvalMode { + Legacy, + Ansi, + Try, +} From 9f89b57fd0573ee8298058d900b8f7fa63cf2364 Mon Sep 17 00:00:00 2001 From: Pablo Langa Date: Thu, 30 May 2024 17:25:57 +0200 Subject: [PATCH 04/21] Add abs in planner --- core/src/execution/datafusion/planner.rs | 34 ++++++++++++++---------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index 7a37e3aae..561d43099 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -24,7 +24,6 @@ use datafusion::{ arrow::{compute::SortOptions, datatypes::SchemaRef}, common::DataFusionError, execution::FunctionRegistry, - functions::math, logical_expr::{ BuiltinScalarFunction, Operator as DataFusionOperator, ScalarFunctionDefinition, }, @@ -48,6 +47,7 @@ use datafusion::{ }, prelude::SessionContext, }; +use datafusion_physical_expr::udf::ScalarUDF; use datafusion_common::{ tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter}, JoinType as DFJoinType, ScalarValue, @@ -65,7 +65,7 @@ use crate::{ avg_decimal::AvgDecimal, bitwise_not::BitwiseNotExpr, bloom_filter_might_contain::BloomFilterMightContain, - cast::{Cast, EvalMode}, + cast::Cast, checkoverflow::CheckOverflow, correlation::Correlation, covariance::Covariance, @@ -95,6 +95,8 @@ use crate::{ }, }; +use super::expressions::{abs::CometAbsFunc, EvalMode}; + // For clippy error on type_complexity. type ExecResult = Result; type PhyAggResult = Result>, ExecutionError>; @@ -149,6 +151,17 @@ impl PhysicalPlanner { } } + fn eval_mode_from_str(eval_mode_str: &str, allow_try: bool) -> Result { + match eval_mode_str { + "ANSI" => Ok(EvalMode::Ansi), + "LEGACY" => Ok(EvalMode::Legacy), + "TRY" if allow_try => Ok(EvalMode::Try), + other => Err(ExecutionError::GeneralError(format!( + "Invalid EvalMode: \"{other}\"" + ))), + } + } + /// Create a DataFusion physical expression from Spark physical expression fn create_expr( &self, @@ -348,16 +361,7 @@ impl PhysicalPlanner { let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?; let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); let timezone = expr.timezone.clone(); - let eval_mode = match expr.eval_mode.as_str() { - "ANSI" => EvalMode::Ansi, - "TRY" => EvalMode::Try, - "LEGACY" => EvalMode::Legacy, - other => { - return Err(ExecutionError::GeneralError(format!( - "Invalid Cast EvalMode: \"{other}\"" - ))) - } - }; + let eval_mode = Self::eval_mode_from_str(expr.eval_mode.as_str(), true)?; Ok(Arc::new(Cast::new(child, datatype, eval_mode, timezone))) } ExprStruct::Hour(expr) => { @@ -495,8 +499,10 @@ impl PhysicalPlanner { let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema.clone())?; let return_type = child.data_type(&input_schema)?; let args = vec![child]; - let scalar_def = ScalarFunctionDefinition::UDF(math::abs()); - + let eval_mode = Self::eval_mode_from_str(expr.eval_mode.as_str(), false)?; + let comet_abs = ScalarUDF::new_from_impl(CometAbsFunc::new(eval_mode)); + let scalar_def = ScalarFunctionDefinition::UDF(Arc::new(comet_abs)); + let expr = ScalarFunctionExpr::new("abs", scalar_def, args, return_type, None, false); Ok(Arc::new(expr)) From e6eda863b683b96dbaf658aa7a25bab36e0a2220 Mon Sep 17 00:00:00 2001 From: Pablo Langa Date: Thu, 30 May 2024 18:06:07 +0200 Subject: [PATCH 05/21] CometAbsFunc wrapper --- .../execution/datafusion/expressions/abs.rs | 55 +++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 core/src/execution/datafusion/expressions/abs.rs diff --git a/core/src/execution/datafusion/expressions/abs.rs b/core/src/execution/datafusion/expressions/abs.rs new file mode 100644 index 000000000..abc204ec4 --- /dev/null +++ b/core/src/execution/datafusion/expressions/abs.rs @@ -0,0 +1,55 @@ +use arrow::datatypes::DataType; +use std::{any::Any, sync::Arc}; +use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature}; +use datafusion_common::DataFusionError; +use datafusion_functions::math; + +use super::EvalMode; + + +#[derive(Debug)] +pub struct CometAbsFunc { + inner_abs_func: Arc, + eval_mode: EvalMode, +} + +impl CometAbsFunc { + pub fn new(eval_mode: EvalMode) -> Self { + Self { + inner_abs_func: math::abs().inner(), + eval_mode + } + } +} + +impl ScalarUDFImpl for CometAbsFunc { + + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "abs" + } + + fn signature(&self) -> &Signature { + &self.inner_abs_func.signature() + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + self.inner_abs_func.return_type(arg_types) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match self.inner_abs_func.invoke(args) { + Ok(result) => Ok(result), + Err(err) => { + if self.eval_mode == EvalMode::Legacy { + Ok(args[0].clone()) + } else { + Err(err) + } + } + } + } +} + From 0b37f8ede81482b57b6c73a01f7119e02e322edf Mon Sep 17 00:00:00 2001 From: Pablo Langa Date: Thu, 30 May 2024 20:52:53 +0200 Subject: [PATCH 06/21] Add error management --- core/src/errors.rs | 3 +++ .../execution/datafusion/expressions/abs.rs | 21 +++++++++++++++---- core/src/execution/datafusion/planner.rs | 2 +- 3 files changed, 21 insertions(+), 5 deletions(-) diff --git a/core/src/errors.rs b/core/src/errors.rs index 04a1629d5..af4fd2697 100644 --- a/core/src/errors.rs +++ b/core/src/errors.rs @@ -88,6 +88,9 @@ pub enum CometError { to_type: String, }, + #[error("[ARITHMETIC_OVERFLOW] {from_type} overflow. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + ArithmeticOverflow { from_type: String }, + #[error(transparent)] Arrow { #[from] diff --git a/core/src/execution/datafusion/expressions/abs.rs b/core/src/execution/datafusion/expressions/abs.rs index abc204ec4..6473be063 100644 --- a/core/src/execution/datafusion/expressions/abs.rs +++ b/core/src/execution/datafusion/expressions/abs.rs @@ -1,23 +1,33 @@ use arrow::datatypes::DataType; +use arrow_schema::ArrowError; use std::{any::Any, sync::Arc}; use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature}; use datafusion_common::DataFusionError; use datafusion_functions::math; +use crate::errors::CometError; + use super::EvalMode; +fn arithmetic_overflow_error(from_type: &str) -> CometError { + CometError::ArithmeticOverflow { + from_type: from_type.to_string(), + } +} #[derive(Debug)] pub struct CometAbsFunc { inner_abs_func: Arc, eval_mode: EvalMode, + data_type_name: String } impl CometAbsFunc { - pub fn new(eval_mode: EvalMode) -> Self { + pub fn new(eval_mode: EvalMode, data_type_name: String) -> Self { Self { inner_abs_func: math::abs().inner(), - eval_mode + eval_mode, + data_type_name } } } @@ -42,13 +52,16 @@ impl ScalarUDFImpl for CometAbsFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { match self.inner_abs_func.invoke(args) { Ok(result) => Ok(result), - Err(err) => { + Err(DataFusionError::ArrowError(ArrowError::ComputeError(msg), trace)) + if msg.contains("overflow") => { if self.eval_mode == EvalMode::Legacy { Ok(args[0].clone()) } else { - Err(err) + let msg = arithmetic_overflow_error(&self.data_type_name).to_string(); + Err(DataFusionError::ArrowError(ArrowError::ComputeError(msg), trace)) } } + other => other, } } } diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index 561d43099..d637feddb 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -500,7 +500,7 @@ impl PhysicalPlanner { let return_type = child.data_type(&input_schema)?; let args = vec![child]; let eval_mode = Self::eval_mode_from_str(expr.eval_mode.as_str(), false)?; - let comet_abs = ScalarUDF::new_from_impl(CometAbsFunc::new(eval_mode)); + let comet_abs = ScalarUDF::new_from_impl(CometAbsFunc::new(eval_mode, return_type.to_string())); let scalar_def = ScalarFunctionDefinition::UDF(Arc::new(comet_abs)); let expr = From d1e2099fdc7ff804638c84f7ec008b6829e2da1d Mon Sep 17 00:00:00 2001 From: Pablo Langa Date: Fri, 31 May 2024 17:49:25 +0200 Subject: [PATCH 07/21] Add tests --- .../apache/comet/CometExpressionSuite.scala | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 34c794eb1..75ce8edcc 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -850,6 +850,32 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("abs Overflow ansi mode") { + val data: Seq[(Int,Int)] = Seq((Int.MaxValue, Int.MinValue)) + withSQLConf( + SQLConf.ANSI_ENABLED.key -> "true", + CometConf.COMET_ANSI_MODE_ENABLED.key -> "true") { + withParquetTable(data, "tbl") { + checkSparkMaybeThrows(sql("select abs(_1), abs(_2) from tbl")) match { + case (Some(e1), Some(e2)) => + val errorPattern = s""".+[ARITHMETIC_OVERFLOW].+overflow. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error.""".r + assert(errorPattern.findFirstIn(e1.getMessage).isDefined) + assert(errorPattern.findFirstIn(e2.getMessage).isDefined) + case _ => fail("Exception should be thrown") + } + } + } + } + + test("abs Overflow legacy mode") { + val data: Seq[(Int,Int)] = Seq((Int.MaxValue, Int.MinValue), (1, -1)) + withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { + withParquetTable(data, "tbl") { + checkSparkAnswerAndOperator("select abs(_1), abs(_2) from tbl") + } + } + } + test("ceil and floor") { Seq("true", "false").foreach { dictionary => withSQLConf( From 73e55136ed443659e0a36d1dba374e0a97ed50ea Mon Sep 17 00:00:00 2001 From: Pablo Langa Date: Fri, 31 May 2024 17:51:11 +0200 Subject: [PATCH 08/21] Add license --- .../src/execution/datafusion/expressions/abs.rs | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/core/src/execution/datafusion/expressions/abs.rs b/core/src/execution/datafusion/expressions/abs.rs index 6473be063..3de3054af 100644 --- a/core/src/execution/datafusion/expressions/abs.rs +++ b/core/src/execution/datafusion/expressions/abs.rs @@ -1,3 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + use arrow::datatypes::DataType; use arrow_schema::ArrowError; use std::{any::Any, sync::Arc}; From cff5f29f360cefd1ec6579d1591b7d4500a5ce63 Mon Sep 17 00:00:00 2001 From: Pablo Langa Date: Fri, 31 May 2024 17:59:56 +0200 Subject: [PATCH 09/21] spotless apply --- .../scala/org/apache/comet/CometExpressionSuite.scala | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 75ce8edcc..c1a98c80b 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -851,14 +851,15 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } test("abs Overflow ansi mode") { - val data: Seq[(Int,Int)] = Seq((Int.MaxValue, Int.MinValue)) + val data: Seq[(Int, Int)] = Seq((Int.MaxValue, Int.MinValue)) withSQLConf( - SQLConf.ANSI_ENABLED.key -> "true", - CometConf.COMET_ANSI_MODE_ENABLED.key -> "true") { + SQLConf.ANSI_ENABLED.key -> "true", + CometConf.COMET_ANSI_MODE_ENABLED.key -> "true") { withParquetTable(data, "tbl") { checkSparkMaybeThrows(sql("select abs(_1), abs(_2) from tbl")) match { case (Some(e1), Some(e2)) => - val errorPattern = s""".+[ARITHMETIC_OVERFLOW].+overflow. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error.""".r + val errorPattern = + s""".+[ARITHMETIC_OVERFLOW].+overflow. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error.""".r assert(errorPattern.findFirstIn(e1.getMessage).isDefined) assert(errorPattern.findFirstIn(e2.getMessage).isDefined) case _ => fail("Exception should be thrown") @@ -868,7 +869,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } test("abs Overflow legacy mode") { - val data: Seq[(Int,Int)] = Seq((Int.MaxValue, Int.MinValue), (1, -1)) + val data: Seq[(Int, Int)] = Seq((Int.MaxValue, Int.MinValue), (1, -1)) withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { withParquetTable(data, "tbl") { checkSparkAnswerAndOperator("select abs(_1), abs(_2) from tbl") From 9b3b4c859859f50e627205475b1935e9c49239c4 Mon Sep 17 00:00:00 2001 From: Pablo Langa Date: Fri, 31 May 2024 21:50:30 +0200 Subject: [PATCH 10/21] format --- .../execution/datafusion/expressions/abs.rs | 18 ++++++++++-------- .../execution/datafusion/expressions/cast.rs | 1 - core/src/execution/datafusion/planner.rs | 12 ++++++++---- .../apache/comet/CometExpressionSuite.scala | 2 +- 4 files changed, 19 insertions(+), 14 deletions(-) diff --git a/core/src/execution/datafusion/expressions/abs.rs b/core/src/execution/datafusion/expressions/abs.rs index 3de3054af..b045d0861 100644 --- a/core/src/execution/datafusion/expressions/abs.rs +++ b/core/src/execution/datafusion/expressions/abs.rs @@ -17,10 +17,10 @@ use arrow::datatypes::DataType; use arrow_schema::ArrowError; -use std::{any::Any, sync::Arc}; use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature}; use datafusion_common::DataFusionError; use datafusion_functions::math; +use std::{any::Any, sync::Arc}; use crate::errors::CometError; @@ -36,7 +36,7 @@ fn arithmetic_overflow_error(from_type: &str) -> CometError { pub struct CometAbsFunc { inner_abs_func: Arc, eval_mode: EvalMode, - data_type_name: String + data_type_name: String, } impl CometAbsFunc { @@ -44,13 +44,12 @@ impl CometAbsFunc { Self { inner_abs_func: math::abs().inner(), eval_mode, - data_type_name + data_type_name, } } } impl ScalarUDFImpl for CometAbsFunc { - fn as_any(&self) -> &dyn Any { self } @@ -69,17 +68,20 @@ impl ScalarUDFImpl for CometAbsFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { match self.inner_abs_func.invoke(args) { Ok(result) => Ok(result), - Err(DataFusionError::ArrowError(ArrowError::ComputeError(msg), trace)) - if msg.contains("overflow") => { + Err(DataFusionError::ArrowError(ArrowError::ComputeError(msg), trace)) + if msg.contains("overflow") => + { if self.eval_mode == EvalMode::Legacy { Ok(args[0].clone()) } else { let msg = arithmetic_overflow_error(&self.data_type_name).to_string(); - Err(DataFusionError::ArrowError(ArrowError::ComputeError(msg), trace)) + Err(DataFusionError::ArrowError( + ArrowError::ComputeError(msg), + trace, + )) } } other => other, } } } - diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index 76aeb60ae..1940d6f71 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -62,7 +62,6 @@ static CAST_OPTIONS: CastOptions = CastOptions { .with_timestamp_format(TIMESTAMP_FORMAT), }; - #[derive(Debug, Hash)] pub struct Cast { pub child: Arc, diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index d637feddb..259c43549 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -47,11 +47,11 @@ use datafusion::{ }, prelude::SessionContext, }; -use datafusion_physical_expr::udf::ScalarUDF; use datafusion_common::{ tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter}, JoinType as DFJoinType, ScalarValue, }; +use datafusion_physical_expr::udf::ScalarUDF; use itertools::Itertools; use jni::objects::GlobalRef; use num::{BigInt, ToPrimitive}; @@ -151,7 +151,10 @@ impl PhysicalPlanner { } } - fn eval_mode_from_str(eval_mode_str: &str, allow_try: bool) -> Result { + fn eval_mode_from_str( + eval_mode_str: &str, + allow_try: bool, + ) -> Result { match eval_mode_str { "ANSI" => Ok(EvalMode::Ansi), "LEGACY" => Ok(EvalMode::Legacy), @@ -500,9 +503,10 @@ impl PhysicalPlanner { let return_type = child.data_type(&input_schema)?; let args = vec![child]; let eval_mode = Self::eval_mode_from_str(expr.eval_mode.as_str(), false)?; - let comet_abs = ScalarUDF::new_from_impl(CometAbsFunc::new(eval_mode, return_type.to_string())); + let comet_abs = + ScalarUDF::new_from_impl(CometAbsFunc::new(eval_mode, return_type.to_string())); let scalar_def = ScalarFunctionDefinition::UDF(Arc::new(comet_abs)); - + let expr = ScalarFunctionExpr::new("abs", scalar_def, args, return_type, None, false); Ok(Arc::new(expr)) diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index c1a98c80b..63cda6361 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -859,7 +859,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { checkSparkMaybeThrows(sql("select abs(_1), abs(_2) from tbl")) match { case (Some(e1), Some(e2)) => val errorPattern = - s""".+[ARITHMETIC_OVERFLOW].+overflow. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error.""".r + """.+[ARITHMETIC_OVERFLOW].+overflow. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error.""".r assert(errorPattern.findFirstIn(e1.getMessage).isDefined) assert(errorPattern.findFirstIn(e2.getMessage).isDefined) case _ => fail("Exception should be thrown") From f7df3577bcd0df87971a488bbddf4b245285833c Mon Sep 17 00:00:00 2001 From: Pablo Langa Date: Sat, 1 Jun 2024 08:25:43 +0200 Subject: [PATCH 11/21] Fix clippy --- core/src/execution/datafusion/expressions/abs.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/execution/datafusion/expressions/abs.rs b/core/src/execution/datafusion/expressions/abs.rs index b045d0861..9923d2652 100644 --- a/core/src/execution/datafusion/expressions/abs.rs +++ b/core/src/execution/datafusion/expressions/abs.rs @@ -58,7 +58,7 @@ impl ScalarUDFImpl for CometAbsFunc { } fn signature(&self) -> &Signature { - &self.inner_abs_func.signature() + self.inner_abs_func.signature() } fn return_type(&self, arg_types: &[DataType]) -> Result { From 76914b00240f2c62eb6c11788b3973d8b809d617 Mon Sep 17 00:00:00 2001 From: Pablo Langa Date: Sat, 1 Jun 2024 09:13:18 +0200 Subject: [PATCH 12/21] error msg for all spark versions --- .../scala/org/apache/comet/CometExpressionSuite.scala | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 63cda6361..1e5ed8f36 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -857,11 +857,12 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { CometConf.COMET_ANSI_MODE_ENABLED.key -> "true") { withParquetTable(data, "tbl") { checkSparkMaybeThrows(sql("select abs(_1), abs(_2) from tbl")) match { - case (Some(e1), Some(e2)) => - val errorPattern = + case (Some(sparkExc), Some(cometExc)) => + val cometErrorPattern = """.+[ARITHMETIC_OVERFLOW].+overflow. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error.""".r - assert(errorPattern.findFirstIn(e1.getMessage).isDefined) - assert(errorPattern.findFirstIn(e2.getMessage).isDefined) + val sparkErrorPattern = ".*integer overflow.*".r + assert(cometErrorPattern.findFirstIn(cometExc.getMessage).isDefined) + assert(sparkErrorPattern.findFirstIn(sparkExc.getMessage).isDefined) case _ => fail("Exception should be thrown") } } From 3b55ca25db282d7efc429d82f730c27ed846dfe7 Mon Sep 17 00:00:00 2001 From: Pablo Langa Date: Sat, 1 Jun 2024 11:28:48 +0200 Subject: [PATCH 13/21] Fix benches --- core/benches/cast_from_string.rs | 2 +- core/benches/cast_numeric.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/core/benches/cast_from_string.rs b/core/benches/cast_from_string.rs index 5bfaebf34..9a9ab18cc 100644 --- a/core/benches/cast_from_string.rs +++ b/core/benches/cast_from_string.rs @@ -17,7 +17,7 @@ use arrow_array::{builder::StringBuilder, RecordBatch}; use arrow_schema::{DataType, Field, Schema}; -use comet::execution::datafusion::expressions::cast::{Cast, EvalMode}; +use comet::execution::datafusion::expressions::{cast::Cast, EvalMode}; use criterion::{criterion_group, criterion_main, Criterion}; use datafusion_physical_expr::{expressions::Column, PhysicalExpr}; use std::sync::Arc; diff --git a/core/benches/cast_numeric.rs b/core/benches/cast_numeric.rs index 398be6946..35f24ce53 100644 --- a/core/benches/cast_numeric.rs +++ b/core/benches/cast_numeric.rs @@ -17,7 +17,7 @@ use arrow_array::{builder::Int32Builder, RecordBatch}; use arrow_schema::{DataType, Field, Schema}; -use comet::execution::datafusion::expressions::cast::{Cast, EvalMode}; +use comet::execution::datafusion::expressions::{cast::Cast, EvalMode}; use criterion::{criterion_group, criterion_main, Criterion}; use datafusion_physical_expr::{expressions::Column, PhysicalExpr}; use std::sync::Arc; From ab28bf6ddb733f00a1f54acaef16d744c58338ab Mon Sep 17 00:00:00 2001 From: Pablo Langa Date: Mon, 3 Jun 2024 21:13:13 +0200 Subject: [PATCH 14/21] Use enum to ansi mode --- core/src/execution/datafusion/planner.rs | 22 ++++++------------- core/src/execution/proto/expr.proto | 2 +- .../apache/comet/serde/QueryPlanSerde.scala | 10 ++------- 3 files changed, 10 insertions(+), 24 deletions(-) diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index cc7daa8a3..a4ba82a85 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -151,20 +151,6 @@ impl PhysicalPlanner { } } - fn eval_mode_from_str( - eval_mode_str: &str, - allow_try: bool, - ) -> Result { - match eval_mode_str { - "ANSI" => Ok(EvalMode::Ansi), - "LEGACY" => Ok(EvalMode::Legacy), - "TRY" if allow_try => Ok(EvalMode::Try), - other => Err(ExecutionError::GeneralError(format!( - "Invalid EvalMode: \"{other}\"" - ))), - } - } - /// Create a DataFusion physical expression from Spark physical expression fn create_expr( &self, @@ -507,7 +493,13 @@ impl PhysicalPlanner { let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema.clone())?; let return_type = child.data_type(&input_schema)?; let args = vec![child]; - let eval_mode = Self::eval_mode_from_str(expr.eval_mode.as_str(), false)?; + let eval_mode = match spark_expression::EvalMode::try_from(expr.eval_mode)? { + spark_expression::EvalMode::Legacy => EvalMode::Legacy, + spark_expression::EvalMode::Ansi => EvalMode::Ansi, + spark_expression::EvalMode::Try => return Err(ExecutionError::GeneralError(format!( + "Invalid EvalMode: \"TRY\"" + ))), + }; let comet_abs = ScalarUDF::new_from_impl(CometAbsFunc::new(eval_mode, return_type.to_string())); let scalar_def = ScalarFunctionDefinition::UDF(Arc::new(comet_abs)); diff --git a/core/src/execution/proto/expr.proto b/core/src/execution/proto/expr.proto index 05389c092..8fdec995f 100644 --- a/core/src/execution/proto/expr.proto +++ b/core/src/execution/proto/expr.proto @@ -479,7 +479,7 @@ message BitwiseNot { message Abs { Expr child = 1; - string eval_mode = 2; + EvalMode eval_mode = 2; } message Subquery { diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index a99872701..f74a535c4 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -57,12 +57,6 @@ import org.apache.comet.shims.ShimQueryPlanSerde */ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim { - object ExecutionMode { - val ANSI = "ANSI" - val LEGACY = "LEGACY" - val TRY = "TRY" - } - def emitWarning(reason: String): Unit = { logWarning(s"Comet native execution is disabled due to: $reason") } @@ -713,7 +707,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim case Cast(child, dt, timeZoneId, evalMode) => val evalModeStr = if (evalMode.isInstanceOf[Boolean]) { // Spark 3.2 & 3.3 has ansiEnabled boolean - if (evalMode.asInstanceOf[Boolean]) ExecutionMode.ANSI else ExecutionMode.LEGACY + if (evalMode.asInstanceOf[Boolean]) "ANSI" else "LEGACY" } else { // Spark 3.4+ has EvalMode enum with values LEGACY, ANSI, and TRY evalMode.toString @@ -1499,7 +1493,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim case Abs(child, failOnErr) => val childExpr = exprToProtoInternal(child, inputs) if (childExpr.isDefined) { - val evalModeStr = if (failOnErr) ExecutionMode.ANSI else ExecutionMode.LEGACY + val evalModeStr = if (failOnErr) ExprOuterClass.EvalMode.ANSI else ExprOuterClass.EvalMode.LEGACY val absBuilder = ExprOuterClass.Abs.newBuilder() absBuilder.setChild(childExpr.get) absBuilder.setEvalMode(evalModeStr) From 0dda0b244109ddfd891ece303c5419149e2cc40f Mon Sep 17 00:00:00 2001 From: Pablo Langa Date: Mon, 3 Jun 2024 22:39:08 +0200 Subject: [PATCH 15/21] Fix format --- core/src/execution/datafusion/planner.rs | 10 ++++++---- .../scala/org/apache/comet/serde/QueryPlanSerde.scala | 4 ++-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index a4ba82a85..ff672e1c3 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -496,9 +496,11 @@ impl PhysicalPlanner { let eval_mode = match spark_expression::EvalMode::try_from(expr.eval_mode)? { spark_expression::EvalMode::Legacy => EvalMode::Legacy, spark_expression::EvalMode::Ansi => EvalMode::Ansi, - spark_expression::EvalMode::Try => return Err(ExecutionError::GeneralError(format!( - "Invalid EvalMode: \"TRY\"" - ))), + spark_expression::EvalMode::Try => { + return Err(ExecutionError::GeneralError( + "Invalid EvalMode: \"TRY\"".to_string(), + )) + } }; let comet_abs = ScalarUDF::new_from_impl(CometAbsFunc::new(eval_mode, return_type.to_string())); @@ -1816,4 +1818,4 @@ mod tests { })); op } -} \ No newline at end of file +} diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index f74a535c4..b3fffd7e1 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -56,7 +56,6 @@ import org.apache.comet.shims.ShimQueryPlanSerde * An utility object for query plan and expression serialization. */ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim { - def emitWarning(reason: String): Unit = { logWarning(s"Comet native execution is disabled due to: $reason") } @@ -1493,7 +1492,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim case Abs(child, failOnErr) => val childExpr = exprToProtoInternal(child, inputs) if (childExpr.isDefined) { - val evalModeStr = if (failOnErr) ExprOuterClass.EvalMode.ANSI else ExprOuterClass.EvalMode.LEGACY + val evalModeStr = + if (failOnErr) ExprOuterClass.EvalMode.ANSI else ExprOuterClass.EvalMode.LEGACY val absBuilder = ExprOuterClass.Abs.newBuilder() absBuilder.setChild(childExpr.get) absBuilder.setEvalMode(evalModeStr) From 1fc4f48b6611f7fd3a16fae2445364c58f2e773e Mon Sep 17 00:00:00 2001 From: Pablo Langa Date: Tue, 4 Jun 2024 07:53:11 +0200 Subject: [PATCH 16/21] Add more tests --- .../apache/comet/CometExpressionSuite.scala | 47 ++++++++++++++----- 1 file changed, 36 insertions(+), 11 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 0d29d943d..65d95717b 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -26,9 +26,11 @@ import org.apache.spark.sql.functions.expr import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.SESSION_LOCAL_TIMEZONE import org.apache.spark.sql.types.{Decimal, DecimalType} - import org.apache.comet.CometSparkSessionExtensions.{isSpark32, isSpark33Plus, isSpark34Plus} +import scala.reflect.ClassTag +import scala.reflect.runtime.universe.TypeTag + class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { import testImplicits._ @@ -851,31 +853,54 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } test("abs Overflow ansi mode") { - val data: Seq[(Int, Int)] = Seq((Int.MaxValue, Int.MinValue)) - withSQLConf( - SQLConf.ANSI_ENABLED.key -> "true", - CometConf.COMET_ANSI_MODE_ENABLED.key -> "true") { + + def testAbsAnsiOverflow[T <: Product: ClassTag: TypeTag](data: Seq[T]): Unit = { withParquetTable(data, "tbl") { checkSparkMaybeThrows(sql("select abs(_1), abs(_2) from tbl")) match { case (Some(sparkExc), Some(cometExc)) => val cometErrorPattern = """.+[ARITHMETIC_OVERFLOW].+overflow. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error.""".r - val sparkErrorPattern = ".*integer overflow.*".r assert(cometErrorPattern.findFirstIn(cometExc.getMessage).isDefined) - assert(sparkErrorPattern.findFirstIn(sparkExc.getMessage).isDefined) + assert(sparkExc.getMessage.contains("overflow")) case _ => fail("Exception should be thrown") } } } - } - test("abs Overflow legacy mode") { - val data: Seq[(Int, Int)] = Seq((Int.MaxValue, Int.MinValue), (1, -1)) - withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { + def testAbsAnsi[T <: Product: ClassTag: TypeTag](data: Seq[T]): Unit = { withParquetTable(data, "tbl") { checkSparkAnswerAndOperator("select abs(_1), abs(_2) from tbl") } } + + withSQLConf( + SQLConf.ANSI_ENABLED.key -> "true", + CometConf.COMET_ANSI_MODE_ENABLED.key -> "true") { + testAbsAnsiOverflow(Seq((Byte.MaxValue, Byte.MinValue))) + testAbsAnsiOverflow(Seq((Short.MaxValue, Short.MinValue))) + testAbsAnsiOverflow(Seq((Int.MaxValue, Int.MinValue))) + testAbsAnsiOverflow(Seq((Long.MaxValue, Long.MinValue))) + testAbsAnsi(Seq((Float.MaxValue, Float.MinValue))) + testAbsAnsi(Seq((Double.MaxValue, Double.MinValue))) + } + } + + test("abs Overflow legacy mode") { + + def testAbsLegacyOverflow[T <: Product: ClassTag: TypeTag](data: Seq[T]): Unit = { + withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { + withParquetTable(data, "tbl") { + checkSparkAnswerAndOperator("select abs(_1), abs(_2) from tbl") + } + } + } + + testAbsLegacyOverflow(Seq((Byte.MaxValue, Byte.MinValue))) + testAbsLegacyOverflow(Seq((Short.MaxValue, Short.MinValue))) + testAbsLegacyOverflow(Seq((Int.MaxValue, Int.MinValue))) + testAbsLegacyOverflow(Seq((Long.MaxValue, Long.MinValue))) + testAbsLegacyOverflow(Seq((Float.MaxValue, Float.MinValue))) + testAbsLegacyOverflow(Seq((Double.MaxValue, Double.MinValue))) } test("ceil and floor") { From fe2a0033c71b67d65dd54ebfe7c1919089d4fcbf Mon Sep 17 00:00:00 2001 From: Pablo Langa Date: Tue, 4 Jun 2024 08:12:27 +0200 Subject: [PATCH 17/21] Format --- .../scala/org/apache/comet/CometExpressionSuite.scala | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index cc7051d17..74edc7d20 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -19,6 +19,9 @@ package org.apache.comet +import scala.reflect.ClassTag +import scala.reflect.runtime.universe.TypeTag + import org.apache.hadoop.fs.Path import org.apache.spark.sql.{CometTestBase, DataFrame, Row} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper @@ -26,10 +29,8 @@ import org.apache.spark.sql.functions.expr import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.SESSION_LOCAL_TIMEZONE import org.apache.spark.sql.types.{Decimal, DecimalType} -import org.apache.comet.CometSparkSessionExtensions.{isSpark32, isSpark33Plus, isSpark34Plus} -import scala.reflect.ClassTag -import scala.reflect.runtime.universe.TypeTag +import org.apache.comet.CometSparkSessionExtensions.{isSpark32, isSpark33Plus, isSpark34Plus} class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { import testImplicits._ @@ -874,8 +875,8 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } withSQLConf( - SQLConf.ANSI_ENABLED.key -> "true", - CometConf.COMET_ANSI_MODE_ENABLED.key -> "true") { + SQLConf.ANSI_ENABLED.key -> "true", + CometConf.COMET_ANSI_MODE_ENABLED.key -> "true") { testAbsAnsiOverflow(Seq((Byte.MaxValue, Byte.MinValue))) testAbsAnsiOverflow(Seq((Short.MaxValue, Short.MinValue))) testAbsAnsiOverflow(Seq((Int.MaxValue, Int.MinValue))) From 3dff4bb73b80b831bbb6d74b56b23422f22d51d5 Mon Sep 17 00:00:00 2001 From: Pablo Langa Date: Wed, 5 Jun 2024 21:14:19 +0200 Subject: [PATCH 18/21] Refactor --- .../execution/datafusion/expressions/abs.rs | 28 +++++++++++-------- .../execution/datafusion/expressions/mod.rs | 21 ++++++++++++++ .../datafusion/expressions/negative.rs | 8 ++---- core/src/execution/datafusion/planner.rs | 22 ++++----------- 4 files changed, 45 insertions(+), 34 deletions(-) diff --git a/core/src/execution/datafusion/expressions/abs.rs b/core/src/execution/datafusion/expressions/abs.rs index 9923d2652..e22108a02 100644 --- a/core/src/execution/datafusion/expressions/abs.rs +++ b/core/src/execution/datafusion/expressions/abs.rs @@ -22,15 +22,9 @@ use datafusion_common::DataFusionError; use datafusion_functions::math; use std::{any::Any, sync::Arc}; -use crate::errors::CometError; +use crate::execution::operators::ExecutionError; -use super::EvalMode; - -fn arithmetic_overflow_error(from_type: &str) -> CometError { - CometError::ArithmeticOverflow { - from_type: from_type.to_string(), - } -} +use super::{arithmetic_overflow_error, EvalMode}; #[derive(Debug)] pub struct CometAbsFunc { @@ -40,12 +34,23 @@ pub struct CometAbsFunc { } impl CometAbsFunc { - pub fn new(eval_mode: EvalMode, data_type_name: String) -> Self { - Self { + pub fn new(eval_mode: EvalMode, data_type_name: String) -> Result { + match eval_mode { + EvalMode::Legacy => (), + EvalMode::Ansi => (), + other => { + return Err(ExecutionError::GeneralError(format!( + "Invalid EvalMode: \"{:?}\"", + other + ))) + } + } + + Ok(Self { inner_abs_func: math::abs().inner(), eval_mode, data_type_name, - } + }) } } @@ -67,7 +72,6 @@ impl ScalarUDFImpl for CometAbsFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { match self.inner_abs_func.invoke(args) { - Ok(result) => Ok(result), Err(DataFusionError::ArrowError(ArrowError::ComputeError(msg), trace)) if msg.contains("overflow") => { diff --git a/core/src/execution/datafusion/expressions/mod.rs b/core/src/execution/datafusion/expressions/mod.rs index 852b5c149..c4916e7d8 100644 --- a/core/src/execution/datafusion/expressions/mod.rs +++ b/core/src/execution/datafusion/expressions/mod.rs @@ -24,6 +24,9 @@ pub mod if_expr; mod normalize_nan; pub mod scalar_funcs; pub use normalize_nan::NormalizeNaNAndZero; +use prost::DecodeError; + +use crate::{errors::CometError, execution::spark_expression}; pub mod abs; pub mod avg; pub mod avg_decimal; @@ -46,3 +49,21 @@ pub enum EvalMode { Ansi, Try, } + +impl TryFrom for EvalMode { + type Error = DecodeError; + + fn try_from(value: i32) -> Result { + match spark_expression::EvalMode::try_from(value)? { + spark_expression::EvalMode::Legacy => Ok(EvalMode::Legacy), + spark_expression::EvalMode::Try => Ok(EvalMode::Try), + spark_expression::EvalMode::Ansi => Ok(EvalMode::Ansi), + } + } +} + +fn arithmetic_overflow_error(from_type: &str) -> CometError { + CometError::ArithmeticOverflow { + from_type: from_type.to_string(), + } +} diff --git a/core/src/execution/datafusion/expressions/negative.rs b/core/src/execution/datafusion/expressions/negative.rs index e7aa2ac64..c9a6c193d 100644 --- a/core/src/execution/datafusion/expressions/negative.rs +++ b/core/src/execution/datafusion/expressions/negative.rs @@ -33,6 +33,8 @@ use std::{ sync::Arc, }; +use super::arithmetic_overflow_error; + pub fn create_negate_expr( expr: Arc, fail_on_error: bool, @@ -48,12 +50,6 @@ pub struct NegativeExpr { fail_on_error: bool, } -fn arithmetic_overflow_error(from_type: &str) -> CometError { - CometError::ArithmeticOverflow { - from_type: from_type.to_string(), - } -} - macro_rules! check_overflow { ($array:expr, $array_type:ty, $min_val:expr, $type_name:expr) => {{ let typed_array = $array diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index 981b2311c..c90cf3740 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -351,11 +351,7 @@ impl PhysicalPlanner { let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?; let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); let timezone = expr.timezone.clone(); - let eval_mode = match spark_expression::EvalMode::try_from(expr.eval_mode)? { - spark_expression::EvalMode::Legacy => EvalMode::Legacy, - spark_expression::EvalMode::Try => EvalMode::Try, - spark_expression::EvalMode::Ansi => EvalMode::Ansi, - }; + let eval_mode = EvalMode::try_from(expr.eval_mode)?; Ok(Arc::new(Cast::new(child, datatype, eval_mode, timezone))) } @@ -494,17 +490,11 @@ impl PhysicalPlanner { let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema.clone())?; let return_type = child.data_type(&input_schema)?; let args = vec![child]; - let eval_mode = match spark_expression::EvalMode::try_from(expr.eval_mode)? { - spark_expression::EvalMode::Legacy => EvalMode::Legacy, - spark_expression::EvalMode::Ansi => EvalMode::Ansi, - spark_expression::EvalMode::Try => { - return Err(ExecutionError::GeneralError( - "Invalid EvalMode: \"TRY\"".to_string(), - )) - } - }; - let comet_abs = - ScalarUDF::new_from_impl(CometAbsFunc::new(eval_mode, return_type.to_string())); + let eval_mode = EvalMode::try_from(expr.eval_mode)?; + let comet_abs = ScalarUDF::new_from_impl(CometAbsFunc::new( + eval_mode, + return_type.to_string(), + )?); let scalar_def = ScalarFunctionDefinition::UDF(Arc::new(comet_abs)); let expr = From 19969d645364f0f7688958d57b91cdbc686e90e5 Mon Sep 17 00:00:00 2001 From: Pablo Langa Date: Wed, 5 Jun 2024 21:21:52 +0200 Subject: [PATCH 19/21] refactor --- .../execution/datafusion/expressions/abs.rs | 26 ++++++++----------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/core/src/execution/datafusion/expressions/abs.rs b/core/src/execution/datafusion/expressions/abs.rs index e22108a02..4eb8c7c1e 100644 --- a/core/src/execution/datafusion/expressions/abs.rs +++ b/core/src/execution/datafusion/expressions/abs.rs @@ -35,22 +35,18 @@ pub struct CometAbsFunc { impl CometAbsFunc { pub fn new(eval_mode: EvalMode, data_type_name: String) -> Result { - match eval_mode { - EvalMode::Legacy => (), - EvalMode::Ansi => (), - other => { - return Err(ExecutionError::GeneralError(format!( - "Invalid EvalMode: \"{:?}\"", - other - ))) - } + if let EvalMode::Legacy | EvalMode::Ansi = eval_mode { + Ok(Self { + inner_abs_func: math::abs().inner(), + eval_mode, + data_type_name, + }) + } else { + Err(ExecutionError::GeneralError(format!( + "Invalid EvalMode: \"{:?}\"", + eval_mode + ))) } - - Ok(Self { - inner_abs_func: math::abs().inner(), - eval_mode, - data_type_name, - }) } } From a72db13ce1f14f645d63badb5ed24c46d09677c2 Mon Sep 17 00:00:00 2001 From: Pablo Langa Date: Sat, 8 Jun 2024 06:56:38 +0200 Subject: [PATCH 20/21] fix merge --- core/src/execution/datafusion/planner.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index a9a360e2c..f0cfcebf5 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -503,7 +503,7 @@ impl PhysicalPlanner { eval_mode, return_type.to_string(), )?); - let expr = ScalarFunctionExpr::new("abs", math::abs(), args, return_type); + let expr = ScalarFunctionExpr::new("abs", comet_abs, args, return_type); Ok(Arc::new(expr)) } ExprStruct::CaseWhen(case_when) => { From 809052d42a2d4dc96ec59a159b3b70609c68a7f6 Mon Sep 17 00:00:00 2001 From: Pablo Langa Date: Sat, 8 Jun 2024 08:00:04 +0200 Subject: [PATCH 21/21] fix merge --- core/src/execution/datafusion/planner.rs | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index f0cfcebf5..f168b958f 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -24,9 +24,7 @@ use datafusion::{ arrow::{compute::SortOptions, datatypes::SchemaRef}, common::DataFusionError, execution::FunctionRegistry, - logical_expr::{ - BuiltinScalarFunction, Operator as DataFusionOperator, ScalarFunctionDefinition, - }, + logical_expr::Operator as DataFusionOperator, physical_expr::{ execution_props::ExecutionProps, expressions::{ @@ -51,8 +49,8 @@ use datafusion_common::{ tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter}, JoinType as DFJoinType, ScalarValue, }; +use datafusion_expr::ScalarUDF; use datafusion_physical_expr_common::aggregate::create_aggregate_expr; -use datafusion_physical_expr::udf::ScalarUDF; use itertools::Itertools; use jni::objects::GlobalRef; use num::{BigInt, ToPrimitive}; @@ -359,7 +357,7 @@ impl PhysicalPlanner { let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?; let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); let timezone = expr.timezone.clone(); - let eval_mode = EvalMode::try_from(expr.eval_mode)?; + let eval_mode = expr.eval_mode.try_into()?; Ok(Arc::new(Cast::new(child, datatype, eval_mode, timezone))) } @@ -498,11 +496,11 @@ impl PhysicalPlanner { let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema.clone())?; let return_type = child.data_type(&input_schema)?; let args = vec![child]; - let eval_mode = EvalMode::try_from(expr.eval_mode)?; - let comet_abs = ScalarUDF::new_from_impl(CometAbsFunc::new( + let eval_mode = expr.eval_mode.try_into()?; + let comet_abs = Arc::new(ScalarUDF::new_from_impl(CometAbsFunc::new( eval_mode, return_type.to_string(), - )?); + )?)); let expr = ScalarFunctionExpr::new("abs", comet_abs, args, return_type); Ok(Arc::new(expr)) } @@ -1793,4 +1791,4 @@ mod tests { })); op } -} \ No newline at end of file +}