diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index d4eb974fb..b1359916f 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -699,6 +699,30 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim case c @ Cast(child, dt, timeZoneId, _) => handleCast(child, inputs, dt, timeZoneId, evalMode(c)) + case expr: Add if evalMode(expr) == CometEvalMode.ANSI => + withInfo(expr, s"ANSI mode not supported") + None + + case expr: Subtract if evalMode(expr) == CometEvalMode.ANSI => + withInfo(expr, s"ANSI mode not supported") + None + + case expr: Multiply if evalMode(expr) == CometEvalMode.ANSI => + withInfo(expr, s"ANSI mode not supported") + None + + case expr: Divide if evalMode(expr) == CometEvalMode.ANSI => + withInfo(expr, s"ANSI mode not supported") + None + + case expr: Remainder if evalMode(expr) == CometEvalMode.ANSI => + withInfo(expr, s"ANSI mode not supported") + None + + case expr: Pmod if evalMode(expr) == CometEvalMode.ANSI => + withInfo(expr, s"ANSI mode not supported") + None + case add @ Add(left, right, _) if supportedDataType(left.dataType) => val leftExpr = exprToProtoInternal(left, inputs) val rightExpr = exprToProtoInternal(right, inputs) diff --git a/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala index 5d2d777c2..9b2d4ea6f 100644 --- a/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala @@ -32,6 +32,24 @@ trait CometExprShim { (unhex.child, Literal(false)) } + def evalMode(expr: Add): CometEvalMode.Value = + CometEvalMode.fromBoolean(expr.failOnError) + + def evalMode(expr: Subtract): CometEvalMode.Value = + CometEvalMode.fromBoolean(expr.failOnError) + + def evalMode(expr: Multiply): CometEvalMode.Value = + CometEvalMode.fromBoolean(expr.failOnError) + + def evalMode(expr: Divide): CometEvalMode.Value = + CometEvalMode.fromBoolean(expr.failOnError) + + def evalMode(expr: Remainder): CometEvalMode.Value = + CometEvalMode.fromBoolean(expr.failOnError) + + def evalMode(expr: Pmod): CometEvalMode.Value = + CometEvalMode.fromBoolean(expr.failOnError) + def evalMode(c: Cast): CometEvalMode.Value = CometEvalMode.fromBoolean(c.ansiEnabled) def evalMode(r: Round): CometEvalMode.Value = CometEvalMode.LEGACY diff --git a/spark/src/main/spark-3.3/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-3.3/org/apache/comet/shims/CometExprShim.scala index 5d2d777c2..9b2d4ea6f 100644 --- a/spark/src/main/spark-3.3/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-3.3/org/apache/comet/shims/CometExprShim.scala @@ -32,6 +32,24 @@ trait CometExprShim { (unhex.child, Literal(false)) } + def evalMode(expr: Add): CometEvalMode.Value = + CometEvalMode.fromBoolean(expr.failOnError) + + def evalMode(expr: Subtract): CometEvalMode.Value = + CometEvalMode.fromBoolean(expr.failOnError) + + def evalMode(expr: Multiply): CometEvalMode.Value = + CometEvalMode.fromBoolean(expr.failOnError) + + def evalMode(expr: Divide): CometEvalMode.Value = + CometEvalMode.fromBoolean(expr.failOnError) + + def evalMode(expr: Remainder): CometEvalMode.Value = + CometEvalMode.fromBoolean(expr.failOnError) + + def evalMode(expr: Pmod): CometEvalMode.Value = + CometEvalMode.fromBoolean(expr.failOnError) + def evalMode(c: Cast): CometEvalMode.Value = CometEvalMode.fromBoolean(c.ansiEnabled) def evalMode(r: Round): CometEvalMode.Value = CometEvalMode.LEGACY diff --git a/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala index a3cd44b61..2d620ce5b 100644 --- a/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala @@ -32,6 +32,24 @@ trait CometExprShim { (unhex.child, Literal(unhex.failOnError)) } + def evalMode(expr: Add): CometEvalMode.Value = + CometEvalModeUtil.fromSparkEvalMode(expr.evalMode) + + def evalMode(expr: Subtract): CometEvalMode.Value = + CometEvalModeUtil.fromSparkEvalMode(expr.evalMode) + + def evalMode(expr: Multiply): CometEvalMode.Value = + CometEvalModeUtil.fromSparkEvalMode(expr.evalMode) + + def evalMode(expr: Divide): CometEvalMode.Value = + CometEvalModeUtil.fromSparkEvalMode(expr.evalMode) + + def evalMode(expr: Remainder): CometEvalMode.Value = + CometEvalModeUtil.fromSparkEvalMode(expr.evalMode) + + def evalMode(expr: Pmod): CometEvalMode.Value = + CometEvalModeUtil.fromSparkEvalMode(expr.evalMode) + def evalMode(c: Cast): CometEvalMode.Value = CometEvalModeUtil.fromSparkEvalMode(c.evalMode) diff --git a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala index a4ac90f09..40429c2e8 100644 --- a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala @@ -32,6 +32,24 @@ trait CometExprShim { (unhex.child, Literal(unhex.failOnError)) } + def evalMode(expr: Add): CometEvalMode.Value = + CometEvalModeUtil.fromSparkEvalMode(expr.evalMode) + + def evalMode(expr: Subtract): CometEvalMode.Value = + CometEvalModeUtil.fromSparkEvalMode(expr.evalMode) + + def evalMode(expr: Multiply): CometEvalMode.Value = + CometEvalModeUtil.fromSparkEvalMode(expr.evalMode) + + def evalMode(expr: Divide): CometEvalMode.Value = + CometEvalModeUtil.fromSparkEvalMode(expr.evalMode) + + def evalMode(expr: Remainder): CometEvalMode.Value = + CometEvalModeUtil.fromSparkEvalMode(expr.evalMode) + + def evalMode(expr: Pmod): CometEvalMode.Value = + CometEvalModeUtil.fromSparkEvalMode(expr.evalMode) + def evalMode(c: Cast): CometEvalMode.Value = CometEvalModeUtil.fromSparkEvalMode(c.evalMode)