Skip to content

Commit

Permalink
fix: CometExecRule should handle ShuffleQueryStage and ReusedExchange
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Mar 10, 2024
1 parent 488c523 commit 30ded81
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,18 @@ import org.apache.spark.internal.Logging
import org.apache.spark.network.util.ByteUnit
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.SparkSessionExtensions
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.comet._
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle}
import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -222,12 +222,16 @@ class CometSparkSessionExtensions
*/
// spotless:on
private def transform(plan: SparkPlan): SparkPlan = {
def transform1(op: UnaryExecNode): Option[Operator] = {
op.child match {
case childNativeOp: CometNativeExec =>
QueryPlanSerde.operator2Proto(op, childNativeOp.nativeOp)
case _ =>
None
def transform1(op: SparkPlan): Option[Operator] = {
val allNativeExec = op.children.map {
case childNativeOp: CometNativeExec => Some(childNativeOp.nativeOp)
case _ => None
}

if (allNativeExec.forall(_.isDefined)) {
QueryPlanSerde.operator2Proto(op, allNativeExec.map(_.get): _*)
} else {
None
}
}

Expand Down Expand Up @@ -378,6 +382,27 @@ class CometSparkSessionExtensions
case None => b
}

case s @ ShuffleQueryStageExec(_, _: CometShuffleExchangeExec, _) =>
val newOp = transform1(s)
newOp match {
case Some(nativeOp) =>
CometSinkPlaceHolder(nativeOp, s, s)
case None =>
s
}

case s @ ShuffleQueryStageExec(
_,
ReusedExchangeExec(_, _: CometShuffleExchangeExec),
_) =>
val newOp = transform1(s)
newOp match {
case Some(nativeOp) =>
CometSinkPlaceHolder(nativeOp, s, s)
case None =>
s
}

// Native shuffle for Comet operators
case s: ShuffleExchangeExec
if isCometShuffleEnabled(conf) &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,13 @@ import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, SinglePartition}
import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils
import org.apache.spark.sql.comet.{CometHashAggregateExec, CometPlan, CometSinkPlaceHolder, DecimalPrecision}
import org.apache.spark.sql.comet.{CometSinkPlaceHolder, DecimalPrecision}
import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
import org.apache.spark.sql.execution
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
Expand Down Expand Up @@ -1883,6 +1885,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde {
case _: CollectLimitExec => true
case _: UnionExec => true
case _: ShuffleExchangeExec => true
case ShuffleQueryStageExec(_, _: CometShuffleExchangeExec, _) => true
case ShuffleQueryStageExec(_, ReusedExchangeExec(_, _: CometShuffleExchangeExec), _) => true
case _: TakeOrderedAndProjectExec => true
case _: BroadcastExchangeExec => true
case _ => false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.{Partitioner, SparkConf}
import org.apache.spark.sql.{CometTestBase, DataFrame, Row}
import org.apache.spark.sql.comet.execution.shuffle.{CometShuffleDependency, CometShuffleExchangeExec, CometShuffleManager}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, AQEShuffleReadExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -925,6 +926,54 @@ class CometAsyncShuffleSuite extends CometColumnarShuffleSuite {
override protected val asyncShuffleEnable: Boolean = true

protected val adaptiveExecutionEnabled: Boolean = true

import testImplicits._

test("Comet native operator after ShuffleQueryStage") {
withSQLConf(
SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
CometConf.COMET_EXEC_ENABLED.key -> "true",
CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true",
CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") {
withParquetTable((0 until 10).map(i => (i, i % 5)), "tbl_a") {
withParquetTable((0 until 10).map(i => (i % 10, i + 2)), "tbl_b") {
val df = sql("SELECT * FROM tbl_a")
val shuffled = df
.select($"_1" + 1 as ("a"))
.filter($"a" > 4)
.repartition(10)
.sortWithinPartitions($"a")
checkSparkAnswerAndOperator(shuffled, classOf[ShuffleQueryStageExec])
}
}
}
}

test("Comet native operator after ShuffleQueryStage + ReusedExchange") {
withSQLConf(
SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
CometConf.COMET_EXEC_ENABLED.key -> "true",
CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true",
CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") {
withParquetTable((0 until 10).map(i => (i, i % 5)), "tbl_a") {
withParquetTable((0 until 10).map(i => (i % 10, i + 2)), "tbl_b") {
val df = sql("SELECT * FROM tbl_a")
val left = df
.select($"_1" + 1 as ("a"))
.filter($"a" > 4)
val right = left.select($"a" as ("b"))
val join = left.join(right, $"a" === $"b")
checkSparkAnswerAndOperator(
join,
classOf[ShuffleQueryStageExec],
classOf[SortMergeJoinExec],
classOf[AQEShuffleReadExec])
}
}
}
}
}

class CometShuffleSuite extends CometColumnarShuffleSuite {
Expand Down

0 comments on commit 30ded81

Please sign in to comment.