diff --git a/src/probe/guided_search_optimized.jl b/src/probe/guided_search_optimized.jl index 4221c70..8e04f2c 100644 --- a/src/probe/guided_search_optimized.jl +++ b/src/probe/guided_search_optimized.jl @@ -7,7 +7,6 @@ level::Int64 bank::Vector{Vector{RuleNode}} eval_cache::Set - programs::Vector{RuleNode} iter::NewProgramsIterator next_iter::Union{Tuple{RuleNode, NewProgramsState}, Nothing} end @@ -17,7 +16,6 @@ function Base.iterate(iter::GuidedSearchIteratorOptimzed) level=-1, bank=[], eval_cache=Set(), - programs=[], iter=NewProgramsIterator(0, [], iter.grammar), next_iter=nothing )) diff --git a/src/probe/new_program_iterator.jl b/src/probe/new_program_iterator.jl index 33c793e..e4a95df 100644 --- a/src/probe/new_program_iterator.jl +++ b/src/probe/new_program_iterator.jl @@ -18,48 +18,48 @@ function Base.iterate(iter::NewProgramsIterator) end function Base.iterate(iter::NewProgramsIterator, state::NewProgramsState) while state.rule_index <= length(iter.grammar.rules) - nr_children = nchildren(iter.grammar, state.rule_index) - rule_cost = calculate_rule_cost(state.rule_index, iter.grammar) - if rule_cost == iter.level && nr_children == 0 - # if one rule is enough and has no children just return that tree - program = RuleNode(state.rule_index) - state.rule_index += 1 - state.sum_iter = nothing - return program, state - elseif rule_cost < iter.level && nr_children > 0 - # outer for loop not started -> start it - if state.sum_iter === nothing + if state.sum_iter === nothing + nr_children = nchildren(iter.grammar, state.rule_index) + rule_cost = calculate_rule_cost(state.rule_index, iter.grammar) + if rule_cost == iter.level && nr_children == 0 + # if one rule is enough and has no children just return that tree + program = RuleNode(state.rule_index) + state.rule_index += 1 + state.sum_iter = nothing + return program, state + elseif rule_cost < iter.level && nr_children > 0 + # outer for loop not started -> start it state.sum_iter = SumIterator(nr_children, iter.level - rule_cost, iter.level - rule_cost) state.sum_iter_state = iterate(state.sum_iter) state.cartesian_iter = nothing end - # if the outerfor loop is not done - while state.sum_iter_state !== nothing - costs, _ = state.sum_iter_state + end + # if the outerfor loop is not done + while state.sum_iter_state !== nothing + costs, _ = state.sum_iter_state - # if the inner for loop is not started - if state.cartesian_iter === nothing - # create inner for loop - bank_indexed = [iter.bank[cost+1] for cost ∈ costs] - state.cartesian_iter = Iterators.product(bank_indexed...) - state.cartesian_iter_state = iterate(state.cartesian_iter) - end + # if the inner for loop is not started + if state.cartesian_iter === nothing + # create inner for loop + bank_indexed = [iter.bank[cost+1] for cost ∈ costs] + state.cartesian_iter = Iterators.product(bank_indexed...) + state.cartesian_iter_state = iterate(state.cartesian_iter) + end - if state.cartesian_iter_state === nothing - # move one step outer for loop - _, next_state = state.sum_iter_state - state.sum_iter_state = iterate(state.sum_iter, next_state) - # reset inner loop - state.cartesian_iter = nothing - else - # save current values - children, _ = state.cartesian_iter_state - rulenode = RuleNode(state.rule_index, collect(children)) - # move to next cartesian - _, next_state = state.cartesian_iter_state - state.cartesian_iter_state = iterate(state.cartesian_iter, next_state) - return rulenode, state - end + if state.cartesian_iter_state === nothing + # move one step outer for loop + _, next_state = state.sum_iter_state + state.sum_iter_state = iterate(state.sum_iter, next_state) + # reset inner loop + state.cartesian_iter = nothing + else + # save current values + children, _ = state.cartesian_iter_state + rulenode = RuleNode(state.rule_index, collect(children)) + # move to next cartesian + _, next_state = state.cartesian_iter_state + state.cartesian_iter_state = iterate(state.cartesian_iter, next_state) + return rulenode, state end end state.rule_index += 1