diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index 8ef8cb83e8..d3f197cc75 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -741,9 +741,32 @@ class CometSparkSessionExtensions } // Set up logical links - newPlan = newPlan.transform { case op: CometExec => - op.originalPlan.logicalLink.foreach(op.setLogicalLink) - op + newPlan = newPlan.transform { + case op: CometExec => + op.originalPlan.logicalLink.foreach(op.setLogicalLink) + op + case op: CometShuffleExchangeExec => + // Original Spark shuffle exchange operator might have empty logical link. + // But the `setLogicalLink` call above on downstream operator of + // `CometShuffleExchangeExec` will set its logical link to the downstream + // operators which cause AQE behavior to be incorrect. So we need to unset + // the logical link here. + if (op.originalPlan.logicalLink.isEmpty) { + op.unsetTagValue(SparkPlan.LOGICAL_PLAN_TAG) + op.unsetTagValue(SparkPlan.LOGICAL_PLAN_INHERITED_TAG) + } else { + op.originalPlan.logicalLink.foreach(op.setLogicalLink) + } + op + + case op: CometBroadcastExchangeExec => + if (op.originalPlan.logicalLink.isEmpty) { + op.unsetTagValue(SparkPlan.LOGICAL_PLAN_TAG) + op.unsetTagValue(SparkPlan.LOGICAL_PLAN_INHERITED_TAG) + } else { + op.originalPlan.logicalLink.foreach(op.setLogicalLink) + } + op } // Convert native execution block by linking consecutive native operators. diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala index 232b6bf17f..faa52cd06b 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala @@ -61,6 +61,7 @@ import org.apache.comet.shims.ShimCometShuffleExchangeExec case class CometShuffleExchangeExec( override val outputPartitioning: Partitioning, child: SparkPlan, + originalPlan: ShuffleExchangeLike, shuffleOrigin: ShuffleOrigin = ENSURE_REQUIREMENTS, shuffleType: ShuffleType = CometNativeShuffle, advisoryPartitionSize: Option[Long] = None) diff --git a/spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometShuffleExchangeExec.scala b/spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometShuffleExchangeExec.scala index f89dbb8dba..6b4fad974f 100644 --- a/spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometShuffleExchangeExec.scala +++ b/spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometShuffleExchangeExec.scala @@ -32,6 +32,7 @@ trait ShimCometShuffleExchangeExec { CometShuffleExchangeExec( s.outputPartitioning, s.child, + s, s.shuffleOrigin, shuffleType, advisoryPartitionSize) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index e5b3523dcf..5e86757f64 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateMode import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometCollectLimitExec, CometFilterExec, CometHashAggregateExec, CometHashJoinExec, CometProjectExec, CometRowToColumnarExec, CometScanExec, CometSortExec, CometSortMergeJoinExec, CometTakeOrderedAndProjectExec} import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometShuffleExchangeExec} import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, SQLExecution, UnionExec} -import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.{BroadcastNestedLoopJoinExec, CartesianProductExec, SortMergeJoinExec} import org.apache.spark.sql.execution.window.WindowExec import org.apache.spark.sql.expressions.Window @@ -62,6 +62,37 @@ class CometExecSuite extends CometTestBase { } } + test("CometShuffleExchangeExec logical link should be correct") { + withTempView("v") { + spark.sparkContext + .parallelize((1 to 4).map(i => TestData(i, i.toString)), 2) + .toDF("c1", "c2") + .createOrReplaceTempView("v") + + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") { + val df = sql("SELECT * FROM v where c1 = 1 order by c1, c2") + val shuffle = find(df.queryExecution.executedPlan) { + case _: CometShuffleExchangeExec => true + case _ => false + }.get.asInstanceOf[CometShuffleExchangeExec] + assert(shuffle.logicalLink.isEmpty) + } + + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "false") { + val df = sql("SELECT * FROM v where c1 = 1 order by c1, c2") + val shuffle = find(df.queryExecution.executedPlan) { + case _: ShuffleExchangeExec => true + case _ => false + }.get.asInstanceOf[ShuffleExchangeExec] + assert(shuffle.logicalLink.isEmpty) + } + } + } + test("Ensure that the correct outputPartitioning of CometSort") { withTable("test_data") { val tableDF = spark.sparkContext @@ -302,7 +333,8 @@ class CometExecSuite extends CometTestBase { withSQLConf( CometConf.COMET_EXEC_ENABLED.key -> "true", CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key -> "true", - "spark.sql.autoBroadcastJoinThreshold" -> "0", + "spark.sql.adaptive.autoBroadcastJoinThreshold" -> "-1", + "spark.sql.autoBroadcastJoinThreshold" -> "-1", "spark.sql.join.preferSortMergeJoin" -> "true") { withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl1") { withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl2") { @@ -1306,3 +1338,5 @@ case class BucketedTableTestSpec( expectedShuffle: Boolean = true, expectedSort: Boolean = true, expectedNumOutputPartitions: Option[Int] = None) + +case class TestData(key: Int, value: String)