diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index 3fe830f04..490e0b0f2 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -462,10 +462,6 @@ class CometSparkSessionExtensions withInfo(op, "BroadcastHashJoin is not enabled") op - case op: BroadcastHashJoinExec if !op.children.forall(isCometNative(_)) => - withInfo(op, "BroadcastHashJoin disabled because not all child plans are native") - op - case op: SortMergeJoinExec if isCometOperatorEnabled(conf, "sort_merge_join") && op.children.forall(isCometNative(_)) => @@ -573,8 +569,7 @@ class CometSparkSessionExtensions case b: BroadcastExchangeExec if isCometNative(b.child) && isCometOperatorEnabled(conf, "broadcastExchangeExec") => - val newOp = QueryPlanSerde.operator2Proto(b) - newOp match { + QueryPlanSerde.operator2Proto(b) match { case Some(nativeOp) => val cometOp = CometBroadcastExchangeExec(b, b.child) CometSinkPlaceHolder(nativeOp, b, cometOp) @@ -589,7 +584,7 @@ class CometSparkSessionExtensions } else { if (!isCometOperatorEnabled( conf, - "broadcastExchangeExec") || !isCometBroadCastForceEnabled(conf)) { + "broadcastExchangeExec") && !isCometBroadCastForceEnabled(conf)) { withInfo(plan, "Native Broadcast is not enabled") } plan @@ -598,6 +593,13 @@ class CometSparkSessionExtensions plan } + // this case should be checked only after the previous case checking for a + // child BroadcastExchange has been applied, otherwise that transform + // never gets applied + case op: BroadcastHashJoinExec if !op.children.forall(isCometNative(_)) => + withInfo(op, "BroadcastHashJoin disabled because not all child plans are native") + op + // For AQE shuffle stage on a Comet shuffle exchange case s @ ShuffleQueryStageExec(_, _: CometShuffleExchangeExec, _) => val newOp = transform1(s) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index b18ed4b69..ca893f185 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -1296,6 +1296,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { } else { val unboundRef = ExprOuterClass.UnboundReference .newBuilder() + .setName(attr.name) .setDatatype(dataType.get) .build()