diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index 44c59a1c0..746566a20 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -573,7 +573,7 @@ impl PhysicalPlanner { let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?; Ok(Arc::new(NotExpr::new(child))) } - ExprStruct::Negative(expr) => { + ExprStruct::UnaryMinus(expr) => { let child: Arc = self.create_expr(expr.child.as_ref().unwrap(), input_schema.clone())?; let result = negative::create_negate_expr(child, expr.fail_on_error); diff --git a/core/src/execution/proto/expr.proto b/core/src/execution/proto/expr.proto index 569f8e69d..8bd1dca8e 100644 --- a/core/src/execution/proto/expr.proto +++ b/core/src/execution/proto/expr.proto @@ -65,7 +65,7 @@ message Expr { CaseWhen caseWhen = 38; In in = 39; Not not = 40; - Negative negative = 41; + UnaryMinus unary_minus = 41; BitwiseShiftRight bitwiseShiftRight = 42; BitwiseShiftLeft bitwiseShiftLeft = 43; IfExpr if = 44; @@ -452,7 +452,7 @@ message Not { Expr child = 1; } -message Negative { +message UnaryMinus { Expr child = 1; bool fail_on_error = 2; } diff --git a/docs/source/user-guide/expressions.md b/docs/source/user-guide/expressions.md index a9a74d900..44e3129e8 100644 --- a/docs/source/user-guide/expressions.md +++ b/docs/source/user-guide/expressions.md @@ -19,99 +19,175 @@ # Supported Spark Expressions -The following Spark expressions are currently available: - -- Literals -- Arithmetic Operators - - UnaryMinus - - Add/Minus/Multiply/Divide/Remainder -- Conditional functions - - Case When - - If -- Cast -- Coalesce -- BloomFilterMightContain -- Boolean functions - - And - - Or - - Not - - EqualTo - - EqualNullSafe - - GreaterThan - - GreaterThanOrEqual - - LessThan - - LessThanOrEqual - - IsNull - - IsNotNull - - In -- String functions - - Substring - - Coalesce - - StringSpace - - Like/RLike - - Contains - - Startswith - - Endswith - - Ascii - - Bit_length - - Octet_length - - Upper - - Lower - - Chr - - Initcap - - Trim/Btrim/Ltrim/Rtrim - - Concat_ws - - Repeat - - Length - - Reverse - - Instr - - Replace - - Translate -- Bitwise functions - - Shiftright/Shiftleft -- Date/Time functions - - Year/Hour/Minute/Second -- Hash functions - - Md5 - - Sha2 - - Hash - - Xxhash64 -- Math functions - - Abs - - Acos - - Asin - - Atan - - Atan2 - - Cos - - Exp - - Ln - - Log10 - - Log2 - - Pow - - Round - - Signum - - Sin - - Sqrt - - Tan - - Ceil - - Floor -- Aggregate functions - - Count - - Sum - - Max - - Min - - Avg - - First - - Last - - BitAnd - - BitOr - - BitXor - - BoolAnd - - BoolOr - - CovPopulation - - CovSample - - VariancePop - - VarianceSamp - - StddevPop - - StddevSamp - - Corr +The following Spark expressions are currently available. Any known compatibility issues are noted in the following tables. + +## Literal Values + +| Expression | Notes | +| -------------------------------------- | ----- | +| Literal values of supported data types | | + +## Unary Arithmetic + +| Expression | Notes | +| ---------------- | ----- | +| UnaryMinus (`-`) | | + +## Binary Arithmeticx + +| Expression | Notes | +| --------------- | --------------------------------------------------- | +| Add (`+`) | | +| Subtract (`-`) | | +| Multiply (`*`) | | +| Divide (`/`) | | +| Remainder (`%`) | Comet produces `NaN` instead of `NULL` for `% -0.0` | + +## Conditional Expressions + +| Expression | Notes | +| ---------- | ----- | +| CaseWhen | | +| If | | + +## Comparison + +| Expression | Notes | +| ------------------------- | ----- | +| EqualTo (`=`) | | +| EqualNullSafe (`<=>`) | | +| GreaterThan (`>`) | | +| GreaterThanOrEqual (`>=`) | | +| LessThan (`<`) | | +| LessThanOrEqual (`<=`) | | +| IsNull (`IS NULL`) | | +| IsNotNull (`IS NOT NULL`) | | +| In (`IN`) | | + +## String Functions + +| Expression | Notes | +| --------------- | ----------------------------------------------------------------------------------------------------------- | +| Ascii | | +| BitLength | | +| Chr | | +| ConcatWs | | +| Contains | | +| EndsWith | | +| InitCap | | +| Instr | | +| Length | | +| Like | | +| Lower | | +| OctetLength | | +| Repeat | Negative argument for number of times to repeat causes exception | +| Replace | | +| Reverse | | +| StartsWith | | +| StringSpace | | +| StringTrim | | +| StringTrimBoth | | +| StringTrimLeft | | +| StringTrimRight | | +| Substring | | +| Translate | | +| Upper | | + +## Date/Time Functions + +| Expression | Notes | +| -------------- | ------------------------ | +| DatePart | Only `year` is supported | +| Extract | Only `year` is supported | +| Hour | | +| Minute | | +| Second | | +| TruncDate | | +| TruncTimestamp | | +| Year | | + +## Math Expressions + +| Expression | Notes | +| ---------- | ------------------------------------------------------------------- | +| Abs | | +| Acos | | +| Asin | | +| Atan | | +| Atan2 | | +| Ceil | | +| Cos | | +| Exp | | +| Floor | | +| Log | log(0) will produce `-Infinity` unlike Spark which returns `null` | +| Log2 | log2(0) will produce `-Infinity` unlike Spark which returns `null` | +| Log10 | log10(0) will produce `-Infinity` unlike Spark which returns `null` | +| Pow | | +| Round | | +| Signum | Signum does not differentiate between `0.0` and `-0.0` | +| Sin | | +| Sqrt | | +| Tan | | + +## Hashing Functions + +| Expression | Notes | +| ---------- | ----- | +| Md5 | | +| Hash | | +| Sha2 | | +| XxHash64 | | + +## Boolean Expressions + +| Expression | Notes | +| ---------- | ----- | +| And | | +| Or | | +| Not | | + +## Bitwise Expressions + +| Expression | Notes | +| -------------------- | ----- | +| ShiftLeft (`<<`) | | +| ShiftRight (`>>`) | | +| BitAnd (`&`) | | +| BitOr (`\|`) | | +| BitXor (`^`) | | +| BitwiseNot (`~`) | | +| BoolAnd (`bool_and`) | | +| BoolOr (`bool_or`) | | + +## Aggregate Expressions + +| Expression | Notes | +| ------------- | ----- | +| Avg | | +| BitAndAgg | | +| BitOrAgg | | +| BitXorAgg | | +| Corr | | +| Count | | +| CovPopulation | | +| CovSample | | +| First | | +| Last | | +| Max | | +| Min | | +| StddevPop | | +| StddevSamp | | +| Sum | | +| VariancePop | | +| VarianceSamp | | + +## Other + +| Expression | Notes | +| ----------------------- | ------------------------------------------------------------------------------- | +| Cast | See compatibility guide for list of supported cast expressions and known issues | +| BloomFilterMightContain | | +| ScalarSubquery | | +| Coalesce | | +| NormalizeNaNAndZero | | + diff --git a/docs/source/user-guide/operators.md b/docs/source/user-guide/operators.md index ec82e9f69..e3a3ac522 100644 --- a/docs/source/user-guide/operators.md +++ b/docs/source/user-guide/operators.md @@ -19,15 +19,20 @@ # Supported Spark Operators -The following Spark operators are currently available: +The following Spark operators are currently replaced with native versions. Query stages that contain any operators +not supported by Comet will fall back to regular Spark execution. -- FileSourceScanExec/BatchScanExec for Parquet -- Projection -- Filter -- Sort -- Hash Aggregate -- Limit -- Sort-merge Join -- Hash Join -- Shuffle -- Expand +| Operator | Notes | +| -------------------------------------------- | ----- | +| FileSourceScanExec/BatchScanExec for Parquet | | +| Projection | | +| Filter | | +| Sort | | +| Hash Aggregate | | +| Limit | | +| Sort-merge Join | | +| Hash Join | | +| BroadcastHashJoinExec | | +| Shuffle | | +| Expand | | +| Union | | diff --git a/spark/src/main/scala/org/apache/comet/GenerateDocs.scala b/spark/src/main/scala/org/apache/comet/GenerateDocs.scala index a2d5e2515..fb86389fe 100644 --- a/spark/src/main/scala/org/apache/comet/GenerateDocs.scala +++ b/spark/src/main/scala/org/apache/comet/GenerateDocs.scala @@ -25,7 +25,7 @@ import scala.io.Source import org.apache.spark.sql.catalyst.expressions.Cast -import org.apache.comet.expressions.{CometCast, Compatible, Incompatible} +import org.apache.comet.expressions.{CometCast, CometEvalMode, Compatible, Incompatible} /** * Utility for generating markdown documentation from the configs. @@ -72,7 +72,7 @@ object GenerateDocs { if (Cast.canCast(fromType, toType) && fromType != toType) { val fromTypeName = fromType.typeName.replace("(10,2)", "") val toTypeName = toType.typeName.replace("(10,2)", "") - CometCast.isSupported(fromType, toType, None, "LEGACY") match { + CometCast.isSupported(fromType, toType, None, CometEvalMode.LEGACY) match { case Compatible(notes) => val notesStr = notes.getOrElse("").trim w.write(s"| $fromTypeName | $toTypeName | $notesStr |\n".getBytes) @@ -89,7 +89,7 @@ object GenerateDocs { if (Cast.canCast(fromType, toType) && fromType != toType) { val fromTypeName = fromType.typeName.replace("(10,2)", "") val toTypeName = toType.typeName.replace("(10,2)", "") - CometCast.isSupported(fromType, toType, None, "LEGACY") match { + CometCast.isSupported(fromType, toType, None, CometEvalMode.LEGACY) match { case Incompatible(notes) => val notesStr = notes.getOrElse("").trim w.write(s"| $fromTypeName | $toTypeName | $notesStr |\n".getBytes) diff --git a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala index 11c5a53cc..811c61d46 100644 --- a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala +++ b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala @@ -55,7 +55,7 @@ object CometCast { fromType: DataType, toType: DataType, timeZoneId: Option[String], - evalMode: String): SupportLevel = { + evalMode: CometEvalMode.Value): SupportLevel = { if (fromType == toType) { return Compatible() @@ -102,7 +102,7 @@ object CometCast { private def canCastFromString( toType: DataType, timeZoneId: Option[String], - evalMode: String): SupportLevel = { + evalMode: CometEvalMode.Value): SupportLevel = { toType match { case DataTypes.BooleanType => Compatible() diff --git a/spark/src/main/scala/org/apache/comet/expressions/CometEvalMode.scala b/spark/src/main/scala/org/apache/comet/expressions/CometEvalMode.scala new file mode 100644 index 000000000..59e9c89a6 --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/expressions/CometEvalMode.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.expressions + +/** + * We cannot reference Spark's EvalMode directly because the package is different between Spark + * versions, so we copy it here. + * + * Expression evaluation modes. + * - LEGACY: the default evaluation mode, which is compliant to Hive SQL. + * - ANSI: a evaluation mode which is compliant to ANSI SQL standard. + * - TRY: a evaluation mode for `try_*` functions. It is identical to ANSI evaluation mode + * except for returning null result on errors. + */ +object CometEvalMode extends Enumeration { + val LEGACY, ANSI, TRY = Value + + def fromBoolean(ansiEnabled: Boolean): Value = if (ansiEnabled) { + ANSI + } else { + LEGACY + } + + def fromString(str: String): CometEvalMode.Value = CometEvalMode.withName(str) +} diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 6fc45187f..2e91dab53 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -19,8 +19,6 @@ package org.apache.comet.serde -import java.util.Locale - import scala.collection.JavaConverters._ import org.apache.spark.internal.Logging @@ -45,7 +43,7 @@ import org.apache.spark.unsafe.types.UTF8String import org.apache.comet.CometConf import org.apache.comet.CometSparkSessionExtensions.{isCometOperatorEnabled, isCometScan, isSpark32, isSpark34Plus, withInfo} -import org.apache.comet.expressions.{CometCast, Compatible, Incompatible, Unsupported} +import org.apache.comet.expressions.{CometCast, CometEvalMode, Compatible, Incompatible, Unsupported} import org.apache.comet.serde.ExprOuterClass.{AggExpr, DataType => ProtoDataType, Expr, ScalarFunc} import org.apache.comet.serde.ExprOuterClass.DataType.{DataTypeInfo, DecimalInfo, ListInfo, MapInfo, StructInfo} import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, JoinType, Operator} @@ -578,6 +576,15 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim } } + def evalModeToProto(evalMode: CometEvalMode.Value): ExprOuterClass.EvalMode = { + evalMode match { + case CometEvalMode.LEGACY => ExprOuterClass.EvalMode.LEGACY + case CometEvalMode.TRY => ExprOuterClass.EvalMode.TRY + case CometEvalMode.ANSI => ExprOuterClass.EvalMode.ANSI + case _ => throw new IllegalStateException(s"Invalid evalMode $evalMode") + } + } + /** * Convert a Spark expression to protobuf. * @@ -590,18 +597,6 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim * @return * The protobuf representation of the expression, or None if the expression is not supported */ - - def stringToEvalMode(evalModeStr: String): ExprOuterClass.EvalMode = - evalModeStr.toUpperCase(Locale.ROOT) match { - case "LEGACY" => ExprOuterClass.EvalMode.LEGACY - case "TRY" => ExprOuterClass.EvalMode.TRY - case "ANSI" => ExprOuterClass.EvalMode.ANSI - case invalid => - throw new IllegalArgumentException( - s"Invalid eval mode '$invalid' " - ) // Assuming we want to catch errors strictly - } - def exprToProto( expr: Expression, input: Seq[Attribute], @@ -610,15 +605,14 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim timeZoneId: Option[String], dt: DataType, childExpr: Option[Expr], - evalMode: String): Option[Expr] = { + evalMode: CometEvalMode.Value): Option[Expr] = { val dataType = serializeDataType(dt) - val evalModeEnum = stringToEvalMode(evalMode) // Convert string to enum if (childExpr.isDefined && dataType.isDefined) { val castBuilder = ExprOuterClass.Cast.newBuilder() castBuilder.setChild(childExpr.get) castBuilder.setDatatype(dataType.get) - castBuilder.setEvalMode(evalModeEnum) // Set the enum in protobuf + castBuilder.setEvalMode(evalModeToProto(evalMode)) val timeZone = timeZoneId.getOrElse("UTC") castBuilder.setTimezone(timeZone) @@ -646,26 +640,26 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim inputs: Seq[Attribute], dt: DataType, timeZoneId: Option[String], - actualEvalModeStr: String): Option[Expr] = { + evalMode: CometEvalMode.Value): Option[Expr] = { val childExpr = exprToProtoInternal(child, inputs) if (childExpr.isDefined) { val castSupport = - CometCast.isSupported(child.dataType, dt, timeZoneId, actualEvalModeStr) + CometCast.isSupported(child.dataType, dt, timeZoneId, evalMode) def getIncompatMessage(reason: Option[String]): String = "Comet does not guarantee correct results for cast " + s"from ${child.dataType} to $dt " + - s"with timezone $timeZoneId and evalMode $actualEvalModeStr" + + s"with timezone $timeZoneId and evalMode $evalMode" + reason.map(str => s" ($str)").getOrElse("") castSupport match { case Compatible(_) => - castToProto(timeZoneId, dt, childExpr, actualEvalModeStr) + castToProto(timeZoneId, dt, childExpr, evalMode) case Incompatible(reason) => if (CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.get()) { logWarning(getIncompatMessage(reason)) - castToProto(timeZoneId, dt, childExpr, actualEvalModeStr) + castToProto(timeZoneId, dt, childExpr, evalMode) } else { withInfo( expr, @@ -677,7 +671,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim withInfo( expr, s"Unsupported cast from ${child.dataType} to $dt " + - s"with timezone $timeZoneId and evalMode $actualEvalModeStr") + s"with timezone $timeZoneId and evalMode $evalMode") None } } else { @@ -701,17 +695,10 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim case UnaryExpression(child) if expr.prettyName == "trycast" => val timeZoneId = SQLConf.get.sessionLocalTimeZone - handleCast(child, inputs, expr.dataType, Some(timeZoneId), "TRY") + handleCast(child, inputs, expr.dataType, Some(timeZoneId), CometEvalMode.TRY) - case Cast(child, dt, timeZoneId, evalMode) => - val evalModeStr = if (evalMode.isInstanceOf[Boolean]) { - // Spark 3.2 & 3.3 has ansiEnabled boolean - if (evalMode.asInstanceOf[Boolean]) "ANSI" else "LEGACY" - } else { - // Spark 3.4+ has EvalMode enum with values LEGACY, ANSI, and TRY - evalMode.toString - } - handleCast(child, inputs, dt, timeZoneId, evalModeStr) + case c @ Cast(child, dt, timeZoneId, _) => + handleCast(child, inputs, dt, timeZoneId, evalMode(c)) case add @ Add(left, right, _) if supportedDataType(left.dataType) => val leftExpr = exprToProtoInternal(left, inputs) @@ -2009,13 +1996,13 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim case UnaryMinus(child, failOnError) => val childExpr = exprToProtoInternal(child, inputs) if (childExpr.isDefined) { - val builder = ExprOuterClass.Negative.newBuilder() + val builder = ExprOuterClass.UnaryMinus.newBuilder() builder.setChild(childExpr.get) builder.setFailOnError(failOnError) Some( ExprOuterClass.Expr .newBuilder() - .setNegative(builder) + .setUnaryMinus(builder) .build()) } else { withInfo(expr, child) @@ -2028,7 +2015,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim // TODO: Remove this once we have new DataFusion release which includes // the fix: https://github.com/apache/arrow-datafusion/pull/9459 if (childExpr.isDefined) { - castToProto(None, a.dataType, childExpr, "LEGACY") + castToProto(None, a.dataType, childExpr, CometEvalMode.LEGACY) } else { withInfo(expr, a.children: _*) None diff --git a/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala index f5a578f82..2c6f6ccf4 100644 --- a/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala @@ -18,6 +18,7 @@ */ package org.apache.comet.shims +import org.apache.comet.expressions.CometEvalMode import org.apache.spark.sql.catalyst.expressions._ /** @@ -27,7 +28,10 @@ trait CometExprShim { /** * Returns a tuple of expressions for the `unhex` function. */ - def unhexSerde(unhex: Unhex): (Expression, Expression) = { + protected def unhexSerde(unhex: Unhex): (Expression, Expression) = { (unhex.child, Literal(false)) } + + protected def evalMode(c: Cast): CometEvalMode.Value = CometEvalMode.fromBoolean(c.ansiEnabled) } + diff --git a/spark/src/main/spark-3.3/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-3.3/org/apache/comet/shims/CometExprShim.scala index f5a578f82..150656c23 100644 --- a/spark/src/main/spark-3.3/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-3.3/org/apache/comet/shims/CometExprShim.scala @@ -18,6 +18,7 @@ */ package org.apache.comet.shims +import org.apache.comet.expressions.CometEvalMode import org.apache.spark.sql.catalyst.expressions._ /** @@ -27,7 +28,9 @@ trait CometExprShim { /** * Returns a tuple of expressions for the `unhex` function. */ - def unhexSerde(unhex: Unhex): (Expression, Expression) = { + protected def unhexSerde(unhex: Unhex): (Expression, Expression) = { (unhex.child, Literal(false)) } + + protected def evalMode(c: Cast): CometEvalMode.Value = CometEvalMode.fromBoolean(c.ansiEnabled) } diff --git a/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala index 3f2301f0a..5f4e3fba2 100644 --- a/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala @@ -18,6 +18,7 @@ */ package org.apache.comet.shims +import org.apache.comet.expressions.CometEvalMode import org.apache.spark.sql.catalyst.expressions._ /** @@ -27,7 +28,19 @@ trait CometExprShim { /** * Returns a tuple of expressions for the `unhex` function. */ - def unhexSerde(unhex: Unhex): (Expression, Expression) = { + protected def unhexSerde(unhex: Unhex): (Expression, Expression) = { (unhex.child, Literal(unhex.failOnError)) } + + protected def evalMode(c: Cast): CometEvalMode.Value = + CometEvalModeUtil.fromSparkEvalMode(c.evalMode) } + +object CometEvalModeUtil { + def fromSparkEvalMode(evalMode: EvalMode.Value): CometEvalMode.Value = evalMode match { + case EvalMode.LEGACY => CometEvalMode.LEGACY + case EvalMode.TRY => CometEvalMode.TRY + case EvalMode.ANSI => CometEvalMode.ANSI + } +} + diff --git a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala index 01f923206..5f4e3fba2 100644 --- a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala @@ -18,6 +18,7 @@ */ package org.apache.comet.shims +import org.apache.comet.expressions.CometEvalMode import org.apache.spark.sql.catalyst.expressions._ /** @@ -30,4 +31,16 @@ trait CometExprShim { protected def unhexSerde(unhex: Unhex): (Expression, Expression) = { (unhex.child, Literal(unhex.failOnError)) } + + protected def evalMode(c: Cast): CometEvalMode.Value = + CometEvalModeUtil.fromSparkEvalMode(c.evalMode) } + +object CometEvalModeUtil { + def fromSparkEvalMode(evalMode: EvalMode.Value): CometEvalMode.Value = evalMode match { + case EvalMode.LEGACY => CometEvalMode.LEGACY + case EvalMode.TRY => CometEvalMode.TRY + case EvalMode.ANSI => CometEvalMode.ANSI + } +} + diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index fd2218965..25343f933 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.functions.col import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, DataTypes, DecimalType} -import org.apache.comet.expressions.{CometCast, Compatible} +import org.apache.comet.expressions.{CometCast, CometEvalMode, Compatible} class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { @@ -76,7 +76,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { } else { val testIgnored = tags.get(expectedTestName).exists(s => s.contains("org.scalatest.Ignore")) - CometCast.isSupported(fromType, toType, None, "LEGACY") match { + CometCast.isSupported(fromType, toType, None, CometEvalMode.LEGACY) match { case Compatible(_) => if (testIgnored) { fail(