diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index 7ddc950ea..7c269c411 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -46,7 +46,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.comet.CometConf._ -import org.apache.comet.CometSparkSessionExtensions.{createMessage, isANSIEnabled, isCometBroadCastForceEnabled, isCometColumnarShuffleEnabled, isCometEnabled, isCometExecEnabled, isCometOperatorEnabled, isCometScan, isCometScanEnabled, isCometShuffleEnabled, isSchemaSupported, shouldApplyRowToColumnar, withInfo, withInfos} +import org.apache.comet.CometSparkSessionExtensions.{createMessage, isANSIEnabled, isCometBroadCastForceEnabled, isCometColumnarShuffleEnabled, isCometEnabled, isCometExecEnabled, isCometOperatorEnabled, isCometScan, isCometScanEnabled, isCometShuffleEnabled, isSchemaSupported, isSpark34Plus, shouldApplyRowToColumnar, withInfo, withInfos} import org.apache.comet.parquet.{CometParquetScan, SupportsComet} import org.apache.comet.serde.OperatorOuterClass.Operator import org.apache.comet.serde.QueryPlanSerde @@ -576,11 +576,13 @@ class CometSparkSessionExtensions // exchange. It is only used for Comet native execution. We only transform Spark broadcast // exchange to Comet broadcast exchange if its downstream is a Comet native plan or if the // broadcast exchange is forced to be enabled by Comet config. + // Note that `CometBroadcastExchangeExec` is only supported for Spark 3.4+. case plan if plan.children.exists(_.isInstanceOf[BroadcastExchangeExec]) => val newChildren = plan.children.map { case b: BroadcastExchangeExec if isCometNative(b.child) && - isCometOperatorEnabled(conf, "broadcastExchangeExec") => + isCometOperatorEnabled(conf, "broadcastExchangeExec") && + isSpark34Plus => // Spark 3.4+ only QueryPlanSerde.operator2Proto(b) match { case Some(nativeOp) => val cometOp = CometBroadcastExchangeExec(b, b.child) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala index 06c5898f7..7bd34debb 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala @@ -238,13 +238,13 @@ case class CometBroadcastExchangeExec(originalPlan: SparkPlan, child: SparkPlan) obj match { case other: CometBroadcastExchangeExec => this.originalPlan == other.originalPlan && - this.output == other.output && this.child == other.child + this.child == other.child case _ => false } } - override def hashCode(): Int = Objects.hashCode(output, child) + override def hashCode(): Int = Objects.hashCode(child) override def stringArgs: Iterator[Any] = Iterator(output, child) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index ad07ff0e2..63587af32 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -406,14 +406,14 @@ case class CometProjectExec( obj match { case other: CometProjectExec => this.projectList == other.projectList && - this.output == other.output && this.child == other.child && + this.child == other.child && this.serializedPlanOpt == other.serializedPlanOpt case _ => false } } - override def hashCode(): Int = Objects.hashCode(projectList, output, child) + override def hashCode(): Int = Objects.hashCode(projectList, child) override protected def outputExpressions: Seq[NamedExpression] = projectList } diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 8f022988f..2e1444281 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -38,8 +38,9 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateMode import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometBroadcastHashJoinExec, CometCollectLimitExec, CometFilterExec, CometHashAggregateExec, CometHashJoinExec, CometProjectExec, CometRowToColumnarExec, CometScanExec, CometSortExec, CometSortMergeJoinExec, CometTakeOrderedAndProjectExec} import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometShuffleExchangeExec} import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, SQLExecution, UnionExec} -import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec} +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.{BroadcastNestedLoopJoinExec, CartesianProductExec, SortMergeJoinExec} +import org.apache.spark.sql.execution.reuse.ReuseExchangeAndSubquery import org.apache.spark.sql.execution.window.WindowExec import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions.{col, date_add, expr, lead, sum} @@ -62,6 +63,40 @@ class CometExecSuite extends CometTestBase { } } + test("ReusedExchangeExec should work on CometBroadcastExchangeExec") { + assume(isSpark34Plus, "ChunkedByteBuffer is not serializable before Spark 3.4+") + withSQLConf( + CometConf.COMET_EXEC_BROADCAST_FORCE_ENABLED.key -> "true", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false", + SQLConf.USE_V1_SOURCE_LIST.key -> "") { + withTempPath { path => + spark + .range(5) + .withColumn("p", $"id" % 2) + .write + .mode("overwrite") + .partitionBy("p") + .parquet(path.toString) + withTempView("t") { + spark.read.parquet(path.toString).createOrReplaceTempView("t") + val df = sql(""" + |SELECT t1.id, t2.id, t3.id + |FROM t AS t1 + |JOIN t AS t2 ON t2.id = t1.id + |JOIN t AS t3 ON t3.id = t2.id + |WHERE t1.p = 1 AND t2.p = 1 AND t3.p = 1 + |""".stripMargin) + val reusedPlan = ReuseExchangeAndSubquery.apply(df.queryExecution.executedPlan) + val reusedExchanges = collect(reusedPlan) { case r: ReusedExchangeExec => + r + } + assert(reusedExchanges.size == 1) + assert(reusedExchanges.head.child.isInstanceOf[CometBroadcastExchangeExec]) + } + } + } + } + test("CometShuffleExchangeExec logical link should be correct") { withTempView("v") { spark.sparkContext