Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move iteration out of the UniformSolver (PR 2/2) #83

Merged
merged 7 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/HerbSearch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ using MLStyle
include("sampling_grammar.jl")

include("program_iterator.jl")
include("count_expressions.jl")
include("uniform_iterator.jl")

include("heuristics.jl")

Expand All @@ -39,7 +39,6 @@ include("genetic_search_iterator.jl")
include("random_iterator.jl")

export
count_expressions,
ProgramIterator,
@programiterator,

Expand All @@ -57,7 +56,9 @@ export
optimal_program,
suboptimal_program,

FixedShapedIterator,
FixedShapedIterator, #TODO: deprecated after the cp thesis
UniformIterator,
next_solution!,

TopDownIterator,
RandomIterator,
Expand Down
20 changes: 0 additions & 20 deletions src/count_expressions.jl

This file was deleted.

2 changes: 1 addition & 1 deletion src/fixed_shaped_iterator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ function Base.iterate(iter::FixedShapedIterator)
pq :: PriorityQueue{SolverState, Union{Real, Tuple{Vararg{Real}}}} = PriorityQueue()

solver = iter.solver
@assert !contains_variable_shaped_hole(get_tree(iter.solver)) "A FixedShapedIterator cannot iterate partial programs with Holes"
@assert !contains_nonuniform_hole(get_tree(iter.solver)) "A FixedShapedIterator cannot iterate partial programs with Holes"

if isfeasible(solver)
enqueue!(pq, get_state(solver), priority_function(iter, get_grammar(solver), get_tree(solver), 0))
Expand Down
16 changes: 16 additions & 0 deletions src/program_iterator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,22 @@ Base.IteratorSize(::ProgramIterator) = Base.SizeUnknown()

Base.eltype(::ProgramIterator) = Union{RuleNode, StateHole}


"""
Base.length(iter::ProgramIterator)

Counts and returns the number of possible programs without storing all the programs.
!!! warning: modifies and exhausts the iterator
"""
function Base.length(iter::ProgramIterator)
l = 0
for _ ∈ iter
l += 1
end
return l
end


"""
@programiterator

Expand Down
65 changes: 34 additions & 31 deletions src/top_down_iterator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,13 @@ function priority_function(
end

"""
function derivation_heuristic(::TopDownIterator)
function derivation_heuristic(::TopDownIterator, indices::Vector{Int})

