Skip to content

Commit

Permalink
fix: Comet should not translate try_sum to native sum expression (#277)
Browse files Browse the repository at this point in the history
* fix: Comet should not translate try_sum to native sum expression

* For Spark 3.2 and 3.3

* Update spark/src/main/scala/org/apache/comet/shims/ShimQueryPlanSerde.scala

Co-authored-by: advancedxy <[email protected]>

* Fix format

---------

Co-authored-by: advancedxy <[email protected]>
  • Loading branch information
viirya and advancedxy authored Apr 17, 2024
1 parent 5c768de commit 9321be6
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde {
inputs: Seq[Attribute],
binding: Boolean): Option[AggExpr] = {
aggExpr.aggregateFunction match {
case s @ Sum(child, _) if sumDataTypeSupported(s.dataType) =>
case s @ Sum(child, _) if sumDataTypeSupported(s.dataType) && isLegacyMode(s) =>
val childExpr = exprToProto(child, inputs, binding)
val dataType = serializeDataType(s.dataType)

Expand All @@ -220,7 +220,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde {
} else {
None
}
case s @ Average(child, _) if avgDataTypeSupported(s.dataType) =>
case s @ Average(child, _) if avgDataTypeSupported(s.dataType) && isLegacyMode(s) =>
val childExpr = exprToProto(child, inputs, binding)
val dataType = serializeDataType(s.dataType)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,24 @@ trait ShimQueryPlanSerde {
}
}

// TODO: delete after drop Spark 3.2/3.3 support
// This method is used to check if the aggregate function is in legacy mode.
// EvalMode is an enum object in Spark 3.4.
def isLegacyMode(aggregate: DeclarativeAggregate): Boolean = {
val evalMode = aggregate.getClass.getDeclaredMethods
.flatMap(m =>
m.getName match {
case "evalMode" => Some(m.invoke(aggregate))
case _ => None
})

if (evalMode.isEmpty) {
true
} else {
"legacy".equalsIgnoreCase(evalMode.head.toString)
}
}

// TODO: delete after drop Spark 3.2 support
def isBloomFilterMightContain(binary: BinaryExpression): Boolean = {
binary.getClass.getName == "org.apache.spark.sql.catalyst.expressions.BloomFilterMightContain"
Expand Down
20 changes: 18 additions & 2 deletions spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

package org.apache.comet.exec

import java.time.{Duration, Period}

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.util.Random
Expand All @@ -38,13 +40,13 @@ import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, SQLExecuti
import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec
import org.apache.spark.sql.execution.joins.{BroadcastNestedLoopJoinExec, CartesianProductExec, SortMergeJoinExec}
import org.apache.spark.sql.execution.window.WindowExec
import org.apache.spark.sql.functions.{date_add, expr, sum}
import org.apache.spark.sql.functions.{col, date_add, expr, sum}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.SESSION_LOCAL_TIMEZONE
import org.apache.spark.unsafe.types.UTF8String

import org.apache.comet.CometConf
import org.apache.comet.CometSparkSessionExtensions.isSpark34Plus
import org.apache.comet.CometSparkSessionExtensions.{isSpark33Plus, isSpark34Plus}

class CometExecSuite extends CometTestBase {
import testImplicits._
Expand All @@ -58,6 +60,20 @@ class CometExecSuite extends CometTestBase {
}
}

test("try_sum should return null if overflow happens before merging") {
assume(isSpark33Plus, "try_sum is available in Spark 3.3+")
val longDf = Seq(Long.MaxValue, Long.MaxValue, 2).toDF("v")
val yearMonthDf = Seq(Int.MaxValue, Int.MaxValue, 2)
.map(Period.ofMonths)
.toDF("v")
val dayTimeDf = Seq(106751991L, 106751991L, 2L)
.map(Duration.ofDays)
.toDF("v")
Seq(longDf, yearMonthDf, dayTimeDf).foreach { df =>
checkSparkAnswer(df.repartitionByRange(2, col("v")).selectExpr("try_sum(v)"))
}
}

test("Fix corrupted AggregateMode when transforming plan parameters") {
withParquetTable((0 until 5).map(i => (i, i + 1)), "table") {
val df = sql("SELECT * FROM table").groupBy($"_1").agg(sum("_2"))
Expand Down

0 comments on commit 9321be6

Please sign in to comment.