diff --git a/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala index fd33ba8bb..a3cd44b61 100644 --- a/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala @@ -39,7 +39,7 @@ trait CometExprShim { } object CometEvalModeUtil { - def fromSparkEvalMode(evalMode: EvalMode.Value) = evalMode match { + def fromSparkEvalMode(evalMode: EvalMode.Value): CometEvalMode.Value = evalMode match { case EvalMode.LEGACY => CometEvalMode.LEGACY case EvalMode.TRY => CometEvalMode.TRY case EvalMode.ANSI => CometEvalMode.ANSI diff --git a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala index 01f923206..e33ac6fc6 100644 --- a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala @@ -18,6 +18,7 @@ */ package org.apache.comet.shims +import org.apache.comet.expressions.CometEvalMode import org.apache.spark.sql.catalyst.expressions._ /** @@ -30,4 +31,18 @@ trait CometExprShim { protected def unhexSerde(unhex: Unhex): (Expression, Expression) = { (unhex.child, Literal(unhex.failOnError)) } + + def evalMode(c: Cast): CometEvalMode.Value = + CometEvalModeUtil.fromSparkEvalMode(c.evalMode) + + def evalMode(r: Round): CometEvalMode.Value = CometEvalMode.fromBoolean(r.ansiEnabled) + +} + +object CometEvalModeUtil { + def fromSparkEvalMode(evalMode: EvalMode.Value) = evalMode match { + case EvalMode.LEGACY => CometEvalMode.LEGACY + case EvalMode.TRY => CometEvalMode.TRY + case EvalMode.ANSI => CometEvalMode.ANSI + } }