From 7f9576354c01eb9ff2eb095a6a037b1f494e0fde Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 7 Jun 2024 15:15:47 -0600 Subject: [PATCH] chore: Add CometEvalMode enum to replace string literals (#539) * Add CometEvalMode enum * address feedback --- .../scala/org/apache/comet/GenerateDocs.scala | 6 +- .../apache/comet/expressions/CometCast.scala | 4 +- .../comet/expressions/CometEvalMode.scala | 42 ++++++++++++++ .../apache/comet/serde/QueryPlanSerde.scala | 57 +++++++------------ .../apache/comet/shims/CometExprShim.scala | 6 +- .../apache/comet/shims/CometExprShim.scala | 5 +- .../apache/comet/shims/CometExprShim.scala | 15 ++++- .../apache/comet/shims/CometExprShim.scala | 13 +++++ .../org/apache/comet/CometCastSuite.scala | 4 +- 9 files changed, 107 insertions(+), 45 deletions(-) create mode 100644 spark/src/main/scala/org/apache/comet/expressions/CometEvalMode.scala diff --git a/spark/src/main/scala/org/apache/comet/GenerateDocs.scala b/spark/src/main/scala/org/apache/comet/GenerateDocs.scala index a2d5e25156..fb86389fee 100644 --- a/spark/src/main/scala/org/apache/comet/GenerateDocs.scala +++ b/spark/src/main/scala/org/apache/comet/GenerateDocs.scala @@ -25,7 +25,7 @@ import scala.io.Source import org.apache.spark.sql.catalyst.expressions.Cast -import org.apache.comet.expressions.{CometCast, Compatible, Incompatible} +import org.apache.comet.expressions.{CometCast, CometEvalMode, Compatible, Incompatible} /** * Utility for generating markdown documentation from the configs. @@ -72,7 +72,7 @@ object GenerateDocs { if (Cast.canCast(fromType, toType) && fromType != toType) { val fromTypeName = fromType.typeName.replace("(10,2)", "") val toTypeName = toType.typeName.replace("(10,2)", "") - CometCast.isSupported(fromType, toType, None, "LEGACY") match { + CometCast.isSupported(fromType, toType, None, CometEvalMode.LEGACY) match { case Compatible(notes) => val notesStr = notes.getOrElse("").trim w.write(s"| $fromTypeName | $toTypeName | $notesStr |\n".getBytes) @@ -89,7 +89,7 @@ object GenerateDocs { if (Cast.canCast(fromType, toType) && fromType != toType) { val fromTypeName = fromType.typeName.replace("(10,2)", "") val toTypeName = toType.typeName.replace("(10,2)", "") - CometCast.isSupported(fromType, toType, None, "LEGACY") match { + CometCast.isSupported(fromType, toType, None, CometEvalMode.LEGACY) match { case Incompatible(notes) => val notesStr = notes.getOrElse("").trim w.write(s"| $fromTypeName | $toTypeName | $notesStr |\n".getBytes) diff --git a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala index 11c5a53cc0..811c61d46f 100644 --- a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala +++ b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala @@ -55,7 +55,7 @@ object CometCast { fromType: DataType, toType: DataType, timeZoneId: Option[String], - evalMode: String): SupportLevel = { + evalMode: CometEvalMode.Value): SupportLevel = { if (fromType == toType) { return Compatible() @@ -102,7 +102,7 @@ object CometCast { private def canCastFromString( toType: DataType, timeZoneId: Option[String], - evalMode: String): SupportLevel = { + evalMode: CometEvalMode.Value): SupportLevel = { toType match { case DataTypes.BooleanType => Compatible() diff --git a/spark/src/main/scala/org/apache/comet/expressions/CometEvalMode.scala b/spark/src/main/scala/org/apache/comet/expressions/CometEvalMode.scala new file mode 100644 index 0000000000..59e9c89a6b --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/expressions/CometEvalMode.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.expressions + +/** + * We cannot reference Spark's EvalMode directly because the package is different between Spark + * versions, so we copy it here. + * + * Expression evaluation modes. + * - LEGACY: the default evaluation mode, which is compliant to Hive SQL. + * - ANSI: a evaluation mode which is compliant to ANSI SQL standard. + * - TRY: a evaluation mode for `try_*` functions. It is identical to ANSI evaluation mode + * except for returning null result on errors. + */ +object CometEvalMode extends Enumeration { + val LEGACY, ANSI, TRY = Value + + def fromBoolean(ansiEnabled: Boolean): Value = if (ansiEnabled) { + ANSI + } else { + LEGACY + } + + def fromString(str: String): CometEvalMode.Value = CometEvalMode.withName(str) +} diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 448c4ff0fd..ed3f2fae61 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -19,8 +19,6 @@ package org.apache.comet.serde -import java.util.Locale - import scala.collection.JavaConverters._ import org.apache.spark.internal.Logging @@ -45,7 +43,7 @@ import org.apache.spark.unsafe.types.UTF8String import org.apache.comet.CometConf import org.apache.comet.CometSparkSessionExtensions.{isCometOperatorEnabled, isCometScan, isSpark32, isSpark34Plus, withInfo} -import org.apache.comet.expressions.{CometCast, Compatible, Incompatible, Unsupported} +import org.apache.comet.expressions.{CometCast, CometEvalMode, Compatible, Incompatible, Unsupported} import org.apache.comet.serde.ExprOuterClass.{AggExpr, DataType => ProtoDataType, Expr, ScalarFunc} import org.apache.comet.serde.ExprOuterClass.DataType.{DataTypeInfo, DecimalInfo, ListInfo, MapInfo, StructInfo} import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, JoinType, Operator} @@ -578,6 +576,15 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim } } + def evalModeToProto(evalMode: CometEvalMode.Value): ExprOuterClass.EvalMode = { + evalMode match { + case CometEvalMode.LEGACY => ExprOuterClass.EvalMode.LEGACY + case CometEvalMode.TRY => ExprOuterClass.EvalMode.TRY + case CometEvalMode.ANSI => ExprOuterClass.EvalMode.ANSI + case _ => throw new IllegalStateException(s"Invalid evalMode $evalMode") + } + } + /** * Convert a Spark expression to protobuf. * @@ -590,18 +597,6 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim * @return * The protobuf representation of the expression, or None if the expression is not supported */ - - def stringToEvalMode(evalModeStr: String): ExprOuterClass.EvalMode = - evalModeStr.toUpperCase(Locale.ROOT) match { - case "LEGACY" => ExprOuterClass.EvalMode.LEGACY - case "TRY" => ExprOuterClass.EvalMode.TRY - case "ANSI" => ExprOuterClass.EvalMode.ANSI - case invalid => - throw new IllegalArgumentException( - s"Invalid eval mode '$invalid' " - ) // Assuming we want to catch errors strictly - } - def exprToProto( expr: Expression, input: Seq[Attribute], @@ -610,15 +605,14 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim timeZoneId: Option[String], dt: DataType, childExpr: Option[Expr], - evalMode: String): Option[Expr] = { + evalMode: CometEvalMode.Value): Option[Expr] = { val dataType = serializeDataType(dt) - val evalModeEnum = stringToEvalMode(evalMode) // Convert string to enum if (childExpr.isDefined && dataType.isDefined) { val castBuilder = ExprOuterClass.Cast.newBuilder() castBuilder.setChild(childExpr.get) castBuilder.setDatatype(dataType.get) - castBuilder.setEvalMode(evalModeEnum) // Set the enum in protobuf + castBuilder.setEvalMode(evalModeToProto(evalMode)) val timeZone = timeZoneId.getOrElse("UTC") castBuilder.setTimezone(timeZone) @@ -646,26 +640,26 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim inputs: Seq[Attribute], dt: DataType, timeZoneId: Option[String], - actualEvalModeStr: String): Option[Expr] = { + evalMode: CometEvalMode.Value): Option[Expr] = { val childExpr = exprToProtoInternal(child, inputs) if (childExpr.isDefined) { val castSupport = - CometCast.isSupported(child.dataType, dt, timeZoneId, actualEvalModeStr) + CometCast.isSupported(child.dataType, dt, timeZoneId, evalMode) def getIncompatMessage(reason: Option[String]): String = "Comet does not guarantee correct results for cast " + s"from ${child.dataType} to $dt " + - s"with timezone $timeZoneId and evalMode $actualEvalModeStr" + + s"with timezone $timeZoneId and evalMode $evalMode" + reason.map(str => s" ($str)").getOrElse("") castSupport match { case Compatible(_) => - castToProto(timeZoneId, dt, childExpr, actualEvalModeStr) + castToProto(timeZoneId, dt, childExpr, evalMode) case Incompatible(reason) => if (CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.get()) { logWarning(getIncompatMessage(reason)) - castToProto(timeZoneId, dt, childExpr, actualEvalModeStr) + castToProto(timeZoneId, dt, childExpr, evalMode) } else { withInfo( expr, @@ -677,7 +671,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim withInfo( expr, s"Unsupported cast from ${child.dataType} to $dt " + - s"with timezone $timeZoneId and evalMode $actualEvalModeStr") + s"with timezone $timeZoneId and evalMode $evalMode") None } } else { @@ -701,17 +695,10 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim case UnaryExpression(child) if expr.prettyName == "trycast" => val timeZoneId = SQLConf.get.sessionLocalTimeZone - handleCast(child, inputs, expr.dataType, Some(timeZoneId), "TRY") + handleCast(child, inputs, expr.dataType, Some(timeZoneId), CometEvalMode.TRY) - case Cast(child, dt, timeZoneId, evalMode) => - val evalModeStr = if (evalMode.isInstanceOf[Boolean]) { - // Spark 3.2 & 3.3 has ansiEnabled boolean - if (evalMode.asInstanceOf[Boolean]) "ANSI" else "LEGACY" - } else { - // Spark 3.4+ has EvalMode enum with values LEGACY, ANSI, and TRY - evalMode.toString - } - handleCast(child, inputs, dt, timeZoneId, evalModeStr) + case c @ Cast(child, dt, timeZoneId, _) => + handleCast(child, inputs, dt, timeZoneId, evalMode(c)) case add @ Add(left, right, _) if supportedDataType(left.dataType) => val leftExpr = exprToProtoInternal(left, inputs) @@ -2006,7 +1993,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim // TODO: Remove this once we have new DataFusion release which includes // the fix: https://github.com/apache/arrow-datafusion/pull/9459 if (childExpr.isDefined) { - castToProto(None, a.dataType, childExpr, "LEGACY") + castToProto(None, a.dataType, childExpr, CometEvalMode.LEGACY) } else { withInfo(expr, a.children: _*) None diff --git a/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala index f5a578f820..2c6f6ccf40 100644 --- a/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala @@ -18,6 +18,7 @@ */ package org.apache.comet.shims +import org.apache.comet.expressions.CometEvalMode import org.apache.spark.sql.catalyst.expressions._ /** @@ -27,7 +28,10 @@ trait CometExprShim { /** * Returns a tuple of expressions for the `unhex` function. */ - def unhexSerde(unhex: Unhex): (Expression, Expression) = { + protected def unhexSerde(unhex: Unhex): (Expression, Expression) = { (unhex.child, Literal(false)) } + + protected def evalMode(c: Cast): CometEvalMode.Value = CometEvalMode.fromBoolean(c.ansiEnabled) } + diff --git a/spark/src/main/spark-3.3/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-3.3/org/apache/comet/shims/CometExprShim.scala index f5a578f820..150656c233 100644 --- a/spark/src/main/spark-3.3/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-3.3/org/apache/comet/shims/CometExprShim.scala @@ -18,6 +18,7 @@ */ package org.apache.comet.shims +import org.apache.comet.expressions.CometEvalMode import org.apache.spark.sql.catalyst.expressions._ /** @@ -27,7 +28,9 @@ trait CometExprShim { /** * Returns a tuple of expressions for the `unhex` function. */ - def unhexSerde(unhex: Unhex): (Expression, Expression) = { + protected def unhexSerde(unhex: Unhex): (Expression, Expression) = { (unhex.child, Literal(false)) } + + protected def evalMode(c: Cast): CometEvalMode.Value = CometEvalMode.fromBoolean(c.ansiEnabled) } diff --git a/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala index 3f2301f0a6..5f4e3fba2b 100644 --- a/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala @@ -18,6 +18,7 @@ */ package org.apache.comet.shims +import org.apache.comet.expressions.CometEvalMode import org.apache.spark.sql.catalyst.expressions._ /** @@ -27,7 +28,19 @@ trait CometExprShim { /** * Returns a tuple of expressions for the `unhex` function. */ - def unhexSerde(unhex: Unhex): (Expression, Expression) = { + protected def unhexSerde(unhex: Unhex): (Expression, Expression) = { (unhex.child, Literal(unhex.failOnError)) } + + protected def evalMode(c: Cast): CometEvalMode.Value = + CometEvalModeUtil.fromSparkEvalMode(c.evalMode) } + +object CometEvalModeUtil { + def fromSparkEvalMode(evalMode: EvalMode.Value): CometEvalMode.Value = evalMode match { + case EvalMode.LEGACY => CometEvalMode.LEGACY + case EvalMode.TRY => CometEvalMode.TRY + case EvalMode.ANSI => CometEvalMode.ANSI + } +} + diff --git a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala index 01f9232068..5f4e3fba2b 100644 --- a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala @@ -18,6 +18,7 @@ */ package org.apache.comet.shims +import org.apache.comet.expressions.CometEvalMode import org.apache.spark.sql.catalyst.expressions._ /** @@ -30,4 +31,16 @@ trait CometExprShim { protected def unhexSerde(unhex: Unhex): (Expression, Expression) = { (unhex.child, Literal(unhex.failOnError)) } + + protected def evalMode(c: Cast): CometEvalMode.Value = + CometEvalModeUtil.fromSparkEvalMode(c.evalMode) } + +object CometEvalModeUtil { + def fromSparkEvalMode(evalMode: EvalMode.Value): CometEvalMode.Value = evalMode match { + case EvalMode.LEGACY => CometEvalMode.LEGACY + case EvalMode.TRY => CometEvalMode.TRY + case EvalMode.ANSI => CometEvalMode.ANSI + } +} + diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index fd2218965e..25343f933b 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.functions.col import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, DataTypes, DecimalType} -import org.apache.comet.expressions.{CometCast, Compatible} +import org.apache.comet.expressions.{CometCast, CometEvalMode, Compatible} class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { @@ -76,7 +76,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { } else { val testIgnored = tags.get(expectedTestName).exists(s => s.contains("org.scalatest.Ignore")) - CometCast.isSupported(fromType, toType, None, "LEGACY") match { + CometCast.isSupported(fromType, toType, None, CometEvalMode.LEGACY) match { case Compatible(_) => if (testIgnored) { fail(