From 30ded81794194af27842870a4ece736b266889f9 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 10 Mar 2024 14:07:52 -0700 Subject: [PATCH] fix: CometExecRule should handle ShuffleQueryStage and ReusedExchange --- .../comet/CometSparkSessionExtensions.scala | 41 ++++++++++++--- .../apache/comet/serde/QueryPlanSerde.scala | 8 ++- .../exec/CometColumnarShuffleSuite.scala | 51 ++++++++++++++++++- 3 files changed, 89 insertions(+), 11 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index 87c2265fcb..5f12c22b83 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -26,18 +26,18 @@ import org.apache.spark.internal.Logging import org.apache.spark.network.util.ByteUnit import org.apache.spark.sql.SparkSession import org.apache.spark.sql.SparkSessionExtensions -import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.comet._ import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle} import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan -import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec} +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -222,12 +222,16 @@ class CometSparkSessionExtensions */ // spotless:on private def transform(plan: SparkPlan): SparkPlan = { - def transform1(op: UnaryExecNode): Option[Operator] = { - op.child match { - case childNativeOp: CometNativeExec => - QueryPlanSerde.operator2Proto(op, childNativeOp.nativeOp) - case _ => - None + def transform1(op: SparkPlan): Option[Operator] = { + val allNativeExec = op.children.map { + case childNativeOp: CometNativeExec => Some(childNativeOp.nativeOp) + case _ => None + } + + if (allNativeExec.forall(_.isDefined)) { + QueryPlanSerde.operator2Proto(op, allNativeExec.map(_.get): _*) + } else { + None } } @@ -378,6 +382,27 @@ class CometSparkSessionExtensions case None => b } + case s @ ShuffleQueryStageExec(_, _: CometShuffleExchangeExec, _) => + val newOp = transform1(s) + newOp match { + case Some(nativeOp) => + CometSinkPlaceHolder(nativeOp, s, s) + case None => + s + } + + case s @ ShuffleQueryStageExec( + _, + ReusedExchangeExec(_, _: CometShuffleExchangeExec), + _) => + val newOp = transform1(s) + newOp match { + case Some(nativeOp) => + CometSinkPlaceHolder(nativeOp, s, s) + case None => + s + } + // Native shuffle for Comet operators case s: ShuffleExchangeExec if isCometShuffleEnabled(conf) && diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index b27fa3a754..a4b7e10977 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -28,11 +28,13 @@ import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, SinglePartition} import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils -import org.apache.spark.sql.comet.{CometHashAggregateExec, CometPlan, CometSinkPlaceHolder, DecimalPrecision} +import org.apache.spark.sql.comet.{CometSinkPlaceHolder, DecimalPrecision} +import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec import org.apache.spark.sql.execution import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec import org.apache.spark.sql.execution.aggregate.HashAggregateExec -import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec} +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -1883,6 +1885,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { case _: CollectLimitExec => true case _: UnionExec => true case _: ShuffleExchangeExec => true + case ShuffleQueryStageExec(_, _: CometShuffleExchangeExec, _) => true + case ShuffleQueryStageExec(_, ReusedExchangeExec(_, _: CometShuffleExchangeExec), _) => true case _: TakeOrderedAndProjectExec => true case _: BroadcastExchangeExec => true case _ => false diff --git a/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala index a9b29e6b7e..55d0ec1bc1 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala @@ -26,7 +26,8 @@ import org.apache.hadoop.fs.Path import org.apache.spark.{Partitioner, SparkConf} import org.apache.spark.sql.{CometTestBase, DataFrame, Row} import org.apache.spark.sql.comet.execution.shuffle.{CometShuffleDependency, CometShuffleExchangeExec, CometShuffleManager} -import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, AQEShuffleReadExec, ShuffleQueryStageExec} +import org.apache.spark.sql.execution.joins.SortMergeJoinExec import org.apache.spark.sql.functions.col import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -925,6 +926,54 @@ class CometAsyncShuffleSuite extends CometColumnarShuffleSuite { override protected val asyncShuffleEnable: Boolean = true protected val adaptiveExecutionEnabled: Boolean = true + + import testImplicits._ + + test("Comet native operator after ShuffleQueryStage") { + withSQLConf( + SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") { + withParquetTable((0 until 10).map(i => (i, i % 5)), "tbl_a") { + withParquetTable((0 until 10).map(i => (i % 10, i + 2)), "tbl_b") { + val df = sql("SELECT * FROM tbl_a") + val shuffled = df + .select($"_1" + 1 as ("a")) + .filter($"a" > 4) + .repartition(10) + .sortWithinPartitions($"a") + checkSparkAnswerAndOperator(shuffled, classOf[ShuffleQueryStageExec]) + } + } + } + } + + test("Comet native operator after ShuffleQueryStage + ReusedExchange") { + withSQLConf( + SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") { + withParquetTable((0 until 10).map(i => (i, i % 5)), "tbl_a") { + withParquetTable((0 until 10).map(i => (i % 10, i + 2)), "tbl_b") { + val df = sql("SELECT * FROM tbl_a") + val left = df + .select($"_1" + 1 as ("a")) + .filter($"a" > 4) + val right = left.select($"a" as ("b")) + val join = left.join(right, $"a" === $"b") + checkSparkAnswerAndOperator( + join, + classOf[ShuffleQueryStageExec], + classOf[SortMergeJoinExec], + classOf[AQEShuffleReadExec]) + } + } + } + } } class CometShuffleSuite extends CometColumnarShuffleSuite {