From 6143e7a9973521844fb7e898a0f22a9c185972bc Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 5 Jun 2024 20:52:50 -0700 Subject: [PATCH 1/8] chore: Add UnboundColumn to carry datatype for unbound reference (#518) * chore: Add UnboundColumn to carry datatype for unbound reference * Update core/src/execution/datafusion/expressions/unbound.rs --- .../execution/datafusion/expressions/mod.rs | 1 + .../datafusion/expressions/unbound.rs | 110 ++++++++++++++++++ core/src/execution/datafusion/planner.rs | 11 +- 3 files changed, 120 insertions(+), 2 deletions(-) create mode 100644 core/src/execution/datafusion/expressions/unbound.rs diff --git a/core/src/execution/datafusion/expressions/mod.rs b/core/src/execution/datafusion/expressions/mod.rs index 084fef2df..05230b4c2 100644 --- a/core/src/execution/datafusion/expressions/mod.rs +++ b/core/src/execution/datafusion/expressions/mod.rs @@ -36,5 +36,6 @@ pub mod strings; pub mod subquery; pub mod sum_decimal; pub mod temporal; +pub mod unbound; mod utils; pub mod variance; diff --git a/core/src/execution/datafusion/expressions/unbound.rs b/core/src/execution/datafusion/expressions/unbound.rs new file mode 100644 index 000000000..5387b1012 --- /dev/null +++ b/core/src/execution/datafusion/expressions/unbound.rs @@ -0,0 +1,110 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::execution::datafusion::expressions::utils::down_cast_any_ref; +use arrow_array::RecordBatch; +use arrow_schema::{DataType, Schema}; +use datafusion::physical_plan::ColumnarValue; +use datafusion_common::{internal_err, Result}; +use datafusion_physical_expr::PhysicalExpr; +use std::{ + any::Any, + hash::{Hash, Hasher}, + sync::Arc, +}; + +/// This is similar to `UnKnownColumn` in DataFusion, but it has data type. +/// This is only used when the column is not bound to a schema, for example, the +/// inputs to aggregation functions in final aggregation. In the case, we cannot +/// bind the aggregation functions to the input schema which is grouping columns +/// and aggregate buffer attributes in Spark (DataFusion has different design). +/// But when creating certain aggregation functions, we need to know its input +/// data types. As `UnKnownColumn` doesn't have data type, we implement this +/// `UnboundColumn` to carry the data type. +#[derive(Debug, Hash, PartialEq, Eq, Clone)] +pub struct UnboundColumn { + name: String, + datatype: DataType, +} + +impl UnboundColumn { + /// Create a new unbound column expression + pub fn new(name: &str, datatype: DataType) -> Self { + Self { + name: name.to_owned(), + datatype, + } + } + + /// Get the column name + pub fn name(&self) -> &str { + &self.name + } +} + +impl std::fmt::Display for UnboundColumn { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{}, datatype: {}", self.name, self.datatype) + } +} + +impl PhysicalExpr for UnboundColumn { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn std::any::Any { + self + } + + /// Get the data type of this expression, given the schema of the input + fn data_type(&self, _input_schema: &Schema) -> Result { + Ok(self.datatype.clone()) + } + + /// Decide whether this expression is nullable, given the schema of the input + fn nullable(&self, _input_schema: &Schema) -> Result { + Ok(true) + } + + /// Evaluate the expression + fn evaluate(&self, _batch: &RecordBatch) -> Result { + internal_err!("UnboundColumn::evaluate() should not be called") + } + + fn children(&self) -> Vec> { + vec![] + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> Result> { + Ok(self) + } + + fn dyn_hash(&self, state: &mut dyn Hasher) { + let mut s = state; + self.hash(&mut s); + } +} + +impl PartialEq for UnboundColumn { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| self == x) + .unwrap_or(false) + } +} diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index a5bcf5654..7af5f6838 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -33,7 +33,7 @@ use datafusion::{ expressions::{ in_list, BinaryExpr, BitAnd, BitOr, BitXor, CaseExpr, CastExpr, Column, Count, FirstValue, InListExpr, IsNotNullExpr, IsNullExpr, LastValue, - Literal as DataFusionLiteral, Max, Min, NotExpr, Sum, UnKnownColumn, + Literal as DataFusionLiteral, Max, Min, NotExpr, Sum, }, AggregateExpr, PhysicalExpr, PhysicalSortExpr, ScalarFunctionExpr, }, @@ -78,6 +78,7 @@ use crate::{ subquery::Subquery, sum_decimal::SumDecimal, temporal::{DateTruncExec, HourExec, MinuteExec, SecondExec, TimestampTruncExec}, + unbound::UnboundColumn, variance::Variance, NormalizeNaNAndZero, }, @@ -239,7 +240,13 @@ impl PhysicalPlanner { let field = input_schema.field(idx); Ok(Arc::new(Column::new(field.name().as_str(), idx))) } - ExprStruct::Unbound(unbound) => Ok(Arc::new(UnKnownColumn::new(unbound.name.as_str()))), + ExprStruct::Unbound(unbound) => { + let data_type = to_arrow_datatype(unbound.datatype.as_ref().unwrap()); + Ok(Arc::new(UnboundColumn::new( + unbound.name.as_str(), + data_type, + ))) + } ExprStruct::IsNotNull(is_notnull) => { let child = self.create_expr(is_notnull.child.as_ref().unwrap(), input_schema)?; Ok(Arc::new(IsNotNullExpr::new(child))) From 0fd1476cdd9e3892137d55a869fcfe906d74d9de Mon Sep 17 00:00:00 2001 From: KAZUYUKI TANIMURA Date: Wed, 5 Jun 2024 23:50:22 -0700 Subject: [PATCH 2/8] chore: Remove 3.4.2.diff (#528) --- dev/diffs/3.4.2.diff | 2555 ------------------------------------------ 1 file changed, 2555 deletions(-) delete mode 100644 dev/diffs/3.4.2.diff diff --git a/dev/diffs/3.4.2.diff b/dev/diffs/3.4.2.diff deleted file mode 100644 index cd02970da..000000000 --- a/dev/diffs/3.4.2.diff +++ /dev/null @@ -1,2555 +0,0 @@ -diff --git a/pom.xml b/pom.xml -index fab98342498..f2156d790d1 100644 ---- a/pom.xml -+++ b/pom.xml -@@ -148,6 +148,8 @@ - 0.10.0 - 2.5.1 - 2.0.8 -+ 3.4 -+ 0.1.0-SNAPSHOT - org.apache.arrow - arrow-memory-netty + * @@ -258,7 +263,7 @@ under the License. org.apache.arrow - arrow-memory-netty + * @@ -284,7 +289,7 @@ under the License. org.apache.arrow - arrow-memory-netty + * @@ -324,6 +329,11 @@ under the License. commons-logging commons-logging + + + org.apache.arrow + * + @@ -340,6 +350,10 @@ under the License. commons-logging commons-logging + + org.apache.arrow + * + @@ -400,6 +414,11 @@ under the License. commons-logging commons-logging + + + org.apache.arrow + * + From f75aeefab58dc7e14cb70742b9d7bb656b727dbd Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 7 Jun 2024 15:15:23 -0600 Subject: [PATCH 7/8] docs: Improve user documentation for supported operators and expressions (#520) * Improve documentation about supported operators and expressions * Improve documentation about supported operators and expressions * more notes * Add more supported expressions * rename protobuf Negative to UnaryMinus for consistency * format * remove duplicate ASF header * SMJ not disabled by default * Update docs/source/user-guide/operators.md Co-authored-by: Liang-Chi Hsieh * Update docs/source/user-guide/operators.md Co-authored-by: Liang-Chi Hsieh * remove RLike --------- Co-authored-by: Liang-Chi Hsieh --- core/src/execution/datafusion/planner.rs | 2 +- core/src/execution/proto/expr.proto | 4 +- docs/source/user-guide/expressions.md | 267 +++++++++++------- docs/source/user-guide/operators.md | 27 +- .../apache/comet/serde/QueryPlanSerde.scala | 4 +- 5 files changed, 192 insertions(+), 112 deletions(-) diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index 7d504f878..6c7ea0de4 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -566,7 +566,7 @@ impl PhysicalPlanner { let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?; Ok(Arc::new(NotExpr::new(child))) } - ExprStruct::Negative(expr) => { + ExprStruct::UnaryMinus(expr) => { let child: Arc = self.create_expr(expr.child.as_ref().unwrap(), input_schema.clone())?; let result = negative::create_negate_expr(child, expr.fail_on_error); diff --git a/core/src/execution/proto/expr.proto b/core/src/execution/proto/expr.proto index 9c6049013..5192bbd4c 100644 --- a/core/src/execution/proto/expr.proto +++ b/core/src/execution/proto/expr.proto @@ -65,7 +65,7 @@ message Expr { CaseWhen caseWhen = 38; In in = 39; Not not = 40; - Negative negative = 41; + UnaryMinus unary_minus = 41; BitwiseShiftRight bitwiseShiftRight = 42; BitwiseShiftLeft bitwiseShiftLeft = 43; IfExpr if = 44; @@ -452,7 +452,7 @@ message Not { Expr child = 1; } -message Negative { +message UnaryMinus { Expr child = 1; bool fail_on_error = 2; } diff --git a/docs/source/user-guide/expressions.md b/docs/source/user-guide/expressions.md index 775ebedb6..14b6f18d0 100644 --- a/docs/source/user-guide/expressions.md +++ b/docs/source/user-guide/expressions.md @@ -19,99 +19,174 @@ # Supported Spark Expressions -The following Spark expressions are currently available: - -- Literals -- Arithmetic Operators - - UnaryMinus - - Add/Minus/Multiply/Divide/Remainder -- Conditional functions - - Case When - - If -- Cast -- Coalesce -- BloomFilterMightContain -- Boolean functions - - And - - Or - - Not - - EqualTo - - EqualNullSafe - - GreaterThan - - GreaterThanOrEqual - - LessThan - - LessThanOrEqual - - IsNull - - IsNotNull - - In -- String functions - - Substring - - Coalesce - - StringSpace - - Like - - Contains - - Startswith - - Endswith - - Ascii - - Bit_length - - Octet_length - - Upper - - Lower - - Chr - - Initcap - - Trim/Btrim/Ltrim/Rtrim - - Concat_ws - - Repeat - - Length - - Reverse - - Instr - - Replace - - Translate -- Bitwise functions - - Shiftright/Shiftleft -- Date/Time functions - - Year/Hour/Minute/Second -- Hash functions - - Md5 - - Sha2 - - Hash - - Xxhash64 -- Math functions - - Abs - - Acos - - Asin - - Atan - - Atan2 - - Cos - - Exp - - Ln - - Log10 - - Log2 - - Pow - - Round - - Signum - - Sin - - Sqrt - - Tan - - Ceil - - Floor -- Aggregate functions - - Count - - Sum - - Max - - Min - - Avg - - First - - Last - - BitAnd - - BitOr - - BitXor - - BoolAnd - - BoolOr - - CovPopulation - - CovSample - - VariancePop - - VarianceSamp - - StddevPop - - StddevSamp - - Corr +The following Spark expressions are currently available. Any known compatibility issues are noted in the following tables. + +## Literal Values + +| Expression | Notes | +| -------------------------------------- | ----- | +| Literal values of supported data types | | + +## Unary Arithmetic + +| Expression | Notes | +| ---------------- | ----- | +| UnaryMinus (`-`) | | + +## Binary Arithmeticx + +| Expression | Notes | +| --------------- | --------------------------------------------------- | +| Add (`+`) | | +| Subtract (`-`) | | +| Multiply (`*`) | | +| Divide (`/`) | | +| Remainder (`%`) | Comet produces `NaN` instead of `NULL` for `% -0.0` | + +## Conditional Expressions + +| Expression | Notes | +| ---------- | ----- | +| CaseWhen | | +| If | | + +## Comparison + +| Expression | Notes | +| ------------------------- | ----- | +| EqualTo (`=`) | | +| EqualNullSafe (`<=>`) | | +| GreaterThan (`>`) | | +| GreaterThanOrEqual (`>=`) | | +| LessThan (`<`) | | +| LessThanOrEqual (`<=`) | | +| IsNull (`IS NULL`) | | +| IsNotNull (`IS NOT NULL`) | | +| In (`IN`) | | + +## String Functions + +| Expression | Notes | +| --------------- | ----------------------------------------------------------------------------------------------------------- | +| Ascii | | +| BitLength | | +| Chr | | +| ConcatWs | | +| Contains | | +| EndsWith | | +| InitCap | | +| Instr | | +| Length | | +| Like | | +| Lower | | +| OctetLength | | +| Repeat | Negative argument for number of times to repeat causes exception | +| Replace | | +| Reverse | | +| StartsWith | | +| StringSpace | | +| StringTrim | | +| StringTrimBoth | | +| StringTrimLeft | | +| StringTrimRight | | +| Substring | | +| Translate | | +| Upper | | + +## Date/Time Functions + +| Expression | Notes | +| -------------- | ------------------------ | +| DatePart | Only `year` is supported | +| Extract | Only `year` is supported | +| Hour | | +| Minute | | +| Second | | +| TruncDate | | +| TruncTimestamp | | +| Year | | + +## Math Expressions + +| Expression | Notes | +| ---------- | ------------------------------------------------------------------- | +| Abs | | +| Acos | | +| Asin | | +| Atan | | +| Atan2 | | +| Ceil | | +| Cos | | +| Exp | | +| Floor | | +| Log | log(0) will produce `-Infinity` unlike Spark which returns `null` | +| Log2 | log2(0) will produce `-Infinity` unlike Spark which returns `null` | +| Log10 | log10(0) will produce `-Infinity` unlike Spark which returns `null` | +| Pow | | +| Round | | +| Signum | Signum does not differentiate between `0.0` and `-0.0` | +| Sin | | +| Sqrt | | +| Tan | | + +## Hashing Functions + +| Expression | Notes | +| ---------- | ----- | +| Md5 | | +| Hash | | +| Sha2 | | +| XxHash64 | | + +## Boolean Expressions + +| Expression | Notes | +| ---------- | ----- | +| And | | +| Or | | +| Not | | + +## Bitwise Expressions + +| Expression | Notes | +| -------------------- | ----- | +| ShiftLeft (`<<`) | | +| ShiftRight (`>>`) | | +| BitAnd (`&`) | | +| BitOr (`\|`) | | +| BitXor (`^`) | | +| BitwiseNot (`~`) | | +| BoolAnd (`bool_and`) | | +| BoolOr (`bool_or`) | | + +## Aggregate Expressions + +| Expression | Notes | +| ------------- | ----- | +| Avg | | +| BitAndAgg | | +| BitOrAgg | | +| BitXorAgg | | +| Corr | | +| Count | | +| CovPopulation | | +| CovSample | | +| First | | +| Last | | +| Max | | +| Min | | +| StddevPop | | +| StddevSamp | | +| Sum | | +| VariancePop | | +| VarianceSamp | | + +## Other + +| Expression | Notes | +| ----------------------- | ------------------------------------------------------------------------------- | +| Cast | See compatibility guide for list of supported cast expressions and known issues | +| BloomFilterMightContain | | +| ScalarSubquery | | +| Coalesce | | +| NormalizeNaNAndZero | | diff --git a/docs/source/user-guide/operators.md b/docs/source/user-guide/operators.md index ec82e9f69..e3a3ac522 100644 --- a/docs/source/user-guide/operators.md +++ b/docs/source/user-guide/operators.md @@ -19,15 +19,20 @@ # Supported Spark Operators -The following Spark operators are currently available: +The following Spark operators are currently replaced with native versions. Query stages that contain any operators +not supported by Comet will fall back to regular Spark execution. -- FileSourceScanExec/BatchScanExec for Parquet -- Projection -- Filter -- Sort -- Hash Aggregate -- Limit -- Sort-merge Join -- Hash Join -- Shuffle -- Expand +| Operator | Notes | +| -------------------------------------------- | ----- | +| FileSourceScanExec/BatchScanExec for Parquet | | +| Projection | | +| Filter | | +| Sort | | +| Hash Aggregate | | +| Limit | | +| Sort-merge Join | | +| Hash Join | | +| BroadcastHashJoinExec | | +| Shuffle | | +| Expand | | +| Union | | diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 8d81b57c4..448c4ff0f 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -1987,13 +1987,13 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim case UnaryMinus(child, failOnError) => val childExpr = exprToProtoInternal(child, inputs) if (childExpr.isDefined) { - val builder = ExprOuterClass.Negative.newBuilder() + val builder = ExprOuterClass.UnaryMinus.newBuilder() builder.setChild(childExpr.get) builder.setFailOnError(failOnError) Some( ExprOuterClass.Expr .newBuilder() - .setNegative(builder) + .setUnaryMinus(builder) .build()) } else { withInfo(expr, child) From 311e13e3ec5fe40e10e0f2671e317e1142272af4 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 7 Jun 2024 15:15:47 -0600 Subject: [PATCH 8/8] chore: Add CometEvalMode enum to replace string literals (#539) * Add CometEvalMode enum * address feedback --- .../scala/org/apache/comet/GenerateDocs.scala | 6 +- .../apache/comet/expressions/CometCast.scala | 4 +- .../comet/expressions/CometEvalMode.scala | 42 ++++++++++++++ .../apache/comet/serde/QueryPlanSerde.scala | 57 +++++++------------ .../apache/comet/shims/CometExprShim.scala | 6 +- .../apache/comet/shims/CometExprShim.scala | 5 +- .../apache/comet/shims/CometExprShim.scala | 15 ++++- .../apache/comet/shims/CometExprShim.scala | 13 +++++ .../org/apache/comet/CometCastSuite.scala | 4 +- 9 files changed, 107 insertions(+), 45 deletions(-) create mode 100644 spark/src/main/scala/org/apache/comet/expressions/CometEvalMode.scala diff --git a/spark/src/main/scala/org/apache/comet/GenerateDocs.scala b/spark/src/main/scala/org/apache/comet/GenerateDocs.scala index a2d5e2515..fb86389fe 100644 --- a/spark/src/main/scala/org/apache/comet/GenerateDocs.scala +++ b/spark/src/main/scala/org/apache/comet/GenerateDocs.scala @@ -25,7 +25,7 @@ import scala.io.Source import org.apache.spark.sql.catalyst.expressions.Cast -import org.apache.comet.expressions.{CometCast, Compatible, Incompatible} +import org.apache.comet.expressions.{CometCast, CometEvalMode, Compatible, Incompatible} /** * Utility for generating markdown documentation from the configs. @@ -72,7 +72,7 @@ object GenerateDocs { if (Cast.canCast(fromType, toType) && fromType != toType) { val fromTypeName = fromType.typeName.replace("(10,2)", "") val toTypeName = toType.typeName.replace("(10,2)", "") - CometCast.isSupported(fromType, toType, None, "LEGACY") match { + CometCast.isSupported(fromType, toType, None, CometEvalMode.LEGACY) match { case Compatible(notes) => val notesStr = notes.getOrElse("").trim w.write(s"| $fromTypeName | $toTypeName | $notesStr |\n".getBytes) @@ -89,7 +89,7 @@ object GenerateDocs { if (Cast.canCast(fromType, toType) && fromType != toType) { val fromTypeName = fromType.typeName.replace("(10,2)", "") val toTypeName = toType.typeName.replace("(10,2)", "") - CometCast.isSupported(fromType, toType, None, "LEGACY") match { + CometCast.isSupported(fromType, toType, None, CometEvalMode.LEGACY) match { case Incompatible(notes) => val notesStr = notes.getOrElse("").trim w.write(s"| $fromTypeName | $toTypeName | $notesStr |\n".getBytes) diff --git a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala index 11c5a53cc..811c61d46 100644 --- a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala +++ b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala @@ -55,7 +55,7 @@ object CometCast { fromType: DataType, toType: DataType, timeZoneId: Option[String], - evalMode: String): SupportLevel = { + evalMode: CometEvalMode.Value): SupportLevel = { if (fromType == toType) { return Compatible() @@ -102,7 +102,7 @@ object CometCast { private def canCastFromString( toType: DataType, timeZoneId: Option[String], - evalMode: String): SupportLevel = { + evalMode: CometEvalMode.Value): SupportLevel = { toType match { case DataTypes.BooleanType => Compatible() diff --git a/spark/src/main/scala/org/apache/comet/expressions/CometEvalMode.scala b/spark/src/main/scala/org/apache/comet/expressions/CometEvalMode.scala new file mode 100644 index 000000000..59e9c89a6 --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/expressions/CometEvalMode.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.expressions + +/** + * We cannot reference Spark's EvalMode directly because the package is different between Spark + * versions, so we copy it here. + * + * Expression evaluation modes. + * - LEGACY: the default evaluation mode, which is compliant to Hive SQL. + * - ANSI: a evaluation mode which is compliant to ANSI SQL standard. + * - TRY: a evaluation mode for `try_*` functions. It is identical to ANSI evaluation mode + * except for returning null result on errors. + */ +object CometEvalMode extends Enumeration { + val LEGACY, ANSI, TRY = Value + + def fromBoolean(ansiEnabled: Boolean): Value = if (ansiEnabled) { + ANSI + } else { + LEGACY + } + + def fromString(str: String): CometEvalMode.Value = CometEvalMode.withName(str) +} diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 448c4ff0f..ed3f2fae6 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -19,8 +19,6 @@ package org.apache.comet.serde -import java.util.Locale - import scala.collection.JavaConverters._ import org.apache.spark.internal.Logging @@ -45,7 +43,7 @@ import org.apache.spark.unsafe.types.UTF8String import org.apache.comet.CometConf import org.apache.comet.CometSparkSessionExtensions.{isCometOperatorEnabled, isCometScan, isSpark32, isSpark34Plus, withInfo} -import org.apache.comet.expressions.{CometCast, Compatible, Incompatible, Unsupported} +import org.apache.comet.expressions.{CometCast, CometEvalMode, Compatible, Incompatible, Unsupported} import org.apache.comet.serde.ExprOuterClass.{AggExpr, DataType => ProtoDataType, Expr, ScalarFunc} import org.apache.comet.serde.ExprOuterClass.DataType.{DataTypeInfo, DecimalInfo, ListInfo, MapInfo, StructInfo} import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, JoinType, Operator} @@ -578,6 +576,15 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim } } + def evalModeToProto(evalMode: CometEvalMode.Value): ExprOuterClass.EvalMode = { + evalMode match { + case CometEvalMode.LEGACY => ExprOuterClass.EvalMode.LEGACY + case CometEvalMode.TRY => ExprOuterClass.EvalMode.TRY + case CometEvalMode.ANSI => ExprOuterClass.EvalMode.ANSI + case _ => throw new IllegalStateException(s"Invalid evalMode $evalMode") + } + } + /** * Convert a Spark expression to protobuf. * @@ -590,18 +597,6 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim * @return * The protobuf representation of the expression, or None if the expression is not supported */ - - def stringToEvalMode(evalModeStr: String): ExprOuterClass.EvalMode = - evalModeStr.toUpperCase(Locale.ROOT) match { - case "LEGACY" => ExprOuterClass.EvalMode.LEGACY - case "TRY" => ExprOuterClass.EvalMode.TRY - case "ANSI" => ExprOuterClass.EvalMode.ANSI - case invalid => - throw new IllegalArgumentException( - s"Invalid eval mode '$invalid' " - ) // Assuming we want to catch errors strictly - } - def exprToProto( expr: Expression, input: Seq[Attribute], @@ -610,15 +605,14 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim timeZoneId: Option[String], dt: DataType, childExpr: Option[Expr], - evalMode: String): Option[Expr] = { + evalMode: CometEvalMode.Value): Option[Expr] = { val dataType = serializeDataType(dt) - val evalModeEnum = stringToEvalMode(evalMode) // Convert string to enum if (childExpr.isDefined && dataType.isDefined) { val castBuilder = ExprOuterClass.Cast.newBuilder() castBuilder.setChild(childExpr.get) castBuilder.setDatatype(dataType.get) - castBuilder.setEvalMode(evalModeEnum) // Set the enum in protobuf + castBuilder.setEvalMode(evalModeToProto(evalMode)) val timeZone = timeZoneId.getOrElse("UTC") castBuilder.setTimezone(timeZone) @@ -646,26 +640,26 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim inputs: Seq[Attribute], dt: DataType, timeZoneId: Option[String], - actualEvalModeStr: String): Option[Expr] = { + evalMode: CometEvalMode.Value): Option[Expr] = { val childExpr = exprToProtoInternal(child, inputs) if (childExpr.isDefined) { val castSupport = - CometCast.isSupported(child.dataType, dt, timeZoneId, actualEvalModeStr) + CometCast.isSupported(child.dataType, dt, timeZoneId, evalMode) def getIncompatMessage(reason: Option[String]): String = "Comet does not guarantee correct results for cast " + s"from ${child.dataType} to $dt " + - s"with timezone $timeZoneId and evalMode $actualEvalModeStr" + + s"with timezone $timeZoneId and evalMode $evalMode" + reason.map(str => s" ($str)").getOrElse("") castSupport match { case Compatible(_) => - castToProto(timeZoneId, dt, childExpr, actualEvalModeStr) + castToProto(timeZoneId, dt, childExpr, evalMode) case Incompatible(reason) => if (CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.get()) { logWarning(getIncompatMessage(reason)) - castToProto(timeZoneId, dt, childExpr, actualEvalModeStr) + castToProto(timeZoneId, dt, childExpr, evalMode) } else { withInfo( expr, @@ -677,7 +671,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim withInfo( expr, s"Unsupported cast from ${child.dataType} to $dt " + - s"with timezone $timeZoneId and evalMode $actualEvalModeStr") + s"with timezone $timeZoneId and evalMode $evalMode") None } } else { @@ -701,17 +695,10 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim case UnaryExpression(child) if expr.prettyName == "trycast" => val timeZoneId = SQLConf.get.sessionLocalTimeZone - handleCast(child, inputs, expr.dataType, Some(timeZoneId), "TRY") + handleCast(child, inputs, expr.dataType, Some(timeZoneId), CometEvalMode.TRY) - case Cast(child, dt, timeZoneId, evalMode) => - val evalModeStr = if (evalMode.isInstanceOf[Boolean]) { - // Spark 3.2 & 3.3 has ansiEnabled boolean - if (evalMode.asInstanceOf[Boolean]) "ANSI" else "LEGACY" - } else { - // Spark 3.4+ has EvalMode enum with values LEGACY, ANSI, and TRY - evalMode.toString - } - handleCast(child, inputs, dt, timeZoneId, evalModeStr) + case c @ Cast(child, dt, timeZoneId, _) => + handleCast(child, inputs, dt, timeZoneId, evalMode(c)) case add @ Add(left, right, _) if supportedDataType(left.dataType) => val leftExpr = exprToProtoInternal(left, inputs) @@ -2006,7 +1993,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim // TODO: Remove this once we have new DataFusion release which includes // the fix: https://github.com/apache/arrow-datafusion/pull/9459 if (childExpr.isDefined) { - castToProto(None, a.dataType, childExpr, "LEGACY") + castToProto(None, a.dataType, childExpr, CometEvalMode.LEGACY) } else { withInfo(expr, a.children: _*) None diff --git a/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala index f5a578f82..2c6f6ccf4 100644 --- a/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala @@ -18,6 +18,7 @@ */ package org.apache.comet.shims +import org.apache.comet.expressions.CometEvalMode import org.apache.spark.sql.catalyst.expressions._ /** @@ -27,7 +28,10 @@ trait CometExprShim { /** * Returns a tuple of expressions for the `unhex` function. */ - def unhexSerde(unhex: Unhex): (Expression, Expression) = { + protected def unhexSerde(unhex: Unhex): (Expression, Expression) = { (unhex.child, Literal(false)) } + + protected def evalMode(c: Cast): CometEvalMode.Value = CometEvalMode.fromBoolean(c.ansiEnabled) } + diff --git a/spark/src/main/spark-3.3/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-3.3/org/apache/comet/shims/CometExprShim.scala index f5a578f82..150656c23 100644 --- a/spark/src/main/spark-3.3/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-3.3/org/apache/comet/shims/CometExprShim.scala @@ -18,6 +18,7 @@ */ package org.apache.comet.shims +import org.apache.comet.expressions.CometEvalMode import org.apache.spark.sql.catalyst.expressions._ /** @@ -27,7 +28,9 @@ trait CometExprShim { /** * Returns a tuple of expressions for the `unhex` function. */ - def unhexSerde(unhex: Unhex): (Expression, Expression) = { + protected def unhexSerde(unhex: Unhex): (Expression, Expression) = { (unhex.child, Literal(false)) } + + protected def evalMode(c: Cast): CometEvalMode.Value = CometEvalMode.fromBoolean(c.ansiEnabled) } diff --git a/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala index 3f2301f0a..5f4e3fba2 100644 --- a/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala @@ -18,6 +18,7 @@ */ package org.apache.comet.shims +import org.apache.comet.expressions.CometEvalMode import org.apache.spark.sql.catalyst.expressions._ /** @@ -27,7 +28,19 @@ trait CometExprShim { /** * Returns a tuple of expressions for the `unhex` function. */ - def unhexSerde(unhex: Unhex): (Expression, Expression) = { + protected def unhexSerde(unhex: Unhex): (Expression, Expression) = { (unhex.child, Literal(unhex.failOnError)) } + + protected def evalMode(c: Cast): CometEvalMode.Value = + CometEvalModeUtil.fromSparkEvalMode(c.evalMode) } + +object CometEvalModeUtil { + def fromSparkEvalMode(evalMode: EvalMode.Value): CometEvalMode.Value = evalMode match { + case EvalMode.LEGACY => CometEvalMode.LEGACY + case EvalMode.TRY => CometEvalMode.TRY + case EvalMode.ANSI => CometEvalMode.ANSI + } +} + diff --git a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala index 01f923206..5f4e3fba2 100644 --- a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala @@ -18,6 +18,7 @@ */ package org.apache.comet.shims +import org.apache.comet.expressions.CometEvalMode import org.apache.spark.sql.catalyst.expressions._ /** @@ -30,4 +31,16 @@ trait CometExprShim { protected def unhexSerde(unhex: Unhex): (Expression, Expression) = { (unhex.child, Literal(unhex.failOnError)) } + + protected def evalMode(c: Cast): CometEvalMode.Value = + CometEvalModeUtil.fromSparkEvalMode(c.evalMode) } + +object CometEvalModeUtil { + def fromSparkEvalMode(evalMode: EvalMode.Value): CometEvalMode.Value = evalMode match { + case EvalMode.LEGACY => CometEvalMode.LEGACY + case EvalMode.TRY => CometEvalMode.TRY + case EvalMode.ANSI => CometEvalMode.ANSI + } +} + diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index fd2218965..25343f933 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.functions.col import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, DataTypes, DecimalType} -import org.apache.comet.expressions.{CometCast, Compatible} +import org.apache.comet.expressions.{CometCast, CometEvalMode, Compatible} class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { @@ -76,7 +76,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { } else { val testIgnored = tags.get(expectedTestName).exists(s => s.contains("org.scalatest.Ignore")) - CometCast.isSupported(fromType, toType, None, "LEGACY") match { + CometCast.isSupported(fromType, toType, None, CometEvalMode.LEGACY) match { case Compatible(_) => if (testIgnored) { fail(