diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index a10ac573e..275b9ebff 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -26,13 +26,14 @@ import org.apache.spark.internal.Logging import org.apache.spark.network.util.ByteUnit import org.apache.spark.sql.SparkSession import org.apache.spark.sql.SparkSessionExtensions +import org.apache.spark.sql.catalyst.expressions.aggregate.{Final, Partial} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.comet._ import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle} import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec import org.apache.spark.sql.comet.util.Utils import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, ShuffleQueryStageExec} +import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, BroadcastQueryStageExec, ShuffleQueryStageExec} import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat @@ -319,26 +320,42 @@ class CometSparkSessionExtensions } case op @ HashAggregateExec(_, _, _, groupingExprs, aggExprs, _, _, _, child) => - val newOp = transform1(op) - newOp match { - case Some(nativeOp) => - val modes = aggExprs.map(_.mode).distinct - // The aggExprs could be empty. For example, if the aggregate functions only have - // distinct aggregate functions or only have group by, the aggExprs is empty and - // modes is empty too. If aggExprs is not empty, we need to verify all the aggregates - // have the same mode. - assert(modes.length == 1 || modes.length == 0) - CometHashAggregateExec( - nativeOp, - op, - groupingExprs, - aggExprs, - child.output, - if (modes.nonEmpty) Some(modes.head) else None, - child, - SerializedPlan(None)) - case None => + val modes = aggExprs.map(_.mode).distinct + + if (!modes.isEmpty && modes.size != 1) { + // This shouldn't happen as all aggregation expressions should share the same mode. + // Fallback to Spark nevertheless here. + op + } else { + val sparkFinalMode = { + !modes.isEmpty && modes.head == Final && findPartialAgg(child).isEmpty + } + + if (sparkFinalMode) { op + } else { + val newOp = transform1(op) + newOp match { + case Some(nativeOp) => + val modes = aggExprs.map(_.mode).distinct + // The aggExprs could be empty. For example, if the aggregate functions only have + // distinct aggregate functions or only have group by, the aggExprs is empty and + // modes is empty too. If aggExprs is not empty, we need to verify all the + // aggregates have the same mode. + assert(modes.length == 1 || modes.length == 0) + CometHashAggregateExec( + nativeOp, + op, + groupingExprs, + aggExprs, + child.output, + if (modes.nonEmpty) Some(modes.head) else None, + child, + SerializedPlan(None)) + case None => + op + } + } } case op: ShuffledHashJoinExec @@ -596,6 +613,20 @@ class CometSparkSessionExtensions } } } + + /** + * Find the first Comet partial aggregate in the plan. If it reaches a Spark HashAggregate + * with partial mode, it will return None. + */ + def findPartialAgg(plan: SparkPlan): Option[CometHashAggregateExec] = { + plan.collectFirst { + case agg: CometHashAggregateExec if agg.aggregateExpressions.forall(_.mode == Partial) => + Some(agg) + case agg: HashAggregateExec if agg.aggregateExpressions.forall(_.mode == Partial) => None + case a: AQEShuffleReadExec => findPartialAgg(a.child) + case s: ShuffleQueryStageExec => findPartialAgg(s.plan) + }.flatten + } } // This rule is responsible for eliminating redundant transitions between row-based and diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index b95ce9b19..89681d3df 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -40,6 +40,19 @@ import org.apache.comet.CometSparkSessionExtensions.isSpark34Plus class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { import testImplicits._ + test("Only trigger Comet Final aggregation on Comet partial aggregation") { + withTempView("lowerCaseData") { + lowerCaseData.createOrReplaceTempView("lowerCaseData") + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") { + val df = sql("SELECT LAST(n) FROM lowerCaseData") + checkSparkAnswer(df) + } + } + } + test( "Average expression in Comet Final should handle " + "all null inputs from partial Spark aggregation") {