From 901a622d2b8633f13faa7bc68b2b75dcd3fc872b Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 6 Jun 2024 06:26:00 -0600 Subject: [PATCH] fall back for more expressions in ANSI mode --- .../org/apache/comet/serde/QueryPlanSerde.scala | 4 ++++ .../org/apache/comet/shims/CometExprShim.scala | 12 ++++++++++++ .../org/apache/comet/shims/CometExprShim.scala | 12 ++++++++++++ .../org/apache/comet/shims/CometExprShim.scala | 12 ++++++++++++ .../org/apache/comet/shims/CometExprShim.scala | 13 +++++++++++++ 5 files changed, 53 insertions(+) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 89ce7f9cd..c53f3248d 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -2207,9 +2207,13 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim case expr: Subtract => evalMode(expr) == CometEvalMode.ANSI case expr: Multiply => evalMode(expr) == CometEvalMode.ANSI case expr: Divide => evalMode(expr) == CometEvalMode.ANSI + case expr: IntegralDivide => evalMode(expr) == CometEvalMode.ANSI case expr: Remainder => evalMode(expr) == CometEvalMode.ANSI case expr: Pmod => evalMode(expr) == CometEvalMode.ANSI case expr: Round => evalMode(expr) == CometEvalMode.ANSI + case expr: BRound => evalMode(expr) == CometEvalMode.ANSI + case expr: Sum => evalMode(expr) == CometEvalMode.ANSI + case expr: Average => evalMode(expr) == CometEvalMode.ANSI case _ => false } diff --git a/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala index 9b2d4ea6f..a9316593c 100644 --- a/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala @@ -20,6 +20,7 @@ package org.apache.comet.shims import org.apache.comet.expressions.CometEvalMode import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{Average, Sum} /** * `CometExprShim` acts as a shim for for parsing expressions from different Spark versions. @@ -44,13 +45,24 @@ trait CometExprShim { def evalMode(expr: Divide): CometEvalMode.Value = CometEvalMode.fromBoolean(expr.failOnError) + def evalMode(expr: IntegralDivide): CometEvalMode.Value = + CometEvalMode.fromBoolean(expr.failOnError) + def evalMode(expr: Remainder): CometEvalMode.Value = CometEvalMode.fromBoolean(expr.failOnError) def evalMode(expr: Pmod): CometEvalMode.Value = CometEvalMode.fromBoolean(expr.failOnError) + def evalMode(expr: Sum): CometEvalMode.Value = + CometEvalMode.fromBoolean(expr.failOnError) + + def evalMode(expr: Average): CometEvalMode.Value = + CometEvalMode.fromBoolean(expr.failOnError) + def evalMode(c: Cast): CometEvalMode.Value = CometEvalMode.fromBoolean(c.ansiEnabled) def evalMode(r: Round): CometEvalMode.Value = CometEvalMode.LEGACY + + def evalMode(r: BRound): CometEvalMode.Value = CometEvalMode.LEGACY } diff --git a/spark/src/main/spark-3.3/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-3.3/org/apache/comet/shims/CometExprShim.scala index 9b2d4ea6f..a9316593c 100644 --- a/spark/src/main/spark-3.3/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-3.3/org/apache/comet/shims/CometExprShim.scala @@ -20,6 +20,7 @@ package org.apache.comet.shims import org.apache.comet.expressions.CometEvalMode import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{Average, Sum} /** * `CometExprShim` acts as a shim for for parsing expressions from different Spark versions. @@ -44,13 +45,24 @@ trait CometExprShim { def evalMode(expr: Divide): CometEvalMode.Value = CometEvalMode.fromBoolean(expr.failOnError) + def evalMode(expr: IntegralDivide): CometEvalMode.Value = + CometEvalMode.fromBoolean(expr.failOnError) + def evalMode(expr: Remainder): CometEvalMode.Value = CometEvalMode.fromBoolean(expr.failOnError) def evalMode(expr: Pmod): CometEvalMode.Value = CometEvalMode.fromBoolean(expr.failOnError) + def evalMode(expr: Sum): CometEvalMode.Value = + CometEvalMode.fromBoolean(expr.failOnError) + + def evalMode(expr: Average): CometEvalMode.Value = + CometEvalMode.fromBoolean(expr.failOnError) + def evalMode(c: Cast): CometEvalMode.Value = CometEvalMode.fromBoolean(c.ansiEnabled) def evalMode(r: Round): CometEvalMode.Value = CometEvalMode.LEGACY + + def evalMode(r: BRound): CometEvalMode.Value = CometEvalMode.LEGACY } diff --git a/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala index 2d620ce5b..5a42f41f8 100644 --- a/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala @@ -20,6 +20,7 @@ package org.apache.comet.shims import org.apache.comet.expressions.CometEvalMode import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{Average, Sum} /** * `CometExprShim` acts as a shim for for parsing expressions from different Spark versions. @@ -44,16 +45,27 @@ trait CometExprShim { def evalMode(expr: Divide): CometEvalMode.Value = CometEvalModeUtil.fromSparkEvalMode(expr.evalMode) + def evalMode(expr: IntegralDivide): CometEvalMode.Value = + CometEvalModeUtil.fromSparkEvalMode(expr.evalMode) + def evalMode(expr: Remainder): CometEvalMode.Value = CometEvalModeUtil.fromSparkEvalMode(expr.evalMode) def evalMode(expr: Pmod): CometEvalMode.Value = CometEvalModeUtil.fromSparkEvalMode(expr.evalMode) + def evalMode(expr: Sum): CometEvalMode.Value = + CometEvalModeUtil.fromSparkEvalMode(expr.evalMode) + + def evalMode(expr: Average): CometEvalMode.Value = + CometEvalModeUtil.fromSparkEvalMode(expr.evalMode) + def evalMode(c: Cast): CometEvalMode.Value = CometEvalModeUtil.fromSparkEvalMode(c.evalMode) def evalMode(r: Round): CometEvalMode.Value = CometEvalMode.fromBoolean(r.ansiEnabled) + + def evalMode(r: BRound): CometEvalMode.Value = CometEvalMode.fromBoolean(r.ansiEnabled) } object CometEvalModeUtil { diff --git a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala index 40429c2e8..e463e89f8 100644 --- a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala @@ -20,6 +20,7 @@ package org.apache.comet.shims import org.apache.comet.expressions.CometEvalMode import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{Average, Sum} /** * `CometExprShim` acts as a shim for for parsing expressions from different Spark versions. @@ -44,17 +45,29 @@ trait CometExprShim { def evalMode(expr: Divide): CometEvalMode.Value = CometEvalModeUtil.fromSparkEvalMode(expr.evalMode) + def evalMode(expr: IntegralDivide): CometEvalMode.Value = + CometEvalModeUtil.fromSparkEvalMode(expr.evalMode) + def evalMode(expr: Remainder): CometEvalMode.Value = CometEvalModeUtil.fromSparkEvalMode(expr.evalMode) def evalMode(expr: Pmod): CometEvalMode.Value = CometEvalModeUtil.fromSparkEvalMode(expr.evalMode) + def evalMode(expr: Sum): CometEvalMode.Value = + CometEvalModeUtil.fromSparkEvalMode(expr.evalMode) + + def evalMode(expr: Average): CometEvalMode.Value = + CometEvalModeUtil.fromSparkEvalMode(expr.evalMode) + + def evalMode(c: Cast): CometEvalMode.Value = CometEvalModeUtil.fromSparkEvalMode(c.evalMode) def evalMode(r: Round): CometEvalMode.Value = CometEvalMode.fromBoolean(r.ansiEnabled) + def evalMode(r: BRound): CometEvalMode.Value = CometEvalMode.fromBoolean(r.ansiEnabled) + } object CometEvalModeUtil {