Returns a sorted sublist of the `indices`, based on which rules are most promising to fill a hole.
By default, this is the identity function.
"""
function derivation_heuristic(::TopDownIterator)
return function (indices)
return indices;
end
function derivation_heuristic(::TopDownIterator, indices::Vector{Int})
return indices;
ReubenJ marked this conversation as resolved.
Show resolved Hide resolved
end

"""
Expand Down Expand Up @@ -76,14 +74,12 @@ function priority_function(
end

"""
function derivation_heuristic(::RandomIterator)
function derivation_heuristic(::RandomIterator, indices::Vector{Int})

Randomly shuffles the rules.
"""
function derivation_heuristic(::RandomIterator)
return function (indices)
return Random.shuffle!(indices);
end
function derivation_heuristic(::RandomIterator, indices::Vector{Int})
return Random.shuffle!(indices);
end


Expand Down Expand Up @@ -174,14 +170,32 @@ Currently, there are two possible causes of the expansion failing:
"""
@enum ExpandFailureReason limit_reached=1 already_complete=2


"""
function Base.collect(iter::TopDownIterator)

Return an array of all programs in the TopDownIterator.
!!! warning
This requires deepcopying programs from type StateHole to type RuleNode.
If it is not needed to save all programs, iterate over the iterator manually.
"""
function Base.collect(iter::TopDownIterator)
@warn "Collecting all programs of a TopDownIterator requires freeze_state"
programs = Vector{RuleNode}()
for program ∈ iter
push!(programs, freeze_state(program))
end
return programs
end

"""
Base.iterate(iter::TopDownIterator)

Describes the iteration for a given [`TopDownIterator`](@ref) over the grammar. The iteration constructs a [`PriorityQueue`](@ref) first and then prunes it propagating the active constraints. Recursively returns the result for the priority queue.
"""
function Base.iterate(iter::TopDownIterator)
# Priority queue with `SolverState`s (for variable shaped trees) and `UniformSolver`s (for fixed shaped trees)
pq :: PriorityQueue{Union{SolverState, UniformSolver}, Union{Real, Tuple{Vararg{Real}}}} = PriorityQueue()
# Priority queue with `SolverState`s (for variable shaped trees) and `UniformIterator`s (for fixed shaped trees)
pq :: PriorityQueue{Union{SolverState, UniformIterator}, Union{Real, Tuple{Vararg{Real}}}} = PriorityQueue()

#TODO: instantiating the solver should be in the program iterator macro
if isnothing(iter.solver)
Expand All @@ -199,18 +213,6 @@ function Base.iterate(iter::TopDownIterator)
return _find_next_complete_tree(iter.solver, pq, iter)
end


# """
# Base.iterate(iter::TopDownIterator, pq::DataStructures.PriorityQueue)

# Describes the iteration for a given [`TopDownIterator`](@ref) and a [`PriorityQueue`](@ref) over the grammar without enqueueing new items to the priority queue. Recursively returns the result for the priority queue.
# """
# function Base.iterate(iter::TopDownIterator, pq::DataStructures.PriorityQueue)
# solver, max_depth, max_size = iter.solver, iter.max_depth, iter.max_size

# return _find_next_complete_tree(solver, max_depth, max_size, pq, iter)
# end

"""
Base.iterate(iter::TopDownIterator, pq::DataStructures.PriorityQueue)

Expand Down Expand Up @@ -245,12 +247,12 @@ function _find_next_complete_tree(
)#::Union{Tuple{RuleNode, Tuple{Vector{AbstractRuleNode}, PriorityQueue}}, Nothing} #@TODO Fix this comment
while length(pq) ≠ 0
(item, priority_value) = dequeue_pair!(pq)
if item isa UniformSolver
if item isa UniformIterator
#the item is a fixed shaped solver, we should get the next solution and re-enqueue it with a new priority value
fixed_shaped_solver = item
solution = next_solution!(fixed_shaped_solver)
uniform_iterator = item
solution = next_solution!(uniform_iterator)
if !isnothing(solution)
enqueue!(pq, fixed_shaped_solver, priority_function(iter, get_grammar(solver), solution, priority_value, true))
enqueue!(pq, uniform_iterator, priority_function(iter, get_grammar(solver), solution, priority_value, true))
return (solution, pq)
end
elseif item isa SolverState
Expand All @@ -263,10 +265,11 @@ function _find_next_complete_tree(
track!(solver.statistics, "#FixedShapedTrees")
if solver.use_uniformsolver
#TODO: use_uniformsolver should be the default case
fixed_shaped_solver = UniformSolver(get_grammar(solver), get_tree(solver), with_statistics=solver.statistics, derivation_heuristic=derivation_heuristic(iter))
solution = next_solution!(fixed_shaped_solver)
uniform_solver = UniformSolver(get_grammar(solver), get_tree(solver), with_statistics=solver.statistics)
uniform_iterator = UniformIterator(uniform_solver, iter)
solution = next_solution!(uniform_iterator)
if !isnothing(solution)
enqueue!(pq, fixed_shaped_solver, priority_function(iter, get_grammar(solver), solution, priority_value, true))
enqueue!(pq, uniform_iterator, priority_function(iter, get_grammar(solver), solution, priority_value, true))
return (solution, pq)
end
else
Expand Down
157 changes: 157 additions & 0 deletions src/uniform_iterator.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
#Branching constraint, the `StateHole` hole must be filled with rule_index `Int`.
Branch = Tuple{StateHole, Int}

#Shared reference to an empty vector to reduce memory allocations.
NOBRANCHES = Vector{Branch}()

"""
mutable struct UniformIterator

Inner iterator that enumerates all candidate programs of a uniform tree.
- `solver`: the uniform solver.
- `outeriter`: outer iterator that is responsible for producing uniform trees. This field is used to dispatch on the [`derivation_heuristic`](@ref).
- `unvisited_branches`: for each search-node from the root to the current search-node, a list of unviisted branches.
- `nsolutions`: number of solutions found so far.
"""
mutable struct UniformIterator
solver::UniformSolver
outeriter::Union{ProgramIterator, Nothing}
unvisited_branches::Stack{Vector{Branch}}
nsolutions::Int
end

"""
UniformIterator(solver::UniformSolver, outeriter::ProgramIterator)

Constructs a new UniformIterator that traverses solutions of the [`UniformSolver`](@ref) and is an inner iterator of an outer [`ProgramIterator`](@ref).
"""
function UniformIterator(solver::UniformSolver, outeriter::Union{ProgramIterator, Nothing})
iter = UniformIterator(solver, outeriter, Stack{Vector{Branch}}(), 0)
if isfeasible(solver)
# create search-branches for the root search-node
save_state!(solver)
push!(iter.unvisited_branches, generate_branches(iter))
end
return iter
end

"""
Returns a vector of disjoint branches to expand the search tree at its current state.
Example:
```
# pseudo code
Hole(domain=[2, 4, 5], children=[
Hole(domain=[1, 6]),
Hole(domain=[1, 6])
])
```
If we split on the first hole, this function will create three branches.
- `(firsthole, 2)`
- `(firsthole, 4)`
- `(firsthole, 5)`
"""
function generate_branches(iter::UniformIterator)::Vector{Branch}
@assert isfeasible(iter.solver)
function _dfs(node::Union{StateHole, RuleNode})
if node isa StateHole && size(node.domain) > 1
#skip the derivation_heuristic if the parent_iterator is not set up
if isnothing(iter.outeriter)
return [(node, rule) for rule ∈ node.domain]
end
#reversing is needed because we pop and consider the rightmost branch first
return reverse!([(node, rule) for rule ∈ derivation_heuristic(iter.outeriter, findall(node.domain))])
end
for child ∈ node.children
branches = _dfs(child)
if !isempty(branches)
return branches
end
end
return NOBRANCHES
end
return _dfs(get_tree(iter.solver))
end

"""
next_solution!(iter::UniformIterator)::Union{RuleNode, StateHole, Nothing}

Searches for the next unvisited solution.
Returns nothing if all solutions have been found already.
"""
function next_solution!(iter::UniformIterator)::Union{RuleNode, StateHole, Nothing}
solver = iter.solver
if iter.nsolutions == 1000000 @warn "UniformSolver is iterating over more than 1000000 solutions..." end
ReubenJ marked this conversation as resolved.
Show resolved Hide resolved
if iter.nsolutions > 0
# backtrack from the previous solution
restore!(solver)
end
while length(iter.unvisited_branches) > 0
branches = first(iter.unvisited_branches)
if length(branches) > 0
# current depth has unvisted branches, pick a branch to explore
(hole, rule) = pop!(branches)
save_state!(solver)
remove_all_but!(solver, solver.node_to_path[hole], rule)
if isfeasible(solver)
# generate new branches for the new search node
branches = generate_branches(iter)
if length(branches) == 0
# search node is a solution leaf node, return the solution
iter.nsolutions += 1
track!(solver.statistics, "#CompleteTrees")
return solver.tree
else
# search node is an (non-root) internal node, store the branches to visit
track!(solver.statistics, "#InternalSearchNodes")
push!(iter.unvisited_branches, branches)
end
else
# search node is an infeasible leaf node, backtrack
track!(solver.statistics, "#InfeasibleTrees")
restore!(solver)
end
else
# search node is an exhausted internal node, backtrack
restore!(solver)
pop!(iter.unvisited_branches)
end
end
if iter.nsolutions == 0 && isfeasible(solver)
_isfilledrecursive(node) = isfilled(node) && all(_isfilledrecursive(c) for c ∈ node.children)
if _isfilledrecursive(solver.tree)
# search node is the root and the only solution, return the solution.
iter.nsolutions += 1
track!(solver.statistics, "#CompleteTrees")
return solver.tree
end
end
return nothing
end

"""
Base.length(iter::UniformIterator)

Counts and returns the number of programs without storing all the programs.
!!! warning: modifies and exhausts the iterator
"""
function Base.length(iter::UniformIterator)
count = 0
s = next_solution!(iter)
while !isnothing(s)
count += 1
s = next_solution!(iter)
end
return count
end

Base.eltype(::UniformIterator) = Union{RuleNode, StateHole}

function Base.iterate(iter::UniformIterator)
solution = next_solution!(iter)
if solution
return solution, nothing
end
return nothing
end

Base.iterate(iter::UniformIterator, _) = iterate(iter)
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Random.seed!(1234)
include("test_genetic.jl")
include("test_programiterator_macro.jl")

include("test_uniform_iterator.jl")
include("test_forbidden.jl")
include("test_ordered.jl")
include("test_contains.jl")
Expand Down
2 changes: 1 addition & 1 deletion test/test_contains.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@ using HerbCore, HerbGrammar, HerbConstraints

# There are 5! = 120 permutations of 5 distinct elements
iter = BFSIterator(grammar, :Permutation, solver=GenericSolver(grammar, :Permutation))
@test length(collect(iter)) == 120
@test length(iter) == 120
end
end
Loading
Loading