From f789c21736819c0ffa5ba56aaa2c5ec4bcb7127a Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 23 Apr 2024 09:04:14 -0700 Subject: [PATCH] fix: CometExec's outputPartitioning might not be same as Spark expects after AQE interferes (#299) * fix: CometExec's outputPartitioning might not be same as Spark expects after AQE interferes * Add compatibility with Spark 3.2 and 3.3 * Remove unused import --- .../ShimCometBroadcastHashJoinExec.scala | 39 +++++ .../apache/spark/sql/comet/operators.scala | 157 +++++++++++++++++- .../plans/AliasAwareOutputExpression.scala | 150 +++++++++++++++++ .../PartitioningPreservingUnaryExecNode.scala | 76 +++++++++ .../apache/comet/exec/CometExecSuite.scala | 25 ++- 5 files changed, 440 insertions(+), 7 deletions(-) create mode 100644 spark/src/main/scala/org/apache/comet/shims/ShimCometBroadcastHashJoinExec.scala create mode 100644 spark/src/main/scala/org/apache/spark/sql/comet/plans/AliasAwareOutputExpression.scala create mode 100644 spark/src/main/scala/org/apache/spark/sql/comet/plans/PartitioningPreservingUnaryExecNode.scala diff --git a/spark/src/main/scala/org/apache/comet/shims/ShimCometBroadcastHashJoinExec.scala b/spark/src/main/scala/org/apache/comet/shims/ShimCometBroadcastHashJoinExec.scala new file mode 100644 index 000000000..eef0ee9d5 --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/shims/ShimCometBroadcastHashJoinExec.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.shims + +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans.physical.Partitioning + +trait ShimCometBroadcastHashJoinExec { + + /** + * Returns the expressions that are used for hash partitioning including `HashPartitioning` and + * `CoalescedHashPartitioning`. They shares same trait `HashPartitioningLike` since Spark 3.4, + * but Spark 3.2/3.3 doesn't have `HashPartitioningLike` and `CoalescedHashPartitioning`. + * + * TODO: remove after dropping Spark 3.2 and 3.3 support. + */ + def getHashPartitioningLikeExpressions(partitioning: Partitioning): Seq[Expression] = { + partitioning.getClass.getDeclaredMethods + .filter(_.getName == "expressions") + .flatMap(_.invoke(partitioning).asInstanceOf[Seq[Expression]]) + } +} diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 1065367c2..571ec226d 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -22,6 +22,7 @@ package org.apache.spark.sql.comet import java.io.{ByteArrayOutputStream, DataInputStream} import java.nio.channels.Channels +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.spark.{SparkEnv, TaskContext} @@ -30,13 +31,15 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, NamedExpression, SortOrder} import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateMode} -import org.apache.spark.sql.catalyst.optimizer.BuildSide -import org.apache.spark.sql.catalyst.plans.JoinType -import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, PartitioningCollection, UnknownPartitioning} import org.apache.spark.sql.comet.execution.shuffle.{ArrowReaderIterator, CometShuffleExchangeExec} +import org.apache.spark.sql.comet.plans.PartitioningPreservingUnaryExecNode import org.apache.spark.sql.comet.util.Utils import org.apache.spark.sql.execution.{BinaryExecNode, ColumnarToRowExec, ExecSubqueryExpression, ExplainUtils, LeafExecNode, ScalarSubquery, SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, BroadcastQueryStageExec, ShuffleQueryStageExec} +import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.internal.SQLConf @@ -47,6 +50,7 @@ import com.google.common.base.Objects import org.apache.comet.{CometConf, CometExecIterator, CometRuntimeException} import org.apache.comet.serde.OperatorOuterClass.Operator +import org.apache.comet.shims.ShimCometBroadcastHashJoinExec /** * A Comet physical operator @@ -69,6 +73,10 @@ abstract class CometExec extends CometPlan { override def outputOrdering: Seq[SortOrder] = originalPlan.outputOrdering + // `CometExec` reuses the outputPartitioning of the original SparkPlan. + // Note that if the outputPartitioning of the original SparkPlan depends on its children, + // we should override this method in the specific CometExec, because Spark AQE may change the + // outputPartitioning of SparkPlan, e.g., AQEShuffleReadExec. override def outputPartitioning: Partitioning = originalPlan.outputPartitioning /** @@ -377,7 +385,8 @@ case class CometProjectExec( override val output: Seq[Attribute], child: SparkPlan, override val serializedPlanOpt: SerializedPlan) - extends CometUnaryExec { + extends CometUnaryExec + with PartitioningPreservingUnaryExecNode { override def producedAttributes: AttributeSet = outputSet override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = this.copy(child = newChild) @@ -396,6 +405,8 @@ case class CometProjectExec( } override def hashCode(): Int = Objects.hashCode(projectList, output, child) + + override protected def outputExpressions: Seq[NamedExpression] = projectList } case class CometFilterExec( @@ -405,6 +416,9 @@ case class CometFilterExec( child: SparkPlan, override val serializedPlanOpt: SerializedPlan) extends CometUnaryExec { + + override def outputPartitioning: Partitioning = child.outputPartitioning + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = this.copy(child = newChild) @@ -439,6 +453,9 @@ case class CometSortExec( child: SparkPlan, override val serializedPlanOpt: SerializedPlan) extends CometUnaryExec { + + override def outputPartitioning: Partitioning = child.outputPartitioning + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = this.copy(child = newChild) @@ -471,6 +488,9 @@ case class CometLocalLimitExec( child: SparkPlan, override val serializedPlanOpt: SerializedPlan) extends CometUnaryExec { + + override def outputPartitioning: Partitioning = child.outputPartitioning + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = this.copy(child = newChild) @@ -498,6 +518,9 @@ case class CometGlobalLimitExec( child: SparkPlan, override val serializedPlanOpt: SerializedPlan) extends CometUnaryExec { + + override def outputPartitioning: Partitioning = child.outputPartitioning + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = this.copy(child = newChild) @@ -586,7 +609,8 @@ case class CometHashAggregateExec( mode: Option[AggregateMode], child: SparkPlan, override val serializedPlanOpt: SerializedPlan) - extends CometUnaryExec { + extends CometUnaryExec + with PartitioningPreservingUnaryExecNode { override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = this.copy(child = newChild) @@ -618,6 +642,9 @@ case class CometHashAggregateExec( override def hashCode(): Int = Objects.hashCode(groupingExpressions, aggregateExpressions, input, mode, child) + + override protected def outputExpressions: Seq[NamedExpression] = + originalPlan.asInstanceOf[HashAggregateExec].resultExpressions } case class CometHashJoinExec( @@ -632,6 +659,18 @@ case class CometHashJoinExec( override val right: SparkPlan, override val serializedPlanOpt: SerializedPlan) extends CometBinaryExec { + + override def outputPartitioning: Partitioning = joinType match { + case _: InnerLike => + PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) + case LeftOuter => left.outputPartitioning + case RightOuter => right.outputPartitioning + case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) + case LeftExistence(_) => left.outputPartitioning + case x => + throw new IllegalArgumentException(s"ShuffledJoin should not take $x as the JoinType") + } + override def withNewChildrenInternal(newLeft: SparkPlan, newRight: SparkPlan): SparkPlan = this.copy(left = newLeft, right = newRight) @@ -668,7 +707,101 @@ case class CometBroadcastHashJoinExec( override val left: SparkPlan, override val right: SparkPlan, override val serializedPlanOpt: SerializedPlan) - extends CometBinaryExec { + extends CometBinaryExec + with ShimCometBroadcastHashJoinExec { + + // The following logic of `outputPartitioning` is copied from Spark `BroadcastHashJoinExec`. + protected lazy val streamedPlan: SparkPlan = buildSide match { + case BuildLeft => right + case BuildRight => left + } + + override lazy val outputPartitioning: Partitioning = { + joinType match { + case _: InnerLike if conf.broadcastHashJoinOutputPartitioningExpandLimit > 0 => + streamedPlan.outputPartitioning match { + case h: HashPartitioning => expandOutputPartitioning(h) + case h: Expression if h.getClass.getName.contains("CoalescedHashPartitioning") => + expandOutputPartitioning(h) + case c: PartitioningCollection => expandOutputPartitioning(c) + case other => other + } + case _ => streamedPlan.outputPartitioning + } + } + + protected lazy val (buildKeys, streamedKeys) = { + require( + leftKeys.length == rightKeys.length && + leftKeys + .map(_.dataType) + .zip(rightKeys.map(_.dataType)) + .forall(types => types._1.sameType(types._2)), + "Join keys from two sides should have same length and types") + buildSide match { + case BuildLeft => (leftKeys, rightKeys) + case BuildRight => (rightKeys, leftKeys) + } + } + + // An one-to-many mapping from a streamed key to build keys. + private lazy val streamedKeyToBuildKeyMapping = { + val mapping = mutable.Map.empty[Expression, Seq[Expression]] + streamedKeys.zip(buildKeys).foreach { case (streamedKey, buildKey) => + val key = streamedKey.canonicalized + mapping.get(key) match { + case Some(v) => mapping.put(key, v :+ buildKey) + case None => mapping.put(key, Seq(buildKey)) + } + } + mapping.toMap + } + + // Expands the given partitioning collection recursively. + private def expandOutputPartitioning( + partitioning: PartitioningCollection): PartitioningCollection = { + PartitioningCollection(partitioning.partitionings.flatMap { + case h: HashPartitioning => expandOutputPartitioning(h).partitionings + case h: Expression if h.getClass.getName.contains("CoalescedHashPartitioning") => + expandOutputPartitioning(h).partitionings + case c: PartitioningCollection => Seq(expandOutputPartitioning(c)) + case other => Seq(other) + }) + } + + // Expands the given hash partitioning by substituting streamed keys with build keys. + // For example, if the expressions for the given partitioning are Seq("a", "b", "c") + // where the streamed keys are Seq("b", "c") and the build keys are Seq("x", "y"), + // the expanded partitioning will have the following expressions: + // Seq("a", "b", "c"), Seq("a", "b", "y"), Seq("a", "x", "c"), Seq("a", "x", "y"). + // The expanded expressions are returned as PartitioningCollection. + private def expandOutputPartitioning( + partitioning: Partitioning with Expression): PartitioningCollection = { + val maxNumCombinations = conf.broadcastHashJoinOutputPartitioningExpandLimit + var currentNumCombinations = 0 + + def generateExprCombinations( + current: Seq[Expression], + accumulated: Seq[Expression]): Seq[Seq[Expression]] = { + if (currentNumCombinations >= maxNumCombinations) { + Nil + } else if (current.isEmpty) { + currentNumCombinations += 1 + Seq(accumulated) + } else { + val buildKeysOpt = streamedKeyToBuildKeyMapping.get(current.head.canonicalized) + generateExprCombinations(current.tail, accumulated :+ current.head) ++ + buildKeysOpt + .map(_.flatMap(b => generateExprCombinations(current.tail, accumulated :+ b))) + .getOrElse(Nil) + } + } + + PartitioningCollection( + generateExprCombinations(getHashPartitioningLikeExpressions(partitioning), Nil) + .map(exprs => partitioning.withNewChildren(exprs).asInstanceOf[Partitioning])) + } + override def withNewChildrenInternal(newLeft: SparkPlan, newRight: SparkPlan): SparkPlan = this.copy(left = newLeft, right = newRight) @@ -705,6 +838,18 @@ case class CometSortMergeJoinExec( override val right: SparkPlan, override val serializedPlanOpt: SerializedPlan) extends CometBinaryExec { + + override def outputPartitioning: Partitioning = joinType match { + case _: InnerLike => + PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) + case LeftOuter => left.outputPartitioning + case RightOuter => right.outputPartitioning + case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) + case LeftExistence(_) => left.outputPartitioning + case x => + throw new IllegalArgumentException(s"ShuffledJoin should not take $x as the JoinType") + } + override def withNewChildrenInternal(newLeft: SparkPlan, newRight: SparkPlan): SparkPlan = this.copy(left = newLeft, right = newRight) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/plans/AliasAwareOutputExpression.scala b/spark/src/main/scala/org/apache/spark/sql/comet/plans/AliasAwareOutputExpression.scala new file mode 100644 index 000000000..6e5b44c8f --- /dev/null +++ b/spark/src/main/scala/org/apache/spark/sql/comet/plans/AliasAwareOutputExpression.scala @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.plans + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.SQLConfHelper +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeSet, Expression, NamedExpression} +import org.apache.spark.sql.catalyst.trees.CurrentOrigin + +/** + * A trait that provides functionality to handle aliases in the `outputExpressions`. + */ +trait AliasAwareOutputExpression extends SQLConfHelper { + // `SQLConf.EXPRESSION_PROJECTION_CANDIDATE_LIMIT` is Spark 3.4+ only. + // Use a default value for now. + protected val aliasCandidateLimit = 100 + protected def outputExpressions: Seq[NamedExpression] + + /** + * This method can be used to strip expression which does not affect the result, for example: + * strip the expression which is ordering agnostic for output ordering. + */ + protected def strip(expr: Expression): Expression = expr + + // Build an `Expression` -> `Attribute` alias map. + // There can be multiple alias defined for the same expressions but it doesn't make sense to store + // more than `aliasCandidateLimit` attributes for an expression. In those cases the old logic + // handled only the last alias so we need to make sure that we give precedence to that. + // If the `outputExpressions` contain simple attributes we need to add those too to the map. + @transient + private lazy val aliasMap = { + val aliases = mutable.Map[Expression, mutable.ArrayBuffer[Attribute]]() + outputExpressions.reverse.foreach { + case a @ Alias(child, _) => + val buffer = + aliases.getOrElseUpdate(strip(child).canonicalized, mutable.ArrayBuffer.empty) + if (buffer.size < aliasCandidateLimit) { + buffer += a.toAttribute + } + case _ => + } + outputExpressions.foreach { + case a: Attribute if aliases.contains(a.canonicalized) => + val buffer = aliases(a.canonicalized) + if (buffer.size < aliasCandidateLimit) { + buffer += a + } + case _ => + } + aliases + } + + protected def hasAlias: Boolean = aliasMap.nonEmpty + + /** + * Return a stream of expressions in which the original expression is projected with `aliasMap`. + */ + protected def projectExpression(expr: Expression): Stream[Expression] = { + val outputSet = AttributeSet(outputExpressions.map(_.toAttribute)) + multiTransformDown(expr) { + // Mapping with aliases + case e: Expression if aliasMap.contains(e.canonicalized) => + aliasMap(e.canonicalized).toSeq ++ (if (e.containsChild.nonEmpty) Seq(e) else Seq.empty) + + // Prune if we encounter an attribute that we can't map and it is not in output set. + // This prune will go up to the closest `multiTransformDown()` call and returns `Stream.empty` + // there. + case a: Attribute if !outputSet.contains(a) => Seq.empty + } + } + + // Copied from Spark 3.4+ to make it available in Spark 3.2+. + def multiTransformDown(expr: Expression)( + rule: PartialFunction[Expression, Seq[Expression]]): Stream[Expression] = { + + // We could return `Seq(this)` if the `rule` doesn't apply and handle both + // - the doesn't apply + // - and the rule returns a one element `Seq(originalNode)` + // cases together. The returned `Seq` can be a `Stream` and unfortunately it doesn't seem like + // there is a way to match on a one element stream without eagerly computing the tail's head. + // This contradicts with the purpose of only taking the necessary elements from the + // alternatives. I.e. the "multiTransformDown is lazy" test case in `TreeNodeSuite` would fail. + // Please note that this behaviour has a downside as well that we can only mark the rule on the + // original node ineffective if the rule didn't match. + var ruleApplied = true + val afterRules = CurrentOrigin.withOrigin(expr.origin) { + rule.applyOrElse( + expr, + (_: Expression) => { + ruleApplied = false + Seq.empty + }) + } + + val afterRulesStream = if (afterRules.isEmpty) { + if (ruleApplied) { + // If the rule returned with empty alternatives then prune + Stream.empty + } else { + // If the rule was not applied then keep the original node + Stream(expr) + } + } else { + // If the rule was applied then use the returned alternatives + afterRules.toStream.map { afterRule => + if (expr fastEquals afterRule) { + expr + } else { + afterRule.copyTagsFrom(expr) + afterRule + } + } + } + + afterRulesStream.flatMap { afterRule => + if (afterRule.containsChild.nonEmpty) { + generateCartesianProduct(afterRule.children.map(c => () => multiTransformDown(c)(rule))) + .map(afterRule.withNewChildren) + } else { + Stream(afterRule) + } + } + } + + def generateCartesianProduct[T](elementSeqs: Seq[() => Seq[T]]): Stream[Seq[T]] = { + elementSeqs.foldRight(Stream(Seq.empty[T]))((elements, elementTails) => + for { + elementTail <- elementTails + element <- elements() + } yield element +: elementTail) + } +} diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/plans/PartitioningPreservingUnaryExecNode.scala b/spark/src/main/scala/org/apache/spark/sql/comet/plans/PartitioningPreservingUnaryExecNode.scala new file mode 100644 index 000000000..8c6f0af18 --- /dev/null +++ b/spark/src/main/scala/org/apache/spark/sql/comet/plans/PartitioningPreservingUnaryExecNode.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.plans + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.expressions.{AttributeSet, Expression} +import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection, UnknownPartitioning} +import org.apache.spark.sql.execution.UnaryExecNode + +/** + * A trait that handles aliases in the `outputExpressions` to produce `outputPartitioning` that + * satisfies distribution requirements. + * + * This is copied from Spark's `PartitioningPreservingUnaryExecNode` because it is only available + * in Spark 3.4+. This is a workaround to make it available in Spark 3.2+. + */ +trait PartitioningPreservingUnaryExecNode extends UnaryExecNode with AliasAwareOutputExpression { + final override def outputPartitioning: Partitioning = { + val partitionings: Seq[Partitioning] = if (hasAlias) { + flattenPartitioning(child.outputPartitioning).flatMap { + case e: Expression => + // We need unique partitionings but if the input partitioning is + // `HashPartitioning(Seq(id + id))` and we have `id -> a` and `id -> b` aliases then after + // the projection we have 4 partitionings: + // `HashPartitioning(Seq(a + a))`, `HashPartitioning(Seq(a + b))`, + // `HashPartitioning(Seq(b + a))`, `HashPartitioning(Seq(b + b))`, but + // `HashPartitioning(Seq(a + b))` is the same as `HashPartitioning(Seq(b + a))`. + val partitioningSet = mutable.Set.empty[Expression] + projectExpression(e) + .filter(e => partitioningSet.add(e.canonicalized)) + .take(aliasCandidateLimit) + .asInstanceOf[Stream[Partitioning]] + case o => Seq(o) + } + } else { + // Filter valid partitiongs (only reference output attributes of the current plan node) + val outputSet = AttributeSet(outputExpressions.map(_.toAttribute)) + flattenPartitioning(child.outputPartitioning).filter { + case e: Expression => e.references.subsetOf(outputSet) + case _ => true + } + } + partitionings match { + case Seq() => UnknownPartitioning(child.outputPartitioning.numPartitions) + case Seq(p) => p + case ps => PartitioningCollection(ps) + } + } + + private def flattenPartitioning(partitioning: Partitioning): Seq[Partitioning] = { + partitioning match { + case PartitioningCollection(childPartitionings) => + childPartitionings.flatMap(flattenPartitioning) + case rest => + rest +: Nil + } + } +} diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index cc968a638..264ea4cd8 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -19,6 +19,7 @@ package org.apache.comet.exec +import java.sql.Date import java.time.{Duration, Period} import scala.collection.JavaConverters._ @@ -34,7 +35,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStatistics, CatalogTable} import org.apache.spark.sql.catalyst.expressions.Hex import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateMode -import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometCollectLimitExec, CometFilterExec, CometHashAggregateExec, CometProjectExec, CometRowToColumnarExec, CometScanExec, CometSortMergeJoinExec, CometTakeOrderedAndProjectExec} +import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometCollectLimitExec, CometFilterExec, CometHashAggregateExec, CometProjectExec, CometRowToColumnarExec, CometScanExec, CometSortExec, CometSortMergeJoinExec, CometTakeOrderedAndProjectExec} import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometShuffleExchangeExec} import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, SQLExecution, UnionExec} import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec @@ -61,6 +62,28 @@ class CometExecSuite extends CometTestBase { } } + test("Ensure that the correct outputPartitioning of CometSort") { + withTable("test_data") { + val tableDF = spark.sparkContext + .parallelize( + (1 to 10).map { i => + (if (i > 4) 5 else i, i.toString, Date.valueOf(s"${2020 + i}-$i-$i")) + }, + 3) + .toDF("id", "data", "day") + tableDF.write.saveAsTable("test_data") + + val df = sql("SELECT * FROM test_data") + .repartition($"data") + .sortWithinPartitions($"id", $"data", $"day") + df.collect() + val sort = stripAQEPlan(df.queryExecution.executedPlan).collect { case s: CometSortExec => + s + }.head + assert(sort.outputPartitioning == sort.child.outputPartitioning) + } + } + test("Repeated shuffle exchange don't fail") { assume(isSpark33Plus) Seq("true", "false").foreach { aqeEnabled =>