diff --git a/core/src/execution/proto/expr.proto b/core/src/execution/proto/expr.proto index c2a88a7eb..605e08aad 100644 --- a/core/src/execution/proto/expr.proto +++ b/core/src/execution/proto/expr.proto @@ -233,16 +233,20 @@ message Remainder { DataType return_type = 4; } +enum EvalMode { + LEGACY = 0; + TRY = 1; + ANSI = 2; +} + message Cast { Expr child = 1; DataType datatype = 2; string timezone = 3; - // LEGACY, ANSI, or TRY - enum EvalMode { - LEGACY = 0; - TRY = 1; - ANSI = 2; - } + // Depreciateid: LEGACY, ANSI, or TRY - preserved for backward compatibity + string eval_mode_string= 4; // for backward compatibility + EvalMode eval_mode = 5; // New enum field + } message Equal { diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 6eda0547f..170504c55 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -525,6 +525,18 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { * @return * The protobuf representation of the expression, or None if the expression is not supported */ + + def stringToEvalMode(evalModeStr: String): ExprOuterClass.EvalMode = + evalModeStr.toUpperCase match { + case "LEGACY" => ExprOuterClass.EvalMode.LEGACY + case "TRY" => ExprOuterClass.EvalMode.TRY + case "ANSI" => ExprOuterClass.EvalMode.ANSI + case _ => + throw new IllegalArgumentException( + "Invalid eval mode" + ) // Assuming we want to catch errors strictly + } + def exprToProto( expr: Expression, input: Seq[Attribute], @@ -535,12 +547,14 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { childExpr: Option[Expr], evalMode: String): Option[Expr] = { val dataType = serializeDataType(dt) + val evalModeEnum = stringToEvalMode(evalMode) // Convert string to enum if (childExpr.isDefined && dataType.isDefined) { val castBuilder = ExprOuterClass.Cast.newBuilder() castBuilder.setChild(childExpr.get) castBuilder.setDatatype(dataType.get) - castBuilder.setEvalMode(evalMode) + // castBuilder.setEvalMode(evalMode) + castBuilder.setEvalMode(evalModeEnum) // Set the enum in protobuf val timeZone = timeZoneId.getOrElse("UTC") castBuilder.setTimezone(timeZone) @@ -1207,7 +1221,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { .newBuilder() .setChild(e) .setDatatype(serializeDataType(IntegerType).get) - .setEvalMode("LEGACY") // year is not affected by ANSI mode + .setEvalMode(stringToEvalMode("LEGACY")) // year is not affected by ANSI mode .build()) .build() })