Skip to content

Commit

Permalink
change for another iteration to be an integer
Browse files Browse the repository at this point in the history
  • Loading branch information
stefankandic committed Nov 27, 2024
1 parent d92a5cf commit 26909ea
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,13 @@ class ResolveDefaultStringTypes(replaceWithTempType: Boolean) extends Rule[Logic
}

if (!replaceWithTempType || newPlan.fastEquals(plan)) {
newPlan.unsetTagValue(RuleExecutor.FORCE_ANOTHER_BATCH_ITER)
newPlan
} else {
// Due to how tree transformations work and StringType object being equal to
// StringType("UTF8_BINARY"), we need to run `ResolveDefaultStringType` twice
// to ensure the correct results for occurrences of default string type.
val finalPlan = ResolveDefaultStringTypesWithoutTempType.apply(newPlan)
finalPlan.setTagValue(RuleExecutor.FORCE_ANOTHER_BATCH_ITER, ())
RuleExecutor.forceAdditionalIteration(finalPlan)
finalPlan
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.spark.internal.{Logging, MessageWithContext}
import org.apache.spark.internal.LogKeys._
import org.apache.spark.internal.MDC
import org.apache.spark.sql.catalyst.QueryPlanningTracker
import org.apache.spark.sql.catalyst.rules.RuleExecutor.getForceIterationValue
import org.apache.spark.sql.catalyst.trees.{TreeNode, TreeNodeTag}
import org.apache.spark.sql.catalyst.util.DateTimeConstants.NANOS_PER_MILLIS
import org.apache.spark.sql.catalyst.util.sideBySide
Expand All @@ -32,11 +33,24 @@ import org.apache.spark.util.Utils
object RuleExecutor {

/**
* A tag to indicate that we should do another batch iteration even if the plan
* hasn't changed between the start and end of the batch.
* Use with caution as it can lead to infinite loops.
* A tag used to explicitly request an additional iteration of the current batch during
* rule execution, even if the query plan remains unchanged. Increment the tag's value
* to enforce another iteration.
*/
private[spark] val FORCE_ANOTHER_BATCH_ITER = TreeNodeTag[Unit]("forceAnotherBatchIter")
private val FORCE_ADDITIONAL_ITERATION = TreeNodeTag[Int]("forceAdditionalIteration")

/**
* Increments the value of the FORCE_ADDITIONAL_ITERATION tag on the given plan to
* explicitly force another iteration of the current batch during rule execution.
*/
def forceAdditionalIteration(plan: TreeNode[_]): Unit = {
val oldValue = getForceIterationValue(plan)
plan.setTagValue(FORCE_ADDITIONAL_ITERATION, oldValue + 1)
}

private def getForceIterationValue(plan: TreeNode[_]): Int = {
plan.getTagValue(FORCE_ADDITIONAL_ITERATION).getOrElse(0)
}

protected val queryExecutionMeter = QueryExecutionMetering()

Expand Down Expand Up @@ -326,8 +340,8 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging {
curPlan
}

def isFixedPointReached(oldPlan: TreeType, newPlan: TreeType): Boolean = {
private def isFixedPointReached(oldPlan: TreeType, newPlan: TreeType): Boolean = {
oldPlan.fastEquals(newPlan) &&
newPlan.getTagValue(RuleExecutor.FORCE_ANOTHER_BATCH_ITER).isEmpty
getForceIterationValue(newPlan) <= getForceIterationValue(oldPlan)
}
}

0 comments on commit 26909ea

Please sign in to comment.