Skip to content

Commit

Permalink
fix: Reuse CometBroadcastExchangeExec with Spark ReuseExchangeAndSubq…
Browse files Browse the repository at this point in the history
…uery rule (#441)
  • Loading branch information
viirya authored May 18, 2024
1 parent 414e7a3 commit ec8da30
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

import org.apache.comet.CometConf._
import org.apache.comet.CometSparkSessionExtensions.{createMessage, isANSIEnabled, isCometBroadCastForceEnabled, isCometColumnarShuffleEnabled, isCometEnabled, isCometExecEnabled, isCometOperatorEnabled, isCometScan, isCometScanEnabled, isCometShuffleEnabled, isSchemaSupported, shouldApplyRowToColumnar, withInfo, withInfos}
import org.apache.comet.CometSparkSessionExtensions.{createMessage, isANSIEnabled, isCometBroadCastForceEnabled, isCometColumnarShuffleEnabled, isCometEnabled, isCometExecEnabled, isCometOperatorEnabled, isCometScan, isCometScanEnabled, isCometShuffleEnabled, isSchemaSupported, isSpark34Plus, shouldApplyRowToColumnar, withInfo, withInfos}
import org.apache.comet.parquet.{CometParquetScan, SupportsComet}
import org.apache.comet.serde.OperatorOuterClass.Operator
import org.apache.comet.serde.QueryPlanSerde
Expand Down Expand Up @@ -576,11 +576,13 @@ class CometSparkSessionExtensions
// exchange. It is only used for Comet native execution. We only transform Spark broadcast
// exchange to Comet broadcast exchange if its downstream is a Comet native plan or if the
// broadcast exchange is forced to be enabled by Comet config.
// Note that `CometBroadcastExchangeExec` is only supported for Spark 3.4+.
case plan if plan.children.exists(_.isInstanceOf[BroadcastExchangeExec]) =>
val newChildren = plan.children.map {
case b: BroadcastExchangeExec
if isCometNative(b.child) &&
isCometOperatorEnabled(conf, "broadcastExchangeExec") =>
isCometOperatorEnabled(conf, "broadcastExchangeExec") &&
isSpark34Plus => // Spark 3.4+ only
QueryPlanSerde.operator2Proto(b) match {
case Some(nativeOp) =>
val cometOp = CometBroadcastExchangeExec(b, b.child)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,13 +238,13 @@ case class CometBroadcastExchangeExec(originalPlan: SparkPlan, child: SparkPlan)
obj match {
case other: CometBroadcastExchangeExec =>
this.originalPlan == other.originalPlan &&
this.output == other.output && this.child == other.child
this.child == other.child
case _ =>
false
}
}

override def hashCode(): Int = Objects.hashCode(output, child)
override def hashCode(): Int = Objects.hashCode(child)

override def stringArgs: Iterator[Any] = Iterator(output, child)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -406,14 +406,14 @@ case class CometProjectExec(
obj match {
case other: CometProjectExec =>
this.projectList == other.projectList &&
this.output == other.output && this.child == other.child &&
this.child == other.child &&
this.serializedPlanOpt == other.serializedPlanOpt
case _ =>
false
}
}

override def hashCode(): Int = Objects.hashCode(projectList, output, child)
override def hashCode(): Int = Objects.hashCode(projectList, child)

override protected def outputExpressions: Seq[NamedExpression] = projectList
}
Expand Down
37 changes: 36 additions & 1 deletion spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateMode
import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometBroadcastHashJoinExec, CometCollectLimitExec, CometFilterExec, CometHashAggregateExec, CometHashJoinExec, CometProjectExec, CometRowToColumnarExec, CometScanExec, CometSortExec, CometSortMergeJoinExec, CometTakeOrderedAndProjectExec}
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometShuffleExchangeExec}
import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, SQLExecution, UnionExec}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{BroadcastNestedLoopJoinExec, CartesianProductExec, SortMergeJoinExec}
import org.apache.spark.sql.execution.reuse.ReuseExchangeAndSubquery
import org.apache.spark.sql.execution.window.WindowExec
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.{col, date_add, expr, lead, sum}
Expand All @@ -62,6 +63,40 @@ class CometExecSuite extends CometTestBase {
}
}

test("ReusedExchangeExec should work on CometBroadcastExchangeExec") {
assume(isSpark34Plus, "ChunkedByteBuffer is not serializable before Spark 3.4+")
withSQLConf(
CometConf.COMET_EXEC_BROADCAST_FORCE_ENABLED.key -> "true",
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false",
SQLConf.USE_V1_SOURCE_LIST.key -> "") {
withTempPath { path =>
spark
.range(5)
.withColumn("p", $"id" % 2)
.write
.mode("overwrite")
.partitionBy("p")
.parquet(path.toString)
withTempView("t") {
spark.read.parquet(path.toString).createOrReplaceTempView("t")
val df = sql("""
|SELECT t1.id, t2.id, t3.id
|FROM t AS t1
|JOIN t AS t2 ON t2.id = t1.id
|JOIN t AS t3 ON t3.id = t2.id
|WHERE t1.p = 1 AND t2.p = 1 AND t3.p = 1
|""".stripMargin)
val reusedPlan = ReuseExchangeAndSubquery.apply(df.queryExecution.executedPlan)
val reusedExchanges = collect(reusedPlan) { case r: ReusedExchangeExec =>
r
}
assert(reusedExchanges.size == 1)
assert(reusedExchanges.head.child.isInstanceOf[CometBroadcastExchangeExec])
}
}
}
}

test("CometShuffleExchangeExec logical link should be correct") {
withTempView("v") {
spark.sparkContext
Expand Down

0 comments on commit ec8da30

Please sign in to comment.