diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 8545eee90d..4d6fd0c531 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -22,6 +22,7 @@ package org.apache.spark.sql.comet import java.io.{ByteArrayOutputStream, DataInputStream} import java.nio.channels.Channels +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.spark.{SparkEnv, TaskContext} @@ -30,13 +31,14 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, NamedExpression, SortOrder} import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateMode} -import org.apache.spark.sql.catalyst.optimizer.BuildSide -import org.apache.spark.sql.catalyst.plans.JoinType -import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioningLike, Partitioning, PartitioningCollection, UnknownPartitioning} import org.apache.spark.sql.comet.execution.shuffle.{ArrowReaderIterator, CometShuffleExchangeExec} import org.apache.spark.sql.comet.util.Utils -import org.apache.spark.sql.execution.{BinaryExecNode, ColumnarToRowExec, ExecSubqueryExpression, ExplainUtils, LeafExecNode, ScalarSubquery, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.{BinaryExecNode, ColumnarToRowExec, ExecSubqueryExpression, ExplainUtils, LeafExecNode, PartitioningPreservingUnaryExecNode, ScalarSubquery, SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, BroadcastQueryStageExec, ShuffleQueryStageExec} +import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.internal.SQLConf @@ -69,6 +71,10 @@ abstract class CometExec extends CometPlan { override def outputOrdering: Seq[SortOrder] = originalPlan.outputOrdering + // `CometExec` reuses the outputPartitioning of the original SparkPlan. + // Note that if the outputPartitioning of the original SparkPlan depends on its children, + // we should override this method in the specific CometExec, because Spark AQE may change the + // outputPartitioning of SparkPlan, e.g., AQEShuffleReadExec. override def outputPartitioning: Partitioning = originalPlan.outputPartitioning /** @@ -377,7 +383,8 @@ case class CometProjectExec( override val output: Seq[Attribute], child: SparkPlan, override val serializedPlanOpt: SerializedPlan) - extends CometUnaryExec { + extends CometUnaryExec + with PartitioningPreservingUnaryExecNode { override def producedAttributes: AttributeSet = outputSet override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = this.copy(child = newChild) @@ -396,6 +403,8 @@ case class CometProjectExec( } override def hashCode(): Int = Objects.hashCode(projectList, output, child) + + override protected def outputExpressions: Seq[NamedExpression] = projectList } case class CometFilterExec( @@ -405,6 +414,9 @@ case class CometFilterExec( child: SparkPlan, override val serializedPlanOpt: SerializedPlan) extends CometUnaryExec { + + override def outputPartitioning: Partitioning = child.outputPartitioning + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = this.copy(child = newChild) @@ -439,6 +451,9 @@ case class CometSortExec( child: SparkPlan, override val serializedPlanOpt: SerializedPlan) extends CometUnaryExec { + + override def outputPartitioning: Partitioning = child.outputPartitioning + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = this.copy(child = newChild) @@ -471,6 +486,9 @@ case class CometLocalLimitExec( child: SparkPlan, override val serializedPlanOpt: SerializedPlan) extends CometUnaryExec { + + override def outputPartitioning: Partitioning = child.outputPartitioning + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = this.copy(child = newChild) @@ -498,6 +516,9 @@ case class CometGlobalLimitExec( child: SparkPlan, override val serializedPlanOpt: SerializedPlan) extends CometUnaryExec { + + override def outputPartitioning: Partitioning = child.outputPartitioning + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = this.copy(child = newChild) @@ -586,7 +607,8 @@ case class CometHashAggregateExec( mode: Option[AggregateMode], child: SparkPlan, override val serializedPlanOpt: SerializedPlan) - extends CometUnaryExec { + extends CometUnaryExec + with PartitioningPreservingUnaryExecNode { override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = this.copy(child = newChild) @@ -618,6 +640,9 @@ case class CometHashAggregateExec( override def hashCode(): Int = Objects.hashCode(groupingExpressions, aggregateExpressions, input, mode, child) + + override protected def outputExpressions: Seq[NamedExpression] = + originalPlan.asInstanceOf[HashAggregateExec].resultExpressions } case class CometHashJoinExec( @@ -632,6 +657,18 @@ case class CometHashJoinExec( override val right: SparkPlan, override val serializedPlanOpt: SerializedPlan) extends CometBinaryExec { + + override def outputPartitioning: Partitioning = joinType match { + case _: InnerLike => + PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) + case LeftOuter => left.outputPartitioning + case RightOuter => right.outputPartitioning + case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) + case LeftExistence(_) => left.outputPartitioning + case x => + throw new IllegalArgumentException(s"ShuffledJoin should not take $x as the JoinType") + } + override def withNewChildrenInternal(newLeft: SparkPlan, newRight: SparkPlan): SparkPlan = this.copy(left = newLeft, right = newRight) @@ -669,6 +706,94 @@ case class CometBroadcastHashJoinExec( override val right: SparkPlan, override val serializedPlanOpt: SerializedPlan) extends CometBinaryExec { + + protected lazy val streamedPlan = buildSide match { + case BuildLeft => right + case BuildRight => left + } + + override lazy val outputPartitioning: Partitioning = { + joinType match { + case _: InnerLike if conf.broadcastHashJoinOutputPartitioningExpandLimit > 0 => + streamedPlan.outputPartitioning match { + case h: HashPartitioningLike => expandOutputPartitioning(h) + case c: PartitioningCollection => expandOutputPartitioning(c) + case other => other + } + case _ => streamedPlan.outputPartitioning + } + } + + protected lazy val (buildKeys, streamedKeys) = { + require( + leftKeys.length == rightKeys.length && + leftKeys + .map(_.dataType) + .zip(rightKeys.map(_.dataType)) + .forall(types => types._1.sameType(types._2)), + "Join keys from two sides should have same length and types") + buildSide match { + case BuildLeft => (leftKeys, rightKeys) + case BuildRight => (rightKeys, leftKeys) + } + } + + // An one-to-many mapping from a streamed key to build keys. + private lazy val streamedKeyToBuildKeyMapping = { + val mapping = mutable.Map.empty[Expression, Seq[Expression]] + streamedKeys.zip(buildKeys).foreach { case (streamedKey, buildKey) => + val key = streamedKey.canonicalized + mapping.get(key) match { + case Some(v) => mapping.put(key, v :+ buildKey) + case None => mapping.put(key, Seq(buildKey)) + } + } + mapping.toMap + } + + // Expands the given partitioning collection recursively. + private def expandOutputPartitioning( + partitioning: PartitioningCollection): PartitioningCollection = { + PartitioningCollection(partitioning.partitionings.flatMap { + case h: HashPartitioningLike => expandOutputPartitioning(h).partitionings + case c: PartitioningCollection => Seq(expandOutputPartitioning(c)) + case other => Seq(other) + }) + } + + // Expands the given hash partitioning by substituting streamed keys with build keys. + // For example, if the expressions for the given partitioning are Seq("a", "b", "c") + // where the streamed keys are Seq("b", "c") and the build keys are Seq("x", "y"), + // the expanded partitioning will have the following expressions: + // Seq("a", "b", "c"), Seq("a", "b", "y"), Seq("a", "x", "c"), Seq("a", "x", "y"). + // The expanded expressions are returned as PartitioningCollection. + private def expandOutputPartitioning( + partitioning: HashPartitioningLike): PartitioningCollection = { + val maxNumCombinations = conf.broadcastHashJoinOutputPartitioningExpandLimit + var currentNumCombinations = 0 + + def generateExprCombinations( + current: Seq[Expression], + accumulated: Seq[Expression]): Seq[Seq[Expression]] = { + if (currentNumCombinations >= maxNumCombinations) { + Nil + } else if (current.isEmpty) { + currentNumCombinations += 1 + Seq(accumulated) + } else { + val buildKeysOpt = streamedKeyToBuildKeyMapping.get(current.head.canonicalized) + generateExprCombinations(current.tail, accumulated :+ current.head) ++ + buildKeysOpt + .map(_.flatMap(b => generateExprCombinations(current.tail, accumulated :+ b))) + .getOrElse(Nil) + } + } + + PartitioningCollection( + generateExprCombinations(partitioning.expressions, Nil) + .map(exprs => partitioning.withNewChildren(exprs).asInstanceOf[HashPartitioningLike])) + } + override def withNewChildrenInternal(newLeft: SparkPlan, newRight: SparkPlan): SparkPlan = this.copy(left = newLeft, right = newRight) @@ -705,6 +830,18 @@ case class CometSortMergeJoinExec( override val right: SparkPlan, override val serializedPlanOpt: SerializedPlan) extends CometBinaryExec { + + override def outputPartitioning: Partitioning = joinType match { + case _: InnerLike => + PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) + case LeftOuter => left.outputPartitioning + case RightOuter => right.outputPartitioning + case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) + case LeftExistence(_) => left.outputPartitioning + case x => + throw new IllegalArgumentException(s"ShuffledJoin should not take $x as the JoinType") + } + override def withNewChildrenInternal(newLeft: SparkPlan, newRight: SparkPlan): SparkPlan = this.copy(left = newLeft, right = newRight) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 5e9907368b..31c364a989 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -19,6 +19,7 @@ package org.apache.comet.exec +import java.sql.Date import java.time.{Duration, Period} import scala.collection.JavaConverters._ @@ -34,7 +35,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStatistics, CatalogTable} import org.apache.spark.sql.catalyst.expressions.Hex import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateMode -import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometCollectLimitExec, CometFilterExec, CometHashAggregateExec, CometProjectExec, CometRowToColumnarExec, CometScanExec, CometTakeOrderedAndProjectExec} +import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometCollectLimitExec, CometFilterExec, CometHashAggregateExec, CometProjectExec, CometRowToColumnarExec, CometScanExec, CometSortExec, CometTakeOrderedAndProjectExec} import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometShuffleExchangeExec} import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, SQLExecution, UnionExec} import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec @@ -61,6 +62,28 @@ class CometExecSuite extends CometTestBase { } } + test("Ensure that the correct outputPartitioning of CometSort") { + withTable("test_data") { + val tableDF = spark.sparkContext + .parallelize( + (1 to 10).map { i => + (if (i > 4) 5 else i, i.toString, Date.valueOf(s"${2020 + i}-$i-$i")) + }, + 3) + .toDF("id", "data", "day") + tableDF.write.saveAsTable("test_data") + + val df = sql("SELECT * FROM test_data") + .repartition($"data") + .sortWithinPartitions($"id", $"data", $"day") + df.collect() + val sort = stripAQEPlan(df.queryExecution.executedPlan).collect { case s: CometSortExec => + s + }.head + assert(sort.outputPartitioning == sort.child.outputPartitioning) + } + } + test("Repeated shuffle exchange don't fail") { assume(isSpark33Plus) Seq("true", "false").foreach { aqeEnabled =>