Skip to content

Commit

Permalink
feat: Use enum to represent CAST eval_mode in expr.proto (apache#415)
Browse files Browse the repository at this point in the history
* Fixes Issue apache#361: Use enum to represent CAST eval_mode in expr.proto

* Update expr.proto and QueryPlanSerde.scala for handling enum EvalMode for cast message

* issue 361 fixed type issue for eval_mode in planner.rs

* issue 361 refactored QueryPlanSerde.scala for enhanced type safety and localization compliance, including a new string-to-enum conversion function and improved import organization.

* Updated planner.rs, expr.proto, QueryPlanSerde.scala for enum support

* Update spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

---------

Co-authored-by: Prashant K. Sharma <[email protected]>
Co-authored-by: Andy Grove <[email protected]>
  • Loading branch information
3 people authored and kazuyukitanimura committed Jul 1, 2024
1 parent c2986db commit 6ae6433
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 13 deletions.
14 changes: 5 additions & 9 deletions core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -348,16 +348,12 @@ impl PhysicalPlanner {
let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?;
let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap());
let timezone = expr.timezone.clone();
let eval_mode = match expr.eval_mode.as_str() {
"ANSI" => EvalMode::Ansi,
"TRY" => EvalMode::Try,
"LEGACY" => EvalMode::Legacy,
other => {
return Err(ExecutionError::GeneralError(format!(
"Invalid Cast EvalMode: \"{other}\""
)))
}
let eval_mode = match spark_expression::EvalMode::try_from(expr.eval_mode)? {
spark_expression::EvalMode::Legacy => EvalMode::Legacy,
spark_expression::EvalMode::Try => EvalMode::Try,
spark_expression::EvalMode::Ansi => EvalMode::Ansi,
};

Ok(Arc::new(Cast::new(child, datatype, eval_mode, timezone)))
}
ExprStruct::Hour(expr) => {
Expand Down
10 changes: 8 additions & 2 deletions core/src/execution/proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -249,12 +249,18 @@ message Remainder {
DataType return_type = 4;
}

enum EvalMode {
LEGACY = 0;
TRY = 1;
ANSI = 2;
}

message Cast {
Expr child = 1;
DataType datatype = 2;
string timezone = 3;
// LEGACY, ANSI, or TRY
string eval_mode = 4;
EvalMode eval_mode = 4;

}

message Equal {
Expand Down
19 changes: 17 additions & 2 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

package org.apache.comet.serde

import java.util.Locale

import scala.collection.JavaConverters._

import org.apache.spark.internal.Logging
Expand Down Expand Up @@ -588,6 +590,18 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
* @return
* The protobuf representation of the expression, or None if the expression is not supported
*/

def stringToEvalMode(evalModeStr: String): ExprOuterClass.EvalMode =
evalModeStr.toUpperCase(Locale.ROOT) match {
case "LEGACY" => ExprOuterClass.EvalMode.LEGACY
case "TRY" => ExprOuterClass.EvalMode.TRY
case "ANSI" => ExprOuterClass.EvalMode.ANSI
case invalid =>
throw new IllegalArgumentException(
s"Invalid eval mode '$invalid' "
) // Assuming we want to catch errors strictly
}

def exprToProto(
expr: Expression,
input: Seq[Attribute],
Expand All @@ -598,12 +612,13 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
childExpr: Option[Expr],
evalMode: String): Option[Expr] = {
val dataType = serializeDataType(dt)
val evalModeEnum = stringToEvalMode(evalMode) // Convert string to enum

if (childExpr.isDefined && dataType.isDefined) {
val castBuilder = ExprOuterClass.Cast.newBuilder()
castBuilder.setChild(childExpr.get)
castBuilder.setDatatype(dataType.get)
castBuilder.setEvalMode(evalMode)
castBuilder.setEvalMode(evalModeEnum) // Set the enum in protobuf

val timeZone = timeZoneId.getOrElse("UTC")
castBuilder.setTimezone(timeZone)
Expand Down Expand Up @@ -1305,7 +1320,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
.newBuilder()
.setChild(e)
.setDatatype(serializeDataType(IntegerType).get)
.setEvalMode("LEGACY") // year is not affected by ANSI mode
.setEvalMode(ExprOuterClass.EvalMode.LEGACY)
.build())
.build()
})
Expand Down

0 comments on commit 6ae6433

Please sign in to comment.