From a44747160f26b6f9aab58e56f8ee586a9c507169 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 16 Apr 2024 16:00:40 -0700 Subject: [PATCH] For Spark 3.2 and 3.3 --- .../apache/comet/serde/QueryPlanSerde.scala | 8 ++------ .../comet/shims/ShimQueryPlanSerde.scala | 20 +++++++++++++++++++ .../apache/comet/exec/CometExecSuite.scala | 3 ++- 3 files changed, 24 insertions(+), 7 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index c51a4b47d..172a5b5b6 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -202,9 +202,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { inputs: Seq[Attribute], binding: Boolean): Option[AggExpr] = { aggExpr.aggregateFunction match { - case s @ Sum(child, evalMode) - if sumDataTypeSupported(s.dataType) && - evalMode == EvalMode.LEGACY => + case s @ Sum(child, _) if sumDataTypeSupported(s.dataType) && isLegacyMode(s) => val childExpr = exprToProto(child, inputs, binding) val dataType = serializeDataType(s.dataType) @@ -222,9 +220,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { } else { None } - case s @ Average(child, evalMode) - if avgDataTypeSupported(s.dataType) && - evalMode == EvalMode.LEGACY => + case s @ Average(child, _) if avgDataTypeSupported(s.dataType) && isLegacyMode(s) => val childExpr = exprToProto(child, inputs, binding) val dataType = serializeDataType(s.dataType) diff --git a/spark/src/main/scala/org/apache/comet/shims/ShimQueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/shims/ShimQueryPlanSerde.scala index 7bdf2c0ef..28b08305b 100644 --- a/spark/src/main/scala/org/apache/comet/shims/ShimQueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/shims/ShimQueryPlanSerde.scala @@ -45,6 +45,26 @@ trait ShimQueryPlanSerde { } } + // TODO: delete after drop Spark 3.2/3.3 support + // This method is used to check if the aggregate function is in legacy mode. + // EvalMode is an enum object in Spark 3.4. + def isLegacyMode(aggregate: DeclarativeAggregate): Boolean = { + val evalMode = aggregate.getClass.getDeclaredMethods + .flatMap(m => + m.getName match { + case "evalMode" => Some(m.invoke(aggregate)) + case _ => None + }) + + if (evalMode.isEmpty) { + true + } else { + // scalastyle:off caselocale + evalMode.head.toString.toLowerCase == "legacy" + // scalastyle:on caselocale + } + } + // TODO: delete after drop Spark 3.2 support def isBloomFilterMightContain(binary: BinaryExpression): Boolean = { binary.getClass.getName == "org.apache.spark.sql.catalyst.expressions.BloomFilterMightContain" diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index b0deb3bf1..a8b05cc98 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -46,7 +46,7 @@ import org.apache.spark.sql.internal.SQLConf.SESSION_LOCAL_TIMEZONE import org.apache.spark.unsafe.types.UTF8String import org.apache.comet.CometConf -import org.apache.comet.CometSparkSessionExtensions.isSpark34Plus +import org.apache.comet.CometSparkSessionExtensions.{isSpark33Plus, isSpark34Plus} class CometExecSuite extends CometTestBase { import testImplicits._ @@ -61,6 +61,7 @@ class CometExecSuite extends CometTestBase { } test("try_sum should return null if overflow happens before merging") { + assume(isSpark33Plus, "try_sum is available in Spark 3.3+") val longDf = Seq(Long.MaxValue, Long.MaxValue, 2).toDF("v") val yearMonthDf = Seq(Int.MaxValue, Int.MaxValue, 2) .map(Period.ofMonths)