Skip to content

Commit

Permalink
fix: CometShuffleExchangeExec logical link should be correct
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Apr 25, 2024
1 parent ef94c55 commit 3a9fdb9
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -741,9 +741,32 @@ class CometSparkSessionExtensions
}

// Set up logical links
newPlan = newPlan.transform { case op: CometExec =>
op.originalPlan.logicalLink.foreach(op.setLogicalLink)
op
newPlan = newPlan.transform {
case op: CometExec =>
op.originalPlan.logicalLink.foreach(op.setLogicalLink)
op
case op: CometShuffleExchangeExec =>
// Original Spark shuffle exchange operator might have empty logical link.
// But the `setLogicalLink` call above on downstream operator of
// `CometShuffleExchangeExec` will set its logical link to the downstream
// operators which cause AQE behavior to be incorrect. So we need to unset
// the logical link here.
if (op.originalPlan.logicalLink.isEmpty) {
op.unsetTagValue(SparkPlan.LOGICAL_PLAN_TAG)
op.unsetTagValue(SparkPlan.LOGICAL_PLAN_INHERITED_TAG)
} else {
op.originalPlan.logicalLink.foreach(op.setLogicalLink)
}
op

case op: CometBroadcastExchangeExec =>
if (op.originalPlan.logicalLink.isEmpty) {
op.unsetTagValue(SparkPlan.LOGICAL_PLAN_TAG)
op.unsetTagValue(SparkPlan.LOGICAL_PLAN_INHERITED_TAG)
} else {
op.originalPlan.logicalLink.foreach(op.setLogicalLink)
}
op
}

// Convert native execution block by linking consecutive native operators.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ import org.apache.comet.shims.ShimCometShuffleExchangeExec
case class CometShuffleExchangeExec(
override val outputPartitioning: Partitioning,
child: SparkPlan,
originalPlan: ShuffleExchangeLike,
shuffleOrigin: ShuffleOrigin = ENSURE_REQUIREMENTS,
shuffleType: ShuffleType = CometNativeShuffle,
advisoryPartitionSize: Option[Long] = None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ trait ShimCometShuffleExchangeExec {
CometShuffleExchangeExec(
s.outputPartitioning,
s.child,
s,
s.shuffleOrigin,
shuffleType,
advisoryPartitionSize)
Expand Down
38 changes: 36 additions & 2 deletions spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateMode
import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometCollectLimitExec, CometFilterExec, CometHashAggregateExec, CometHashJoinExec, CometProjectExec, CometRowToColumnarExec, CometScanExec, CometSortExec, CometSortMergeJoinExec, CometTakeOrderedAndProjectExec}
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometShuffleExchangeExec}
import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, SQLExecution, UnionExec}
import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{BroadcastNestedLoopJoinExec, CartesianProductExec, SortMergeJoinExec}
import org.apache.spark.sql.execution.window.WindowExec
import org.apache.spark.sql.expressions.Window
Expand All @@ -62,6 +62,37 @@ class CometExecSuite extends CometTestBase {
}
}

test("CometShuffleExchangeExec logical link should be correct") {
withTempView("v") {
spark.sparkContext
.parallelize((1 to 4).map(i => TestData(i, i.toString)), 2)
.toDF("c1", "c2")
.createOrReplaceTempView("v")

withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false",
CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") {
val df = sql("SELECT * FROM v where c1 = 1 order by c1, c2")
val shuffle = find(df.queryExecution.executedPlan) {
case _: CometShuffleExchangeExec => true
case _ => false
}.get.asInstanceOf[CometShuffleExchangeExec]
assert(shuffle.logicalLink.isEmpty)
}

withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false",
CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "false") {
val df = sql("SELECT * FROM v where c1 = 1 order by c1, c2")
val shuffle = find(df.queryExecution.executedPlan) {
case _: ShuffleExchangeExec => true
case _ => false
}.get.asInstanceOf[ShuffleExchangeExec]
assert(shuffle.logicalLink.isEmpty)
}
}
}

test("Ensure that the correct outputPartitioning of CometSort") {
withTable("test_data") {
val tableDF = spark.sparkContext
Expand Down Expand Up @@ -302,7 +333,8 @@ class CometExecSuite extends CometTestBase {
withSQLConf(
CometConf.COMET_EXEC_ENABLED.key -> "true",
CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key -> "true",
"spark.sql.autoBroadcastJoinThreshold" -> "0",
"spark.sql.adaptive.autoBroadcastJoinThreshold" -> "-1",
"spark.sql.autoBroadcastJoinThreshold" -> "-1",
"spark.sql.join.preferSortMergeJoin" -> "true") {
withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl1") {
withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl2") {
Expand Down Expand Up @@ -1306,3 +1338,5 @@ case class BucketedTableTestSpec(
expectedShuffle: Boolean = true,
expectedSort: Boolean = true,
expectedNumOutputPartitions: Option[Int] = None)

case class TestData(key: Int, value: String)

0 comments on commit 3a9fdb9

Please sign in to comment.