Skip to content

Commit

Permalink
fix: Comet columnar shuffle should not be on top of another Comet shu…
Browse files Browse the repository at this point in the history
…ffle operator (#296)

* fix: Comet columnar shuffle should not be on top of another Comet shuffle operator

* Fix 3.2 build
  • Loading branch information
viirya authored Apr 21, 2024
1 parent 138b062 commit f4d3869
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,8 @@ class CometSparkSessionExtensions
case s: ShuffleExchangeExec
if (!s.child.supportsColumnar || isCometPlan(
s.child)) && isCometColumnarShuffleEnabled(conf) &&
QueryPlanSerde.supportPartitioningTypes(s.child.output) =>
QueryPlanSerde.supportPartitioningTypes(s.child.output) &&
!isShuffleOperator(s.child) =>
logInfo("Comet extension enabled for JVM Columnar Shuffle")
CometShuffleExchangeExec(s, shuffleType = CometColumnarShuffle)
}
Expand Down Expand Up @@ -538,10 +539,13 @@ class CometSparkSessionExtensions
}

// Columnar shuffle for regular Spark operators (not Comet) and Comet operators
// (if configured)
// (if configured).
// If the child of ShuffleExchangeExec is also a ShuffleExchangeExec, we should not
// convert it to CometColumnarShuffle,
case s: ShuffleExchangeExec
if isCometShuffleEnabled(conf) && isCometColumnarShuffleEnabled(conf) &&
QueryPlanSerde.supportPartitioningTypes(s.child.output) =>
QueryPlanSerde.supportPartitioningTypes(s.child.output) &&
!isShuffleOperator(s.child) =>
logInfo("Comet extension enabled for JVM Columnar Shuffle")

val newOp = QueryPlanSerde.operator2Proto(s)
Expand Down Expand Up @@ -627,6 +631,18 @@ class CometSparkSessionExtensions
case s: ShuffleQueryStageExec => findPartialAgg(s.plan)
}.flatten
}

/**
* Returns true if a given spark plan is Comet shuffle operator.
*/
def isShuffleOperator(op: SparkPlan): Boolean = {
op match {
case op: ShuffleQueryStageExec if op.plan.isInstanceOf[CometShuffleExchangeExec] => true
case _: CometShuffleExchangeExec => true
case op: CometSinkPlaceHolder => isShuffleOperator(op.child)
case _ => false
}
}
}

// This rule is responsible for eliminating redundant transitions between row-based and
Expand Down
25 changes: 24 additions & 1 deletion spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, SQLExecuti
import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec
import org.apache.spark.sql.execution.joins.{BroadcastNestedLoopJoinExec, CartesianProductExec, SortMergeJoinExec}
import org.apache.spark.sql.execution.window.WindowExec
import org.apache.spark.sql.functions.{col, date_add, expr, sum}
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.{col, date_add, expr, lead, sum}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.SESSION_LOCAL_TIMEZONE
import org.apache.spark.unsafe.types.UTF8String
Expand All @@ -60,6 +61,28 @@ class CometExecSuite extends CometTestBase {
}
}

test("Repeated shuffle exchange don't fail") {
assume(isSpark33Plus)
Seq("true", "false").foreach { aqeEnabled =>
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> aqeEnabled,
// `REQUIRE_ALL_CLUSTER_KEYS_FOR_DISTRIBUTION` is a new config in Spark 3.3+.
"spark.sql.requireAllClusterKeysForDistribution" -> "true",
CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") {
val df =
Seq(("a", 1, 1), ("a", 2, 2), ("b", 1, 3), ("b", 1, 4)).toDF("key1", "key2", "value")
val windowSpec = Window.partitionBy("key1", "key2").orderBy("value")

val windowed = df
// repartition by subset of window partitionBy keys which satisfies ClusteredDistribution
.repartition($"key1")
.select(lead($"key1", 1).over(windowSpec), lead($"value", 1).over(windowSpec))

checkSparkAnswer(windowed)
}
}
}

test("try_sum should return null if overflow happens before merging") {
assume(isSpark33Plus, "try_sum is available in Spark 3.3+")
val longDf = Seq(Long.MaxValue, Long.MaxValue, 2).toDF("v")
Expand Down

0 comments on commit f4d3869

Please sign in to comment.