From f4d38693095aef6d2495b2dec1b3b0330281ee70 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 21 Apr 2024 10:10:29 -0700 Subject: [PATCH] fix: Comet columnar shuffle should not be on top of another Comet shuffle operator (#296) * fix: Comet columnar shuffle should not be on top of another Comet shuffle operator * Fix 3.2 build --- .../comet/CometSparkSessionExtensions.scala | 22 +++++++++++++--- .../apache/comet/exec/CometExecSuite.scala | 25 ++++++++++++++++++- 2 files changed, 43 insertions(+), 4 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index d794f4017..9b5889770 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -151,7 +151,8 @@ class CometSparkSessionExtensions case s: ShuffleExchangeExec if (!s.child.supportsColumnar || isCometPlan( s.child)) && isCometColumnarShuffleEnabled(conf) && - QueryPlanSerde.supportPartitioningTypes(s.child.output) => + QueryPlanSerde.supportPartitioningTypes(s.child.output) && + !isShuffleOperator(s.child) => logInfo("Comet extension enabled for JVM Columnar Shuffle") CometShuffleExchangeExec(s, shuffleType = CometColumnarShuffle) } @@ -538,10 +539,13 @@ class CometSparkSessionExtensions } // Columnar shuffle for regular Spark operators (not Comet) and Comet operators - // (if configured) + // (if configured). + // If the child of ShuffleExchangeExec is also a ShuffleExchangeExec, we should not + // convert it to CometColumnarShuffle, case s: ShuffleExchangeExec if isCometShuffleEnabled(conf) && isCometColumnarShuffleEnabled(conf) && - QueryPlanSerde.supportPartitioningTypes(s.child.output) => + QueryPlanSerde.supportPartitioningTypes(s.child.output) && + !isShuffleOperator(s.child) => logInfo("Comet extension enabled for JVM Columnar Shuffle") val newOp = QueryPlanSerde.operator2Proto(s) @@ -627,6 +631,18 @@ class CometSparkSessionExtensions case s: ShuffleQueryStageExec => findPartialAgg(s.plan) }.flatten } + + /** + * Returns true if a given spark plan is Comet shuffle operator. + */ + def isShuffleOperator(op: SparkPlan): Boolean = { + op match { + case op: ShuffleQueryStageExec if op.plan.isInstanceOf[CometShuffleExchangeExec] => true + case _: CometShuffleExchangeExec => true + case op: CometSinkPlaceHolder => isShuffleOperator(op.child) + case _ => false + } + } } // This rule is responsible for eliminating redundant transitions between row-based and diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index a8b05cc98..5e9907368 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -40,7 +40,8 @@ import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, SQLExecuti import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec import org.apache.spark.sql.execution.joins.{BroadcastNestedLoopJoinExec, CartesianProductExec, SortMergeJoinExec} import org.apache.spark.sql.execution.window.WindowExec -import org.apache.spark.sql.functions.{col, date_add, expr, sum} +import org.apache.spark.sql.expressions.Window +import org.apache.spark.sql.functions.{col, date_add, expr, lead, sum} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.SESSION_LOCAL_TIMEZONE import org.apache.spark.unsafe.types.UTF8String @@ -60,6 +61,28 @@ class CometExecSuite extends CometTestBase { } } + test("Repeated shuffle exchange don't fail") { + assume(isSpark33Plus) + Seq("true", "false").foreach { aqeEnabled => + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> aqeEnabled, + // `REQUIRE_ALL_CLUSTER_KEYS_FOR_DISTRIBUTION` is a new config in Spark 3.3+. + "spark.sql.requireAllClusterKeysForDistribution" -> "true", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") { + val df = + Seq(("a", 1, 1), ("a", 2, 2), ("b", 1, 3), ("b", 1, 4)).toDF("key1", "key2", "value") + val windowSpec = Window.partitionBy("key1", "key2").orderBy("value") + + val windowed = df + // repartition by subset of window partitionBy keys which satisfies ClusteredDistribution + .repartition($"key1") + .select(lead($"key1", 1).over(windowSpec), lead($"value", 1).over(windowSpec)) + + checkSparkAnswer(windowed) + } + } + } + test("try_sum should return null if overflow happens before merging") { assume(isSpark33Plus, "try_sum is available in Spark 3.3+") val longDf = Seq(Long.MaxValue, Long.MaxValue, 2).toDF("v")