Skip to content

Commit

Permalink
Move if check for sum_iter state for the new_programs iterator ou…
Browse files Browse the repository at this point in the history
…tside the if then else

Remove unused `programs` field
  • Loading branch information
nicolaefilat committed May 12, 2024
1 parent f3688eb commit 685f6ce
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 38 deletions.
2 changes: 0 additions & 2 deletions src/probe/guided_search_optimized.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
level::Int64
bank::Vector{Vector{RuleNode}}
eval_cache::Set
programs::Vector{RuleNode}
iter::NewProgramsIterator
next_iter::Union{Tuple{RuleNode, NewProgramsState}, Nothing}
end
Expand All @@ -17,7 +16,6 @@ function Base.iterate(iter::GuidedSearchIteratorOptimzed)
level=-1,
bank=[],
eval_cache=Set(),
programs=[],
iter=NewProgramsIterator(0, [], iter.grammar),
next_iter=nothing
))
Expand Down
72 changes: 36 additions & 36 deletions src/probe/new_program_iterator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,48 +18,48 @@ function Base.iterate(iter::NewProgramsIterator)
end
function Base.iterate(iter::NewProgramsIterator, state::NewProgramsState)
while state.rule_index <= length(iter.grammar.rules)
nr_children = nchildren(iter.grammar, state.rule_index)
rule_cost = calculate_rule_cost(state.rule_index, iter.grammar)
if rule_cost == iter.level && nr_children == 0
# if one rule is enough and has no children just return that tree
program = RuleNode(state.rule_index)
state.rule_index += 1
state.sum_iter = nothing
return program, state
elseif rule_cost < iter.level && nr_children > 0
# outer for loop not started -> start it
if state.sum_iter === nothing
if state.sum_iter === nothing
nr_children = nchildren(iter.grammar, state.rule_index)
rule_cost = calculate_rule_cost(state.rule_index, iter.grammar)
if rule_cost == iter.level && nr_children == 0
# if one rule is enough and has no children just return that tree
program = RuleNode(state.rule_index)
state.rule_index += 1
state.sum_iter = nothing
return program, state
elseif rule_cost < iter.level && nr_children > 0
# outer for loop not started -> start it
state.sum_iter = SumIterator(nr_children, iter.level - rule_cost, iter.level - rule_cost)
state.sum_iter_state = iterate(state.sum_iter)
state.cartesian_iter = nothing
end
# if the outerfor loop is not done
while state.sum_iter_state !== nothing
costs, _ = state.sum_iter_state
end
# if the outerfor loop is not done
while state.sum_iter_state !== nothing
costs, _ = state.sum_iter_state

# if the inner for loop is not started
if state.cartesian_iter === nothing
# create inner for loop
bank_indexed = [iter.bank[cost+1] for cost costs]
state.cartesian_iter = Iterators.product(bank_indexed...)
state.cartesian_iter_state = iterate(state.cartesian_iter)
end
# if the inner for loop is not started
if state.cartesian_iter === nothing
# create inner for loop
bank_indexed = [iter.bank[cost+1] for cost costs]
state.cartesian_iter = Iterators.product(bank_indexed...)
state.cartesian_iter_state = iterate(state.cartesian_iter)
end

if state.cartesian_iter_state === nothing
# move one step outer for loop
_, next_state = state.sum_iter_state
state.sum_iter_state = iterate(state.sum_iter, next_state)
# reset inner loop
state.cartesian_iter = nothing
else
# save current values
children, _ = state.cartesian_iter_state
rulenode = RuleNode(state.rule_index, collect(children))
# move to next cartesian
_, next_state = state.cartesian_iter_state
state.cartesian_iter_state = iterate(state.cartesian_iter, next_state)
return rulenode, state
end
if state.cartesian_iter_state === nothing
# move one step outer for loop
_, next_state = state.sum_iter_state
state.sum_iter_state = iterate(state.sum_iter, next_state)
# reset inner loop
state.cartesian_iter = nothing
else
# save current values
children, _ = state.cartesian_iter_state
rulenode = RuleNode(state.rule_index, collect(children))
# move to next cartesian
_, next_state = state.cartesian_iter_state
state.cartesian_iter_state = iterate(state.cartesian_iter, next_state)
return rulenode, state
end
end
state.rule_index += 1
Expand Down

0 comments on commit 685f6ce

Please sign in to comment.