Skip to content

Commit

Permalink
fix: CometExec's outputPartitioning might not be same as Spark expect…
Browse files Browse the repository at this point in the history
…s after AQE interferes (#299)

* fix: CometExec's outputPartitioning might not be same as Spark expects after AQE interferes

* Add compatibility with Spark 3.2 and 3.3

* Remove unused import
  • Loading branch information
viirya authored Apr 23, 2024
1 parent 21717eb commit f789c21
Show file tree
Hide file tree
Showing 5 changed files with 440 additions and 7 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.comet.shims

import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.physical.Partitioning

trait ShimCometBroadcastHashJoinExec {

/**
* Returns the expressions that are used for hash partitioning including `HashPartitioning` and
* `CoalescedHashPartitioning`. They shares same trait `HashPartitioningLike` since Spark 3.4,
* but Spark 3.2/3.3 doesn't have `HashPartitioningLike` and `CoalescedHashPartitioning`.
*
* TODO: remove after dropping Spark 3.2 and 3.3 support.
*/
def getHashPartitioningLikeExpressions(partitioning: Partitioning): Seq[Expression] = {
partitioning.getClass.getDeclaredMethods
.filter(_.getName == "expressions")
.flatMap(_.invoke(partitioning).asInstanceOf[Seq[Expression]])
}
}
157 changes: 151 additions & 6 deletions spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ package org.apache.spark.sql.comet
import java.io.{ByteArrayOutputStream, DataInputStream}
import java.nio.channels.Channels

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.{SparkEnv, TaskContext}
Expand All @@ -30,13 +31,15 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, NamedExpression, SortOrder}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateMode}
import org.apache.spark.sql.catalyst.optimizer.BuildSide
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, PartitioningCollection, UnknownPartitioning}
import org.apache.spark.sql.comet.execution.shuffle.{ArrowReaderIterator, CometShuffleExchangeExec}
import org.apache.spark.sql.comet.plans.PartitioningPreservingUnaryExecNode
import org.apache.spark.sql.comet.util.Utils
import org.apache.spark.sql.execution.{BinaryExecNode, ColumnarToRowExec, ExecSubqueryExpression, ExplainUtils, LeafExecNode, ScalarSubquery, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.internal.SQLConf
Expand All @@ -47,6 +50,7 @@ import com.google.common.base.Objects

import org.apache.comet.{CometConf, CometExecIterator, CometRuntimeException}
import org.apache.comet.serde.OperatorOuterClass.Operator
import org.apache.comet.shims.ShimCometBroadcastHashJoinExec

/**
* A Comet physical operator
Expand All @@ -69,6 +73,10 @@ abstract class CometExec extends CometPlan {

override def outputOrdering: Seq[SortOrder] = originalPlan.outputOrdering

// `CometExec` reuses the outputPartitioning of the original SparkPlan.
// Note that if the outputPartitioning of the original SparkPlan depends on its children,
// we should override this method in the specific CometExec, because Spark AQE may change the
// outputPartitioning of SparkPlan, e.g., AQEShuffleReadExec.
override def outputPartitioning: Partitioning = originalPlan.outputPartitioning

/**
Expand Down Expand Up @@ -377,7 +385,8 @@ case class CometProjectExec(
override val output: Seq[Attribute],
child: SparkPlan,
override val serializedPlanOpt: SerializedPlan)
extends CometUnaryExec {
extends CometUnaryExec
with PartitioningPreservingUnaryExecNode {
override def producedAttributes: AttributeSet = outputSet
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
this.copy(child = newChild)
Expand All @@ -396,6 +405,8 @@ case class CometProjectExec(
}

override def hashCode(): Int = Objects.hashCode(projectList, output, child)

override protected def outputExpressions: Seq[NamedExpression] = projectList
}

case class CometFilterExec(
Expand All @@ -405,6 +416,9 @@ case class CometFilterExec(
child: SparkPlan,
override val serializedPlanOpt: SerializedPlan)
extends CometUnaryExec {

override def outputPartitioning: Partitioning = child.outputPartitioning

override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
this.copy(child = newChild)

Expand Down Expand Up @@ -439,6 +453,9 @@ case class CometSortExec(
child: SparkPlan,
override val serializedPlanOpt: SerializedPlan)
extends CometUnaryExec {

override def outputPartitioning: Partitioning = child.outputPartitioning

override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
this.copy(child = newChild)

Expand Down Expand Up @@ -471,6 +488,9 @@ case class CometLocalLimitExec(
child: SparkPlan,
override val serializedPlanOpt: SerializedPlan)
extends CometUnaryExec {

override def outputPartitioning: Partitioning = child.outputPartitioning

override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
this.copy(child = newChild)

Expand Down Expand Up @@ -498,6 +518,9 @@ case class CometGlobalLimitExec(
child: SparkPlan,
override val serializedPlanOpt: SerializedPlan)
extends CometUnaryExec {

override def outputPartitioning: Partitioning = child.outputPartitioning

override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
this.copy(child = newChild)

Expand Down Expand Up @@ -586,7 +609,8 @@ case class CometHashAggregateExec(
mode: Option[AggregateMode],
child: SparkPlan,
override val serializedPlanOpt: SerializedPlan)
extends CometUnaryExec {
extends CometUnaryExec
with PartitioningPreservingUnaryExecNode {
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
this.copy(child = newChild)

Expand Down Expand Up @@ -618,6 +642,9 @@ case class CometHashAggregateExec(

override def hashCode(): Int =
Objects.hashCode(groupingExpressions, aggregateExpressions, input, mode, child)

override protected def outputExpressions: Seq[NamedExpression] =
originalPlan.asInstanceOf[HashAggregateExec].resultExpressions
}

case class CometHashJoinExec(
Expand All @@ -632,6 +659,18 @@ case class CometHashJoinExec(
override val right: SparkPlan,
override val serializedPlanOpt: SerializedPlan)
extends CometBinaryExec {

override def outputPartitioning: Partitioning = joinType match {
case _: InnerLike =>
PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning))
case LeftOuter => left.outputPartitioning
case RightOuter => right.outputPartitioning
case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions)
case LeftExistence(_) => left.outputPartitioning
case x =>
throw new IllegalArgumentException(s"ShuffledJoin should not take $x as the JoinType")
}

override def withNewChildrenInternal(newLeft: SparkPlan, newRight: SparkPlan): SparkPlan =
this.copy(left = newLeft, right = newRight)

Expand Down Expand Up @@ -668,7 +707,101 @@ case class CometBroadcastHashJoinExec(
override val left: SparkPlan,
override val right: SparkPlan,
override val serializedPlanOpt: SerializedPlan)
extends CometBinaryExec {
extends CometBinaryExec
with ShimCometBroadcastHashJoinExec {

// The following logic of `outputPartitioning` is copied from Spark `BroadcastHashJoinExec`.
protected lazy val streamedPlan: SparkPlan = buildSide match {
case BuildLeft => right
case BuildRight => left
}

override lazy val outputPartitioning: Partitioning = {
joinType match {
case _: InnerLike if conf.broadcastHashJoinOutputPartitioningExpandLimit > 0 =>
streamedPlan.outputPartitioning match {
case h: HashPartitioning => expandOutputPartitioning(h)
case h: Expression if h.getClass.getName.contains("CoalescedHashPartitioning") =>
expandOutputPartitioning(h)
case c: PartitioningCollection => expandOutputPartitioning(c)
case other => other
}
case _ => streamedPlan.outputPartitioning
}
}

protected lazy val (buildKeys, streamedKeys) = {
require(
leftKeys.length == rightKeys.length &&
leftKeys
.map(_.dataType)
.zip(rightKeys.map(_.dataType))
.forall(types => types._1.sameType(types._2)),
"Join keys from two sides should have same length and types")
buildSide match {
case BuildLeft => (leftKeys, rightKeys)
case BuildRight => (rightKeys, leftKeys)
}
}

// An one-to-many mapping from a streamed key to build keys.
private lazy val streamedKeyToBuildKeyMapping = {
val mapping = mutable.Map.empty[Expression, Seq[Expression]]
streamedKeys.zip(buildKeys).foreach { case (streamedKey, buildKey) =>
val key = streamedKey.canonicalized
mapping.get(key) match {
case Some(v) => mapping.put(key, v :+ buildKey)
case None => mapping.put(key, Seq(buildKey))
}
}
mapping.toMap
}

// Expands the given partitioning collection recursively.
private def expandOutputPartitioning(
partitioning: PartitioningCollection): PartitioningCollection = {
PartitioningCollection(partitioning.partitionings.flatMap {
case h: HashPartitioning => expandOutputPartitioning(h).partitionings
case h: Expression if h.getClass.getName.contains("CoalescedHashPartitioning") =>
expandOutputPartitioning(h).partitionings
case c: PartitioningCollection => Seq(expandOutputPartitioning(c))
case other => Seq(other)
})
}

// Expands the given hash partitioning by substituting streamed keys with build keys.
// For example, if the expressions for the given partitioning are Seq("a", "b", "c")
// where the streamed keys are Seq("b", "c") and the build keys are Seq("x", "y"),
// the expanded partitioning will have the following expressions:
// Seq("a", "b", "c"), Seq("a", "b", "y"), Seq("a", "x", "c"), Seq("a", "x", "y").
// The expanded expressions are returned as PartitioningCollection.
private def expandOutputPartitioning(
partitioning: Partitioning with Expression): PartitioningCollection = {
val maxNumCombinations = conf.broadcastHashJoinOutputPartitioningExpandLimit
var currentNumCombinations = 0

def generateExprCombinations(
current: Seq[Expression],
accumulated: Seq[Expression]): Seq[Seq[Expression]] = {
if (currentNumCombinations >= maxNumCombinations) {
Nil
} else if (current.isEmpty) {
currentNumCombinations += 1
Seq(accumulated)
} else {
val buildKeysOpt = streamedKeyToBuildKeyMapping.get(current.head.canonicalized)
generateExprCombinations(current.tail, accumulated :+ current.head) ++
buildKeysOpt
.map(_.flatMap(b => generateExprCombinations(current.tail, accumulated :+ b)))
.getOrElse(Nil)
}
}

PartitioningCollection(
generateExprCombinations(getHashPartitioningLikeExpressions(partitioning), Nil)
.map(exprs => partitioning.withNewChildren(exprs).asInstanceOf[Partitioning]))
}

override def withNewChildrenInternal(newLeft: SparkPlan, newRight: SparkPlan): SparkPlan =
this.copy(left = newLeft, right = newRight)

Expand Down Expand Up @@ -705,6 +838,18 @@ case class CometSortMergeJoinExec(
override val right: SparkPlan,
override val serializedPlanOpt: SerializedPlan)
extends CometBinaryExec {

override def outputPartitioning: Partitioning = joinType match {
case _: InnerLike =>
PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning))
case LeftOuter => left.outputPartitioning
case RightOuter => right.outputPartitioning
case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions)
case LeftExistence(_) => left.outputPartitioning
case x =>
throw new IllegalArgumentException(s"ShuffledJoin should not take $x as the JoinType")
}

override def withNewChildrenInternal(newLeft: SparkPlan, newRight: SparkPlan): SparkPlan =
this.copy(left = newLeft, right = newRight)

Expand Down
Loading

0 comments on commit f789c21

Please sign in to comment.