diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index a8579757d..adbe412de 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -239,6 +239,7 @@ abstract class CometNativeExec extends CometExec { val firstNonBroadcastPlan = sparkPlans.zipWithIndex.find { case (_: CometBroadcastExchangeExec, _) => false case (BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _), _) => false + case (BroadcastQueryStageExec(_, _: ReusedExchangeExec, _), _) => false case _ => true } @@ -263,6 +264,13 @@ abstract class CometNativeExec extends CometExec { inputs += c.setNumPartitions(firstNonBroadcastPlanNumPartitions).executeColumnar() case BroadcastQueryStageExec(_, c: CometBroadcastExchangeExec, _) => inputs += c.setNumPartitions(firstNonBroadcastPlanNumPartitions).executeColumnar() + case ReusedExchangeExec(_, c: CometBroadcastExchangeExec) => + inputs += c.setNumPartitions(firstNonBroadcastPlanNumPartitions).executeColumnar() + case BroadcastQueryStageExec( + _, + ReusedExchangeExec(_, c: CometBroadcastExchangeExec), + _) => + inputs += c.setNumPartitions(firstNonBroadcastPlanNumPartitions).executeColumnar() case _ if idx == firstNonBroadcastPlan.get._2 => inputs += firstNonBroadcastPlanRDD case _ =>