diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index a0b75de3f..20119b13a 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -348,16 +348,12 @@ impl PhysicalPlanner { let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?; let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); let timezone = expr.timezone.clone(); - let eval_mode = match expr.eval_mode.as_str() { - "ANSI" => EvalMode::Ansi, - "TRY" => EvalMode::Try, - "LEGACY" => EvalMode::Legacy, - other => { - return Err(ExecutionError::GeneralError(format!( - "Invalid Cast EvalMode: \"{other}\"" - ))) - } + let eval_mode = match spark_expression::EvalMode::try_from(expr.eval_mode)? { + spark_expression::EvalMode::Legacy => EvalMode::Legacy, + spark_expression::EvalMode::Try => EvalMode::Try, + spark_expression::EvalMode::Ansi => EvalMode::Ansi, }; + Ok(Arc::new(Cast::new(child, datatype, eval_mode, timezone))) } ExprStruct::Hour(expr) => { diff --git a/core/src/execution/proto/expr.proto b/core/src/execution/proto/expr.proto index be85e8a92..bcd983875 100644 --- a/core/src/execution/proto/expr.proto +++ b/core/src/execution/proto/expr.proto @@ -249,12 +249,18 @@ message Remainder { DataType return_type = 4; } +enum EvalMode { + LEGACY = 0; + TRY = 1; + ANSI = 2; +} + message Cast { Expr child = 1; DataType datatype = 2; string timezone = 3; - // LEGACY, ANSI, or TRY - string eval_mode = 4; + EvalMode eval_mode = 4; + } message Equal { diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index c1188e193..79a75102d 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -19,6 +19,8 @@ package org.apache.comet.serde +import java.util.Locale + import scala.collection.JavaConverters._ import org.apache.spark.internal.Logging @@ -588,6 +590,18 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim * @return * The protobuf representation of the expression, or None if the expression is not supported */ + + def stringToEvalMode(evalModeStr: String): ExprOuterClass.EvalMode = + evalModeStr.toUpperCase(Locale.ROOT) match { + case "LEGACY" => ExprOuterClass.EvalMode.LEGACY + case "TRY" => ExprOuterClass.EvalMode.TRY + case "ANSI" => ExprOuterClass.EvalMode.ANSI + case invalid => + throw new IllegalArgumentException( + s"Invalid eval mode '$invalid' " + ) // Assuming we want to catch errors strictly + } + def exprToProto( expr: Expression, input: Seq[Attribute], @@ -598,12 +612,13 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim childExpr: Option[Expr], evalMode: String): Option[Expr] = { val dataType = serializeDataType(dt) + val evalModeEnum = stringToEvalMode(evalMode) // Convert string to enum if (childExpr.isDefined && dataType.isDefined) { val castBuilder = ExprOuterClass.Cast.newBuilder() castBuilder.setChild(childExpr.get) castBuilder.setDatatype(dataType.get) - castBuilder.setEvalMode(evalMode) + castBuilder.setEvalMode(evalModeEnum) // Set the enum in protobuf val timeZone = timeZoneId.getOrElse("UTC") castBuilder.setTimezone(timeZone) @@ -1305,7 +1320,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim .newBuilder() .setChild(e) .setDatatype(serializeDataType(IntegerType).get) - .setEvalMode("LEGACY") // year is not affected by ANSI mode + .setEvalMode(ExprOuterClass.EvalMode.LEGACY) .build()) .build() })