diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 26fc708ff..172a5b5b6 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -202,7 +202,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { inputs: Seq[Attribute], binding: Boolean): Option[AggExpr] = { aggExpr.aggregateFunction match { - case s @ Sum(child, _) if sumDataTypeSupported(s.dataType) => + case s @ Sum(child, _) if sumDataTypeSupported(s.dataType) && isLegacyMode(s) => val childExpr = exprToProto(child, inputs, binding) val dataType = serializeDataType(s.dataType) @@ -220,7 +220,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { } else { None } - case s @ Average(child, _) if avgDataTypeSupported(s.dataType) => + case s @ Average(child, _) if avgDataTypeSupported(s.dataType) && isLegacyMode(s) => val childExpr = exprToProto(child, inputs, binding) val dataType = serializeDataType(s.dataType) diff --git a/spark/src/main/scala/org/apache/comet/shims/ShimQueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/shims/ShimQueryPlanSerde.scala index 7bdf2c0ef..b92d3fc6a 100644 --- a/spark/src/main/scala/org/apache/comet/shims/ShimQueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/shims/ShimQueryPlanSerde.scala @@ -45,6 +45,24 @@ trait ShimQueryPlanSerde { } } + // TODO: delete after drop Spark 3.2/3.3 support + // This method is used to check if the aggregate function is in legacy mode. + // EvalMode is an enum object in Spark 3.4. + def isLegacyMode(aggregate: DeclarativeAggregate): Boolean = { + val evalMode = aggregate.getClass.getDeclaredMethods + .flatMap(m => + m.getName match { + case "evalMode" => Some(m.invoke(aggregate)) + case _ => None + }) + + if (evalMode.isEmpty) { + true + } else { + "legacy".equalsIgnoreCase(evalMode.head.toString) + } + } + // TODO: delete after drop Spark 3.2 support def isBloomFilterMightContain(binary: BinaryExpression): Boolean = { binary.getClass.getName == "org.apache.spark.sql.catalyst.expressions.BloomFilterMightContain" diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 0bb21aba7..a8b05cc98 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -19,6 +19,8 @@ package org.apache.comet.exec +import java.time.{Duration, Period} + import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.Random @@ -38,13 +40,13 @@ import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, SQLExecuti import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec import org.apache.spark.sql.execution.joins.{BroadcastNestedLoopJoinExec, CartesianProductExec, SortMergeJoinExec} import org.apache.spark.sql.execution.window.WindowExec -import org.apache.spark.sql.functions.{date_add, expr, sum} +import org.apache.spark.sql.functions.{col, date_add, expr, sum} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.SESSION_LOCAL_TIMEZONE import org.apache.spark.unsafe.types.UTF8String import org.apache.comet.CometConf -import org.apache.comet.CometSparkSessionExtensions.isSpark34Plus +import org.apache.comet.CometSparkSessionExtensions.{isSpark33Plus, isSpark34Plus} class CometExecSuite extends CometTestBase { import testImplicits._ @@ -58,6 +60,20 @@ class CometExecSuite extends CometTestBase { } } + test("try_sum should return null if overflow happens before merging") { + assume(isSpark33Plus, "try_sum is available in Spark 3.3+") + val longDf = Seq(Long.MaxValue, Long.MaxValue, 2).toDF("v") + val yearMonthDf = Seq(Int.MaxValue, Int.MaxValue, 2) + .map(Period.ofMonths) + .toDF("v") + val dayTimeDf = Seq(106751991L, 106751991L, 2L) + .map(Duration.ofDays) + .toDF("v") + Seq(longDf, yearMonthDf, dayTimeDf).foreach { df => + checkSparkAnswer(df.repartitionByRange(2, col("v")).selectExpr("try_sum(v)")) + } + } + test("Fix corrupted AggregateMode when transforming plan parameters") { withParquetTable((0 until 5).map(i => (i, i + 1)), "table") { val df = sql("SELECT * FROM table").groupBy($"_1").agg(sum("_2"))