diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala index 6ea8a37ff9e86..2695fe4b54393 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala @@ -43,14 +43,13 @@ class ResolveDefaultStringTypes(replaceWithTempType: Boolean) extends Rule[Logic } if (!replaceWithTempType || newPlan.fastEquals(plan)) { - newPlan.unsetTagValue(RuleExecutor.FORCE_ANOTHER_BATCH_ITER) newPlan } else { // Due to how tree transformations work and StringType object being equal to // StringType("UTF8_BINARY"), we need to run `ResolveDefaultStringType` twice // to ensure the correct results for occurrences of default string type. val finalPlan = ResolveDefaultStringTypesWithoutTempType.apply(newPlan) - finalPlan.setTagValue(RuleExecutor.FORCE_ANOTHER_BATCH_ITER, ()) + RuleExecutor.forceAdditionalIteration(finalPlan) finalPlan } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index ba1971aac6dfc..bdbf698db2e01 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -22,6 +22,7 @@ import org.apache.spark.internal.{Logging, MessageWithContext} import org.apache.spark.internal.LogKeys._ import org.apache.spark.internal.MDC import org.apache.spark.sql.catalyst.QueryPlanningTracker +import org.apache.spark.sql.catalyst.rules.RuleExecutor.getForceIterationValue import org.apache.spark.sql.catalyst.trees.{TreeNode, TreeNodeTag} import org.apache.spark.sql.catalyst.util.DateTimeConstants.NANOS_PER_MILLIS import org.apache.spark.sql.catalyst.util.sideBySide @@ -32,11 +33,24 @@ import org.apache.spark.util.Utils object RuleExecutor { /** - * A tag to indicate that we should do another batch iteration even if the plan - * hasn't changed between the start and end of the batch. - * Use with caution as it can lead to infinite loops. + * A tag used to explicitly request an additional iteration of the current batch during + * rule execution, even if the query plan remains unchanged. Increment the tag's value + * to enforce another iteration. */ - private[spark] val FORCE_ANOTHER_BATCH_ITER = TreeNodeTag[Unit]("forceAnotherBatchIter") + private val FORCE_ADDITIONAL_ITERATION = TreeNodeTag[Int]("forceAdditionalIteration") + + /** + * Increments the value of the FORCE_ADDITIONAL_ITERATION tag on the given plan to + * explicitly force another iteration of the current batch during rule execution. + */ + def forceAdditionalIteration(plan: TreeNode[_]): Unit = { + val oldValue = getForceIterationValue(plan) + plan.setTagValue(FORCE_ADDITIONAL_ITERATION, oldValue + 1) + } + + private def getForceIterationValue(plan: TreeNode[_]): Int = { + plan.getTagValue(FORCE_ADDITIONAL_ITERATION).getOrElse(0) + } protected val queryExecutionMeter = QueryExecutionMetering() @@ -326,8 +340,8 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { curPlan } - def isFixedPointReached(oldPlan: TreeType, newPlan: TreeType): Boolean = { + private def isFixedPointReached(oldPlan: TreeType, newPlan: TreeType): Boolean = { oldPlan.fastEquals(newPlan) && - newPlan.getTagValue(RuleExecutor.FORCE_ANOTHER_BATCH_ITER).isEmpty + getForceIterationValue(newPlan) <= getForceIterationValue(oldPlan) } }