diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index e6cb992fc..a6b61e648 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -446,9 +446,7 @@ class CometSparkSessionExtensions // exchange. It is only used for Comet native execution. We only transform Spark broadcast // exchange to Comet broadcast exchange if its downstream is a Comet native plan or if the // broadcast exchange is forced to be enabled by Comet config. - case plan - if (isCometNative(plan) || isCometBroadCastForceEnabled(conf)) && - plan.children.exists(_.isInstanceOf[BroadcastExchangeExec]) => + case plan if plan.children.exists(_.isInstanceOf[BroadcastExchangeExec]) => val newChildren = plan.children.map { case b: BroadcastExchangeExec if isCometNative(b.child) && @@ -461,7 +459,12 @@ class CometSparkSessionExtensions } case other => other } - plan.withNewChildren(newChildren) + val newPlan = transform(plan.withNewChildren(newChildren)) + if (isCometNative(newPlan) || isCometBroadCastForceEnabled(conf)) { + newPlan + } else { + plan + } // For AQE shuffle stage on a Comet shuffle exchange case s @ ShuffleQueryStageExec(_, _: CometShuffleExchangeExec, _) => diff --git a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala index 6f479e3bb..54c0baf16 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala @@ -45,7 +45,6 @@ class CometJoinSuite extends CometTestBase { withSQLConf( CometConf.COMET_BATCH_SIZE.key -> "100", SQLConf.PREFER_SORTMERGEJOIN.key -> "false", - "spark.comet.exec.broadcast.enabled" -> "true", "spark.sql.join.forceApplyShuffledHashJoin" -> "true", SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { @@ -74,7 +73,6 @@ class CometJoinSuite extends CometTestBase { withSQLConf( CometConf.COMET_BATCH_SIZE.key -> "100", SQLConf.PREFER_SORTMERGEJOIN.key -> "false", - "spark.comet.exec.broadcast.enabled" -> "true", "spark.sql.join.forceApplyShuffledHashJoin" -> "true", SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {