From 175f6af730e9377d0ee3240286564d54f9672e36 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 5 Mar 2024 08:49:31 -0800 Subject: [PATCH 1/2] fix: coalesce should return correct datatype --- .../org/apache/comet/CometExpressionSuite.scala | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index cbb77b9ff..b3e60f58a 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -34,6 +34,19 @@ import org.apache.comet.CometSparkSessionExtensions.{isSpark32, isSpark33Plus, i class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { import testImplicits._ + test("coalesce should return correct datatype") { + Seq(true, false).foreach { dictionaryEnabled => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 10000) + withParquetTable(path.toString, "tbl") { + checkSparkAnswerAndOperator( + "SELECT coalesce(cast(_18 as date), cast(_19 as date), _20) FROM tbl") + } + } + } + } + test("bitwise shift with different left/right types") { Seq(false, true).foreach { dictionary => withSQLConf("parquet.enable.dictionary" -> dictionary.toString) { From 05f82d38fa7f7de8d3ad13257eac87cb1d854feb Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 5 Mar 2024 10:58:55 -0800 Subject: [PATCH 2/2] Fix --- .../apache/comet/serde/QueryPlanSerde.scala | 47 +++++++++++-------- 1 file changed, 28 insertions(+), 19 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 8dc2b1596..8b4eaa609 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -349,6 +349,29 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { expr: Expression, input: Seq[Attribute], binding: Boolean = true): Option[Expr] = { + def castToProto( + timeZoneId: Option[String], + dt: DataType, + childExpr: Option[Expr]): Option[Expr] = { + val dataType = serializeDataType(dt) + + if (childExpr.isDefined && dataType.isDefined) { + val castBuilder = ExprOuterClass.Cast.newBuilder() + castBuilder.setChild(childExpr.get) + castBuilder.setDatatype(dataType.get) + + val timeZone = timeZoneId.getOrElse("UTC") + castBuilder.setTimezone(timeZone) + + Some( + ExprOuterClass.Expr + .newBuilder() + .setCast(castBuilder) + .build()) + } else { + None + } + } def exprToProtoInternal(expr: Expression, inputs: Seq[Attribute]): Option[Expr] = { SQLConf.get @@ -363,24 +386,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { case Cast(child, dt, timeZoneId, _) => val childExpr = exprToProtoInternal(child, inputs) - val dataType = serializeDataType(dt) - - if (childExpr.isDefined && dataType.isDefined) { - val castBuilder = ExprOuterClass.Cast.newBuilder() - castBuilder.setChild(childExpr.get) - castBuilder.setDatatype(dataType.get) - - val timeZone = timeZoneId.getOrElse("UTC") - castBuilder.setTimezone(timeZone) - - Some( - ExprOuterClass.Expr - .newBuilder() - .setCast(castBuilder) - .build()) - } else { - None - } + castToProto(timeZoneId, dt, childExpr) case add @ Add(left, right, _) if supportedDataType(left.dataType) => val leftExpr = exprToProtoInternal(left, inputs) @@ -1494,7 +1500,10 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { case a @ Coalesce(_) => val exprChildren = a.children.map(exprToProtoInternal(_, inputs)) - scalarExprToProto("coalesce", exprChildren: _*) + val childExpr = scalarExprToProto("coalesce", exprChildren: _*) + // TODO: Remove this once we have new DataFusion release which includes + // the fix: https://github.com/apache/arrow-datafusion/pull/9459 + castToProto(None, a.dataType, childExpr) // With Spark 3.4, CharVarcharCodegenUtils.readSidePadding gets called to pad spaces for // char types. Use rpad to achieve the behavior.