Skip to content

Commit

Permalink
For Spark 3.2 and 3.3
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Apr 16, 2024
1 parent ee349ad commit a447471
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde {
inputs: Seq[Attribute],
binding: Boolean): Option[AggExpr] = {
aggExpr.aggregateFunction match {
case s @ Sum(child, evalMode)
if sumDataTypeSupported(s.dataType) &&
evalMode == EvalMode.LEGACY =>
case s @ Sum(child, _) if sumDataTypeSupported(s.dataType) && isLegacyMode(s) =>
val childExpr = exprToProto(child, inputs, binding)
val dataType = serializeDataType(s.dataType)

Expand All @@ -222,9 +220,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde {
} else {
None
}
case s @ Average(child, evalMode)
if avgDataTypeSupported(s.dataType) &&
evalMode == EvalMode.LEGACY =>
case s @ Average(child, _) if avgDataTypeSupported(s.dataType) && isLegacyMode(s) =>
val childExpr = exprToProto(child, inputs, binding)
val dataType = serializeDataType(s.dataType)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,26 @@ trait ShimQueryPlanSerde {
}
}

// TODO: delete after drop Spark 3.2/3.3 support
// This method is used to check if the aggregate function is in legacy mode.
// EvalMode is an enum object in Spark 3.4.
def isLegacyMode(aggregate: DeclarativeAggregate): Boolean = {
val evalMode = aggregate.getClass.getDeclaredMethods
.flatMap(m =>
m.getName match {
case "evalMode" => Some(m.invoke(aggregate))
case _ => None
})

if (evalMode.isEmpty) {
true
} else {
// scalastyle:off caselocale
evalMode.head.toString.toLowerCase == "legacy"
// scalastyle:on caselocale
}
}

// TODO: delete after drop Spark 3.2 support
def isBloomFilterMightContain(binary: BinaryExpression): Boolean = {
binary.getClass.getName == "org.apache.spark.sql.catalyst.expressions.BloomFilterMightContain"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ import org.apache.spark.sql.internal.SQLConf.SESSION_LOCAL_TIMEZONE
import org.apache.spark.unsafe.types.UTF8String

import org.apache.comet.CometConf
import org.apache.comet.CometSparkSessionExtensions.isSpark34Plus
import org.apache.comet.CometSparkSessionExtensions.{isSpark33Plus, isSpark34Plus}

class CometExecSuite extends CometTestBase {
import testImplicits._
Expand All @@ -61,6 +61,7 @@ class CometExecSuite extends CometTestBase {
}

test("try_sum should return null if overflow happens before merging") {
assume(isSpark33Plus, "try_sum is available in Spark 3.3+")
val longDf = Seq(Long.MaxValue, Long.MaxValue, 2).toDF("v")
val yearMonthDf = Seq(Int.MaxValue, Int.MaxValue, 2)
.map(Period.ofMonths)
Expand Down

0 comments on commit a447471

Please sign in to comment.