From ee349ad47a7f18409a411890274ab61d48144453 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 16 Apr 2024 13:56:56 -0700 Subject: [PATCH] fix: Comet should not translate try_sum to native sum expression --- .../org/apache/comet/serde/QueryPlanSerde.scala | 8 ++++++-- .../org/apache/comet/exec/CometExecSuite.scala | 17 ++++++++++++++++- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 26fc708ff..c51a4b47d 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -202,7 +202,9 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { inputs: Seq[Attribute], binding: Boolean): Option[AggExpr] = { aggExpr.aggregateFunction match { - case s @ Sum(child, _) if sumDataTypeSupported(s.dataType) => + case s @ Sum(child, evalMode) + if sumDataTypeSupported(s.dataType) && + evalMode == EvalMode.LEGACY => val childExpr = exprToProto(child, inputs, binding) val dataType = serializeDataType(s.dataType) @@ -220,7 +222,9 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { } else { None } - case s @ Average(child, _) if avgDataTypeSupported(s.dataType) => + case s @ Average(child, evalMode) + if avgDataTypeSupported(s.dataType) && + evalMode == EvalMode.LEGACY => val childExpr = exprToProto(child, inputs, binding) val dataType = serializeDataType(s.dataType) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 0bb21aba7..b0deb3bf1 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -19,6 +19,8 @@ package org.apache.comet.exec +import java.time.{Duration, Period} + import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.Random @@ -38,7 +40,7 @@ import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, SQLExecuti import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec import org.apache.spark.sql.execution.joins.{BroadcastNestedLoopJoinExec, CartesianProductExec, SortMergeJoinExec} import org.apache.spark.sql.execution.window.WindowExec -import org.apache.spark.sql.functions.{date_add, expr, sum} +import org.apache.spark.sql.functions.{col, date_add, expr, sum} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.SESSION_LOCAL_TIMEZONE import org.apache.spark.unsafe.types.UTF8String @@ -58,6 +60,19 @@ class CometExecSuite extends CometTestBase { } } + test("try_sum should return null if overflow happens before merging") { + val longDf = Seq(Long.MaxValue, Long.MaxValue, 2).toDF("v") + val yearMonthDf = Seq(Int.MaxValue, Int.MaxValue, 2) + .map(Period.ofMonths) + .toDF("v") + val dayTimeDf = Seq(106751991L, 106751991L, 2L) + .map(Duration.ofDays) + .toDF("v") + Seq(longDf, yearMonthDf, dayTimeDf).foreach { df => + checkSparkAnswer(df.repartitionByRange(2, col("v")).selectExpr("try_sum(v)")) + } + } + test("Fix corrupted AggregateMode when transforming plan parameters") { withParquetTable((0 until 5).map(i => (i, i + 1)), "table") { val df = sql("SELECT * FROM table").groupBy($"_1").agg(sum("_2"))