Skip to content

Commit

Permalink
fix: CometExec's outputPartitioning might not be same as Spark expect…
Browse files Browse the repository at this point in the history
…s after AQE interferes
  • Loading branch information
viirya committed Apr 22, 2024
1 parent f4d3869 commit 45b4588
Show file tree
Hide file tree
Showing 2 changed files with 167 additions and 7 deletions.
149 changes: 143 additions & 6 deletions spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ package org.apache.spark.sql.comet
import java.io.{ByteArrayOutputStream, DataInputStream}
import java.nio.channels.Channels

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.{SparkEnv, TaskContext}
Expand All @@ -30,13 +31,14 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, NamedExpression, SortOrder}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateMode}
import org.apache.spark.sql.catalyst.optimizer.BuildSide
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioningLike, Partitioning, PartitioningCollection, UnknownPartitioning}
import org.apache.spark.sql.comet.execution.shuffle.{ArrowReaderIterator, CometShuffleExchangeExec}
import org.apache.spark.sql.comet.util.Utils
import org.apache.spark.sql.execution.{BinaryExecNode, ColumnarToRowExec, ExecSubqueryExpression, ExplainUtils, LeafExecNode, ScalarSubquery, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.{BinaryExecNode, ColumnarToRowExec, ExecSubqueryExpression, ExplainUtils, LeafExecNode, PartitioningPreservingUnaryExecNode, ScalarSubquery, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -69,6 +71,10 @@ abstract class CometExec extends CometPlan {

override def outputOrdering: Seq[SortOrder] = originalPlan.outputOrdering

// `CometExec` reuses the outputPartitioning of the original SparkPlan.
// Note that if the outputPartitioning of the original SparkPlan depends on its children,
// we should override this method in the specific CometExec, because Spark AQE may change the
// outputPartitioning of SparkPlan, e.g., AQEShuffleReadExec.
override def outputPartitioning: Partitioning = originalPlan.outputPartitioning

/**
Expand Down Expand Up @@ -377,7 +383,8 @@ case class CometProjectExec(
override val output: Seq[Attribute],
child: SparkPlan,
override val serializedPlanOpt: SerializedPlan)
extends CometUnaryExec {
extends CometUnaryExec
with PartitioningPreservingUnaryExecNode {
override def producedAttributes: AttributeSet = outputSet
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
this.copy(child = newChild)
Expand All @@ -396,6 +403,8 @@ case class CometProjectExec(
}

override def hashCode(): Int = Objects.hashCode(projectList, output, child)

override protected def outputExpressions: Seq[NamedExpression] = projectList
}

case class CometFilterExec(
Expand All @@ -405,6 +414,9 @@ case class CometFilterExec(
child: SparkPlan,
override val serializedPlanOpt: SerializedPlan)
extends CometUnaryExec {

override def outputPartitioning: Partitioning = child.outputPartitioning

override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
this.copy(child = newChild)

Expand Down Expand Up @@ -439,6 +451,9 @@ case class CometSortExec(
child: SparkPlan,
override val serializedPlanOpt: SerializedPlan)
extends CometUnaryExec {

override def outputPartitioning: Partitioning = child.outputPartitioning

override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
this.copy(child = newChild)

Expand Down Expand Up @@ -471,6 +486,9 @@ case class CometLocalLimitExec(
child: SparkPlan,
override val serializedPlanOpt: SerializedPlan)
extends CometUnaryExec {

override def outputPartitioning: Partitioning = child.outputPartitioning

override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
this.copy(child = newChild)

Expand Down Expand Up @@ -498,6 +516,9 @@ case class CometGlobalLimitExec(
child: SparkPlan,
override val serializedPlanOpt: SerializedPlan)
extends CometUnaryExec {

override def outputPartitioning: Partitioning = child.outputPartitioning

override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
this.copy(child = newChild)

Expand Down Expand Up @@ -586,7 +607,8 @@ case class CometHashAggregateExec(
mode: Option[AggregateMode],
child: SparkPlan,
override val serializedPlanOpt: SerializedPlan)
extends CometUnaryExec {
extends CometUnaryExec
with PartitioningPreservingUnaryExecNode {
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
this.copy(child = newChild)

Expand Down Expand Up @@ -618,6 +640,9 @@ case class CometHashAggregateExec(

override def hashCode(): Int =
Objects.hashCode(groupingExpressions, aggregateExpressions, input, mode, child)

override protected def outputExpressions: Seq[NamedExpression] =
originalPlan.asInstanceOf[HashAggregateExec].resultExpressions
}

case class CometHashJoinExec(
Expand All @@ -632,6 +657,18 @@ case class CometHashJoinExec(
override val right: SparkPlan,
override val serializedPlanOpt: SerializedPlan)
extends CometBinaryExec {

override def outputPartitioning: Partitioning = joinType match {
case _: InnerLike =>
PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning))
case LeftOuter => left.outputPartitioning
case RightOuter => right.outputPartitioning
case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions)
case LeftExistence(_) => left.outputPartitioning
case x =>
throw new IllegalArgumentException(s"ShuffledJoin should not take $x as the JoinType")
}

override def withNewChildrenInternal(newLeft: SparkPlan, newRight: SparkPlan): SparkPlan =
this.copy(left = newLeft, right = newRight)

Expand Down Expand Up @@ -669,6 +706,94 @@ case class CometBroadcastHashJoinExec(
override val right: SparkPlan,
override val serializedPlanOpt: SerializedPlan)
extends CometBinaryExec {

protected lazy val streamedPlan = buildSide match {
case BuildLeft => right
case BuildRight => left
}

override lazy val outputPartitioning: Partitioning = {
joinType match {
case _: InnerLike if conf.broadcastHashJoinOutputPartitioningExpandLimit > 0 =>
streamedPlan.outputPartitioning match {
case h: HashPartitioningLike => expandOutputPartitioning(h)
case c: PartitioningCollection => expandOutputPartitioning(c)
case other => other
}
case _ => streamedPlan.outputPartitioning
}
}

protected lazy val (buildKeys, streamedKeys) = {
require(
leftKeys.length == rightKeys.length &&
leftKeys
.map(_.dataType)
.zip(rightKeys.map(_.dataType))
.forall(types => types._1.sameType(types._2)),
"Join keys from two sides should have same length and types")
buildSide match {
case BuildLeft => (leftKeys, rightKeys)
case BuildRight => (rightKeys, leftKeys)
}
}

// An one-to-many mapping from a streamed key to build keys.
private lazy val streamedKeyToBuildKeyMapping = {
val mapping = mutable.Map.empty[Expression, Seq[Expression]]
streamedKeys.zip(buildKeys).foreach { case (streamedKey, buildKey) =>
val key = streamedKey.canonicalized
mapping.get(key) match {
case Some(v) => mapping.put(key, v :+ buildKey)
case None => mapping.put(key, Seq(buildKey))
}
}
mapping.toMap
}

// Expands the given partitioning collection recursively.
private def expandOutputPartitioning(
partitioning: PartitioningCollection): PartitioningCollection = {
PartitioningCollection(partitioning.partitionings.flatMap {
case h: HashPartitioningLike => expandOutputPartitioning(h).partitionings
case c: PartitioningCollection => Seq(expandOutputPartitioning(c))
case other => Seq(other)
})
}

// Expands the given hash partitioning by substituting streamed keys with build keys.
// For example, if the expressions for the given partitioning are Seq("a", "b", "c")
// where the streamed keys are Seq("b", "c") and the build keys are Seq("x", "y"),
// the expanded partitioning will have the following expressions:
// Seq("a", "b", "c"), Seq("a", "b", "y"), Seq("a", "x", "c"), Seq("a", "x", "y").
// The expanded expressions are returned as PartitioningCollection.
private def expandOutputPartitioning(
partitioning: HashPartitioningLike): PartitioningCollection = {
val maxNumCombinations = conf.broadcastHashJoinOutputPartitioningExpandLimit
var currentNumCombinations = 0

def generateExprCombinations(
current: Seq[Expression],
accumulated: Seq[Expression]): Seq[Seq[Expression]] = {
if (currentNumCombinations >= maxNumCombinations) {
Nil
} else if (current.isEmpty) {
currentNumCombinations += 1
Seq(accumulated)
} else {
val buildKeysOpt = streamedKeyToBuildKeyMapping.get(current.head.canonicalized)
generateExprCombinations(current.tail, accumulated :+ current.head) ++
buildKeysOpt
.map(_.flatMap(b => generateExprCombinations(current.tail, accumulated :+ b)))
.getOrElse(Nil)
}
}

PartitioningCollection(
generateExprCombinations(partitioning.expressions, Nil)
.map(exprs => partitioning.withNewChildren(exprs).asInstanceOf[HashPartitioningLike]))
}

override def withNewChildrenInternal(newLeft: SparkPlan, newRight: SparkPlan): SparkPlan =
this.copy(left = newLeft, right = newRight)

Expand Down Expand Up @@ -705,6 +830,18 @@ case class CometSortMergeJoinExec(
override val right: SparkPlan,
override val serializedPlanOpt: SerializedPlan)
extends CometBinaryExec {

override def outputPartitioning: Partitioning = joinType match {
case _: InnerLike =>
PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning))
case LeftOuter => left.outputPartitioning
case RightOuter => right.outputPartitioning
case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions)
case LeftExistence(_) => left.outputPartitioning
case x =>
throw new IllegalArgumentException(s"ShuffledJoin should not take $x as the JoinType")
}

override def withNewChildrenInternal(newLeft: SparkPlan, newRight: SparkPlan): SparkPlan =
this.copy(left = newLeft, right = newRight)

Expand Down
25 changes: 24 additions & 1 deletion spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

package org.apache.comet.exec

import java.sql.Date
import java.time.{Duration, Period}

import scala.collection.JavaConverters._
Expand All @@ -34,7 +35,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStatistics, CatalogTable}
import org.apache.spark.sql.catalyst.expressions.Hex
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateMode
import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometCollectLimitExec, CometFilterExec, CometHashAggregateExec, CometProjectExec, CometRowToColumnarExec, CometScanExec, CometTakeOrderedAndProjectExec}
import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometCollectLimitExec, CometFilterExec, CometHashAggregateExec, CometProjectExec, CometRowToColumnarExec, CometScanExec, CometSortExec, CometTakeOrderedAndProjectExec}
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometShuffleExchangeExec}
import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, SQLExecution, UnionExec}
import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec
Expand All @@ -61,6 +62,28 @@ class CometExecSuite extends CometTestBase {
}
}

test("Ensure that the correct outputPartitioning of CometSort") {
withTable("test_data") {
val tableDF = spark.sparkContext
.parallelize(
(1 to 10).map { i =>
(if (i > 4) 5 else i, i.toString, Date.valueOf(s"${2020 + i}-$i-$i"))
},
3)
.toDF("id", "data", "day")
tableDF.write.saveAsTable("test_data")

val df = sql("SELECT * FROM test_data")
.repartition($"data")
.sortWithinPartitions($"id", $"data", $"day")
df.collect()
val sort = stripAQEPlan(df.queryExecution.executedPlan).collect { case s: CometSortExec =>
s
}.head
assert(sort.outputPartitioning == sort.child.outputPartitioning)
}
}

test("Repeated shuffle exchange don't fail") {
assume(isSpark33Plus)
Seq("true", "false").foreach { aqeEnabled =>
Expand Down

0 comments on commit 45b4588

Please sign in to comment.