From 81a641f30844d76b417a600392a0cc97a74a919b Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 13 Mar 2024 11:14:45 -0700 Subject: [PATCH] fix: CometExecRule should handle ShuffleQueryStage and ReusedExchange (#186) * fix: CometExecRule should handle ShuffleQueryStage and ReusedExchange * fix * Add comment and move tests * Remove unused table in test. --- .../comet/CometSparkSessionExtensions.scala | 44 ++++++++++++++--- .../apache/comet/serde/QueryPlanSerde.scala | 6 ++- .../exec/CometColumnarShuffleSuite.scala | 48 ++++++++++++++++++- .../comet/exec/CometNativeShuffleSuite.scala | 5 +- 4 files changed, 92 insertions(+), 11 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index 5720b6935..39c83ae53 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -31,12 +31,13 @@ import org.apache.spark.sql.comet._ import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle} import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan -import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec} +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -221,12 +222,16 @@ class CometSparkSessionExtensions */ // spotless:on private def transform(plan: SparkPlan): SparkPlan = { - def transform1(op: UnaryExecNode): Option[Operator] = { - op.child match { - case childNativeOp: CometNativeExec => - QueryPlanSerde.operator2Proto(op, childNativeOp.nativeOp) - case _ => - None + def transform1(op: SparkPlan): Option[Operator] = { + val allNativeExec = op.children.map { + case childNativeOp: CometNativeExec => Some(childNativeOp.nativeOp) + case _ => None + } + + if (allNativeExec.forall(_.isDefined)) { + QueryPlanSerde.operator2Proto(op, allNativeExec.map(_.get): _*) + } else { + None } } @@ -377,6 +382,31 @@ class CometSparkSessionExtensions case None => b } + // For AQE shuffle stage on a Comet shuffle exchange + case s @ ShuffleQueryStageExec(_, _: CometShuffleExchangeExec, _) => + val newOp = transform1(s) + newOp match { + case Some(nativeOp) => + CometSinkPlaceHolder(nativeOp, s, s) + case None => + s + } + + // For AQE shuffle stage on a reused Comet shuffle exchange + // Note that we don't need to handle `ReusedExchangeExec` for non-AQE case, because + // the query plan won't be re-optimized/planned in non-AQE mode. + case s @ ShuffleQueryStageExec( + _, + ReusedExchangeExec(_, _: CometShuffleExchangeExec), + _) => + val newOp = transform1(s) + newOp match { + case Some(nativeOp) => + CometSinkPlaceHolder(nativeOp, s, s) + case None => + s + } + // Native shuffle for Comet operators case s: ShuffleExchangeExec if isCometShuffleEnabled(conf) && diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 902f7037f..5da926e38 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -29,10 +29,12 @@ import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, SinglePartition} import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils import org.apache.spark.sql.comet.{CometSinkPlaceHolder, DecimalPrecision} +import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec import org.apache.spark.sql.execution import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec import org.apache.spark.sql.execution.aggregate.HashAggregateExec -import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec} +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -1883,6 +1885,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { case _: CollectLimitExec => true case _: UnionExec => true case _: ShuffleExchangeExec => true + case ShuffleQueryStageExec(_, _: CometShuffleExchangeExec, _) => true + case ShuffleQueryStageExec(_, ReusedExchangeExec(_, _: CometShuffleExchangeExec), _) => true case _: TakeOrderedAndProjectExec => true case _: BroadcastExchangeExec => true case _ => false diff --git a/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala index 1a92f71db..216b6900e 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala @@ -26,7 +26,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.{Partitioner, SparkConf} import org.apache.spark.sql.{CometTestBase, DataFrame, Row} import org.apache.spark.sql.comet.execution.shuffle.{CometShuffleDependency, CometShuffleExchangeExec, CometShuffleManager} -import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, AQEShuffleReadExec, ShuffleQueryStageExec} import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.execution.joins.SortMergeJoinExec import org.apache.spark.sql.functions.col @@ -933,6 +933,52 @@ class CometShuffleSuite extends CometColumnarShuffleSuite { override protected val asyncShuffleEnable: Boolean = false protected val adaptiveExecutionEnabled: Boolean = true + + import testImplicits._ + + test("Comet native operator after ShuffleQueryStage") { + withSQLConf( + SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") { + withParquetTable((0 until 10).map(i => (i, i % 5)), "tbl_a") { + val df = sql("SELECT * FROM tbl_a") + val shuffled = df + .select($"_1" + 1 as ("a")) + .filter($"a" > 4) + .repartition(10) + .sortWithinPartitions($"a") + checkSparkAnswerAndOperator(shuffled, classOf[ShuffleQueryStageExec]) + } + } + } + + test("Comet native operator after ShuffleQueryStage + ReusedExchange") { + withSQLConf( + SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") { + withParquetTable((0 until 10).map(i => (i, i % 5)), "tbl_a") { + withParquetTable((0 until 10).map(i => (i % 10, i + 2)), "tbl_b") { + val df = sql("SELECT * FROM tbl_a") + val left = df + .select($"_1" + 1 as ("a")) + .filter($"a" > 4) + val right = left.select($"a" as ("b")) + val join = left.join(right, $"a" === $"b") + checkSparkAnswerAndOperator( + join, + classOf[ShuffleQueryStageExec], + classOf[SortMergeJoinExec], + classOf[AQEShuffleReadExec]) + } + } + } + } } class DisableAQECometShuffleSuite extends CometColumnarShuffleSuite { diff --git a/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala index c35763c34..59e27fd0f 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala @@ -64,8 +64,9 @@ class CometNativeShuffleSuite extends CometTestBase with AdaptiveSparkPlanHelper val path = new Path(dir.toURI.toString, "test.parquet") makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 1000) var allTypes: Seq[Int] = (1 to 20) - if (isSpark34Plus) { - allTypes = allTypes.filterNot(Set(14, 17).contains) + if (!isSpark34Plus) { + // TODO: Remove this once after https://github.com/apache/arrow/issues/40038 is fixed + allTypes = allTypes.filterNot(Set(14).contains) } allTypes.map(i => s"_$i").foreach { c => withSQLConf(