Skip to content

Commit

Permalink
fix: Only trigger Comet Final aggregation on Comet partial aggregation
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Apr 12, 2024
1 parent 421f0e0 commit d5eeec8
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,14 @@ import org.apache.spark.internal.Logging
import org.apache.spark.network.util.ByteUnit
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.SparkSessionExtensions
import org.apache.spark.sql.catalyst.expressions.aggregate.{Final, Partial}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.comet._
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle}
import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
import org.apache.spark.sql.comet.util.Utils
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
Expand Down Expand Up @@ -319,26 +320,42 @@ class CometSparkSessionExtensions
}

case op @ HashAggregateExec(_, _, _, groupingExprs, aggExprs, _, _, _, child) =>
val newOp = transform1(op)
newOp match {
case Some(nativeOp) =>
val modes = aggExprs.map(_.mode).distinct
// The aggExprs could be empty. For example, if the aggregate functions only have
// distinct aggregate functions or only have group by, the aggExprs is empty and
// modes is empty too. If aggExprs is not empty, we need to verify all the aggregates
// have the same mode.
assert(modes.length == 1 || modes.length == 0)
CometHashAggregateExec(
nativeOp,
op,
groupingExprs,
aggExprs,
child.output,
if (modes.nonEmpty) Some(modes.head) else None,
child,
SerializedPlan(None))
case None =>
val modes = aggExprs.map(_.mode).distinct

if (!modes.isEmpty && modes.size != 1) {
// This shouldn't happen as all aggregation expressions should share the same mode.
// Fallback to Spark nevertheless here.
op
} else {
val sparkFinalMode = {
!modes.isEmpty && modes.head == Final && findPartialAgg(child).isEmpty
}

if (sparkFinalMode) {
op
} else {
val newOp = transform1(op)
newOp match {
case Some(nativeOp) =>
val modes = aggExprs.map(_.mode).distinct
// The aggExprs could be empty. For example, if the aggregate functions only have
// distinct aggregate functions or only have group by, the aggExprs is empty and
// modes is empty too. If aggExprs is not empty, we need to verify all the
// aggregates have the same mode.
assert(modes.length == 1 || modes.length == 0)
CometHashAggregateExec(
nativeOp,
op,
groupingExprs,
aggExprs,
child.output,
if (modes.nonEmpty) Some(modes.head) else None,
child,
SerializedPlan(None))
case None =>
op
}
}
}

case op: ShuffledHashJoinExec
Expand Down Expand Up @@ -596,6 +613,20 @@ class CometSparkSessionExtensions
}
}
}

/**
* Find the first Comet partial aggregate in the plan. If it reaches a Spark HashAggregate
* with partial mode, it will return None.
*/
def findPartialAgg(plan: SparkPlan): Option[CometHashAggregateExec] = {
plan.collectFirst {
case agg: CometHashAggregateExec if agg.aggregateExpressions.forall(_.mode == Partial) =>
Some(agg)
case agg: HashAggregateExec if agg.aggregateExpressions.forall(_.mode == Partial) => None
case a: AQEShuffleReadExec => findPartialAgg(a.child)
case s: ShuffleQueryStageExec => findPartialAgg(s.plan)
}.flatten
}
}

// This rule is responsible for eliminating redundant transitions between row-based and
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,19 @@ import org.apache.comet.CometSparkSessionExtensions.isSpark34Plus
class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {
import testImplicits._

test("Only trigger Comet Final aggregation on Comet partial aggregation") {
withTempView("lowerCaseData") {
lowerCaseData.createOrReplaceTempView("lowerCaseData")
withSQLConf(
CometConf.COMET_ENABLED.key -> "true",
CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true",
CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") {
val df = sql("SELECT LAST(n) FROM lowerCaseData")
checkSparkAnswer(df)
}
}
}

test(
"Average expression in Comet Final should handle " +
"all null inputs from partial Spark aggregation") {
Expand Down

0 comments on commit d5eeec8

Please sign in to comment.