Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FixedShapedSolver #76

Merged
merged 22 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
d26d98d
Add the Solver as an optional argument to the program iterator
Whebon Feb 25, 2024
c2f5225
Update heuristics to only search for VariableShapedHoles
Whebon Feb 25, 2024
4ad771f
Rewrite TopDownIteration for the Solver
Whebon Feb 25, 2024
a35ea2d
Move the creation of the Solver outside the iterator
Whebon Feb 26, 2024
696c54e
Add basic implementation of a `FixedShapedIterator`
Whebon Mar 1, 2024
50cc556
Add a test for the new Forbidden constraint
Whebon Mar 2, 2024
211e517
check if the solver state is still feasible after a tree manipulation
Whebon Mar 5, 2024
c45e9ce
Reduce the number of `save_state!` calls
Whebon Mar 6, 2024
6e8f6f9
Track the number of fixed shaped trees
Whebon Mar 7, 2024
6386396
Add tests for searches with the `Ordered` constraint
Whebon Mar 9, 2024
51aa284
Add tests for Forbidden
Whebon Mar 9, 2024
5f0b17f
Enable old tests
Whebon Mar 11, 2024
4a5dfec
Remove `test_context_sensitive_iterators`
Whebon Mar 11, 2024
ef2b51b
Rename `Solver` to `GenericSolver`
Whebon Mar 14, 2024
6639d3c
Add the `FixedShapedSolver` to the top down iterator
Whebon Mar 15, 2024
1d8296f
Pass the `SolverStatistics` object to the inner solver
Whebon Mar 18, 2024
f44eb0f
Add a test for a DomainRuleNode in a Forbidden Constraint
Whebon Mar 22, 2024
7f9a199
Add a test for a DomainRuleNode in an Ordered constraint
Whebon Mar 26, 2024
74592e4
Add a test for the `Contains` constraint
Whebon Mar 26, 2024
fcf8adc
Rename is_feasible
Whebon Mar 27, 2024
50238da
Grammar -> AbstractGrammar; Update version
THinnerichs Apr 3, 2024
c7d4b3a
Merge branch 'solver' into dfs-solver
THinnerichs Apr 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "HerbSearch"
uuid = "3008d8e8-f9aa-438a-92ed-26e9c7b4829f"
authors = ["Sebastijan Dumancic <[email protected]>", "Jaap de Jong <[email protected]>", "Nicolae Filat <[email protected]>", "Piotr Cichoń <[email protected]>", "Tilman Hinnerichs <[email protected]>"]
version = "0.1.1"
version = "0.2.1"

[deps]
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Expand All @@ -17,13 +17,14 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[compat]
DataStructures = "0.17,0.18"
HerbConstraints = "0.1.0"
HerbCore = "0.1.1"
HerbGrammar = "0.1.0"
HerbInterpret = "0.1.0"
HerbSpecification = "0.1.0"
HerbConstraints = "^0.2.0"
HerbCore = "^0.2.0"
HerbGrammar = "^0.2.1"
HerbInterpret = "0.1.2"
HerbSpecification = "^0.1.0"
MLStyle = "^0.4.17"
StatsBase = "0.34"
julia = "1.8"
julia = "^1.8"

[extras]
LegibleLambdas = "f1f30506-32fe-5131-bd72-7c197988f9e5"
Expand Down
3 changes: 3 additions & 0 deletions src/HerbSearch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ include("count_expressions.jl")

include("heuristics.jl")

include("fixed_shaped_iterator.jl")
include("top_down_iterator.jl")

include("evaluate.jl")
Expand Down Expand Up @@ -52,6 +53,8 @@ export
optimal_program,
suboptimal_program,

FixedShapedIterator,

TopDownIterator,
BFSIterator,
DFSIterator,
Expand Down
4 changes: 2 additions & 2 deletions src/count_expressions.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""
count_expressions(grammar::Grammar, max_depth::Int, max_size::Int, sym::Symbol)
count_expressions(grammar::AbstractGrammar, max_depth::Int, max_size::Int, sym::Symbol)

Counts and returns the number of possible expressions of a grammar up to max_depth with start symbol sym.
"""
function count_expressions(grammar::Grammar, max_depth::Int, max_size::Int, sym::Symbol)
function count_expressions(grammar::AbstractGrammar, max_depth::Int, max_size::Int, sym::Symbol)
l = 0
# Calculate length without storing all expressions
for _ ∈ BFSIterator(grammar, sym, max_depth=max_depth, max_size=max_size)
Expand Down
110 changes: 110 additions & 0 deletions src/fixed_shaped_iterator.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
Base.@doc """
@programiterator FixedShapedIterator()

Enumerates all programs that extend from the provided fixed shaped tree.
The [Solver](@ref) is required to be in a state without any [VariableShapedHole](@ref)s
""" FixedShapedIterator
@programiterator FixedShapedIterator()

"""
priority_function(::FixedShapedIterator, g::AbstractGrammar, tree::AbstractRuleNode, parent_value::Union{Real, Tuple{Vararg{Real}}})

Assigns a priority value to a `tree` that needs to be considered later in the search. Trees with the lowest priority value are considered first.

- `g`: The grammar used for enumeration
- `tree`: The tree that is about to be stored in the priority queue
- `parent_value`: The priority value of the parent [`State`](@ref)
"""
function priority_function(
::FixedShapedIterator,
g::AbstractGrammar,
tree::AbstractRuleNode,
parent_value::Union{Real, Tuple{Vararg{Real}}}
)
parent_value + 1;
end


"""
hole_heuristic(::FixedShapedIterator, node::AbstractRuleNode, max_depth::Int)::Union{ExpandFailureReason, HoleReference}

Defines a heuristic over fixed shaped holes. Returns a [`HoleReference`](@ref) once a hole is found.
"""
function hole_heuristic(::FixedShapedIterator, node::AbstractRuleNode, max_depth::Int)::Union{ExpandFailureReason, HoleReference}
return heuristic_leftmost_fixed_shaped_hole(node, max_depth);
end

"""
Base.iterate(iter::FixedShapedIterator)

Describes the iteration for a given [`TopDownIterator`](@ref) over the grammar. The iteration constructs a [`PriorityQueue`](@ref) first and then prunes it propagating the active constraints. Recursively returns the result for the priority queue.
"""
function Base.iterate(iter::FixedShapedIterator)
# Priority queue with number of nodes in the program
pq :: PriorityQueue{State, Union{Real, Tuple{Vararg{Real}}}} = PriorityQueue()

solver = iter.solver
@assert !contains_variable_shaped_hole(get_tree(iter.solver)) "A FixedShapedIterator cannot iterate partial programs with VariableShapedHoles"

if isfeasible(solver)
enqueue!(pq, get_state(solver), priority_function(iter, get_grammar(solver), get_tree(solver), 0))
end
return _find_next_complete_tree(solver, pq, iter)
end


"""
Base.iterate(iter::FixedShapedIterator, pq::DataStructures.PriorityQueue)

Describes the iteration for a given [`TopDownIterator`](@ref) and a [`PriorityQueue`](@ref) over the grammar without enqueueing new items to the priority queue. Recursively returns the result for the priority queue.
"""
function Base.iterate(iter::FixedShapedIterator, pq::DataStructures.PriorityQueue)
return _find_next_complete_tree(iter.solver, pq, iter)
end

"""
_find_next_complete_tree(solver::Solver, pq::PriorityQueue, iter::FixedShapedIterator)::Union{Tuple{RuleNode, PriorityQueue}, Nothing}

Takes a priority queue and returns the smallest AST from the grammar it can obtain from the queue or by (repeatedly) expanding trees that are in the queue.
Returns `nothing` if there are no trees left within the depth limit.
"""
function _find_next_complete_tree(
solver::Solver,
pq::PriorityQueue,
iter::FixedShapedIterator
)::Union{Tuple{RuleNode, PriorityQueue}, Nothing}
while length(pq) ≠ 0
(state, priority_value) = dequeue_pair!(pq)
load_state!(solver, state)

hole_res = hole_heuristic(iter, get_tree(solver), typemax(Int))
if hole_res ≡ already_complete
#the tree is complete
return (get_tree(solver), pq)
elseif hole_res ≡ limit_reached
# The maximum depth is reached
continue
elseif hole_res isa HoleReference
# Fixed Shaped Hole was found
# TODO: problem. this 'hole' is tied to a target state. it should be state independent
(; hole, path) = hole_res

rules = findall(hole.domain)
number_of_rules = length(rules)
for (i, rule_index) ∈ enumerate(findall(hole.domain))
if i < number_of_rules
state = save_state!(solver)
end
@assert isfeasible(solver) "Attempting to expand an infeasible tree: $(get_tree(solver))"
remove_all_but!(solver, path, rule_index)
if isfeasible(solver)
enqueue!(pq, get_state(solver), priority_function(iter, get_grammar(solver), get_tree(solver), priority_value))
end
if i < number_of_rules
load_state!(solver, state)
end
end
end
end
return nothing
end
4 changes: 2 additions & 2 deletions src/genetic_functions/mutation.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""
mutate_random!(program::RuleNode, grammar::Grammar, max_depth::Int64 = 2)
mutate_random!(program::RuleNode, grammar::AbstractGrammar, max_depth::Int64 = 2)

Mutates the given program by inserting a randomly generated sub-program at a random location.
"""
function mutate_random!(program::RuleNode, grammar::Grammar, max_depth::Int64 = 2)
function mutate_random!(program::RuleNode, grammar::AbstractGrammar, max_depth::Int64 = 2)
node_location::NodeLoc = sample(NodeLoc, program)
subprogram = get(program, node_location)
symbol = return_type(grammar, subprogram)
Expand Down
4 changes: 2 additions & 2 deletions src/genetic_search_iterator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ cross_over(::GeneticSearchIterator, parent_1::RuleNode, parent_2::RuleNode) = cr


"""
mutate!(::GeneticSearchIterator, program::RuleNode, grammar::Grammar, max_depth::Int = 2)
mutate!(::GeneticSearchIterator, program::RuleNode, grammar::AbstractGrammar, max_depth::Int = 2)

Mutates the program of an invididual.
"""
mutate!(::GeneticSearchIterator, program::RuleNode, grammar::Grammar, max_depth::Int = 2) = mutate_random!(program, grammar, max_depth)
mutate!(::GeneticSearchIterator, program::RuleNode, grammar::AbstractGrammar, max_depth::Int = 2) = mutate_random!(program, grammar, max_depth)

"""
select_parents(::GeneticSearchIterator, population::Array{RuleNode}, fitness_array::Array{<:Real})
Expand Down
45 changes: 37 additions & 8 deletions src/heuristics.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,42 @@
using Random

"""
heuristic_leftmost_fixed_shaped_hole(node::AbstractRuleNode, max_depth::Int)::Union{ExpandFailureReason, HoleReference}

Defines a heuristic over [FixedShapeHole](@ref)s, where the left-most hole always gets considered first. Returns a [`HoleReference`](@ref) once a hole is found. This is the default option for enumerators.
"""
function heuristic_leftmost_fixed_shaped_hole(node::AbstractRuleNode, max_depth::Int)::Union{ExpandFailureReason, HoleReference}
function leftmost(node::AbstractRuleNode, max_depth::Int, path::Vector{Int})::Union{ExpandFailureReason, HoleReference}
if max_depth == 0 return limit_reached end

for (i, child) in enumerate(node.children)
new_path = push!(copy(path), i)
hole_res = leftmost(child, max_depth-1, new_path)
if (hole_res == limit_reached) || (hole_res isa HoleReference)
return hole_res
end
end

return already_complete
end

#TODO: refactor this. this method should be merged with `heuristic_leftmost`. The only difference is the `FixedShapedHole` typing in the signature below:
function leftmost(hole::FixedShapedHole, max_depth::Int, path::Vector{Int})::Union{ExpandFailureReason, HoleReference}
if max_depth == 0 return limit_reached end
return HoleReference(hole, path)
end

return leftmost(node, max_depth, Vector{Int}())
end


"""
heuristic_leftmost(node::AbstractRuleNode, max_depth::Int)::Union{ExpandFailureReason, HoleReference}

Defines a heuristic over holes, where the left-most hole always gets considered first. Returns a [`HoleReference`](@ref) once a hole is found. This is the default option for enumerators.
"""
function heuristic_leftmost(node::AbstractRuleNode, max_depth::Int)::Union{ExpandFailureReason, HoleReference}
function leftmost(node::RuleNode, max_depth::Int, path::Vector{Int})::Union{ExpandFailureReason, HoleReference}
function leftmost(node::AbstractRuleNode, max_depth::Int, path::Vector{Int})::Union{ExpandFailureReason, HoleReference}
if max_depth == 0 return limit_reached end

for (i, child) in enumerate(node.children)
Expand All @@ -21,7 +50,7 @@ function heuristic_leftmost(node::AbstractRuleNode, max_depth::Int)::Union{Expan
return already_complete
end

function leftmost(hole::Hole, max_depth::Int, path::Vector{Int})::Union{ExpandFailureReason, HoleReference}
function leftmost(hole::VariableShapedHole, max_depth::Int, path::Vector{Int})::Union{ExpandFailureReason, HoleReference}
if max_depth == 0 return limit_reached end
return HoleReference(hole, path)
end
Expand All @@ -35,7 +64,7 @@ end
Defines a heuristic over holes, where the right-most hole always gets considered first. Returns a [`HoleReference`](@ref) once a hole is found.
"""
function heuristic_rightmost(node::AbstractRuleNode, max_depth::Int)::Union{ExpandFailureReason, HoleReference}
function rightmost(node::RuleNode, max_depth::Int, path::Vector{Int})::Union{ExpandFailureReason, HoleReference}
function rightmost(node::AbstractRuleNode, max_depth::Int, path::Vector{Int})::Union{ExpandFailureReason, HoleReference}
if max_depth == 0 return limit_reached end

for (i, child) in Iterators.reverse(enumerate(node.children))
Expand All @@ -49,7 +78,7 @@ function heuristic_rightmost(node::AbstractRuleNode, max_depth::Int)::Union{Expa
return already_complete
end

function rightmost(hole::Hole, max_depth::Int, path::Vector{Int})::Union{ExpandFailureReason, HoleReference}
function rightmost(hole::VariableShapedHole, max_depth::Int, path::Vector{Int})::Union{ExpandFailureReason, HoleReference}
if max_depth == 0 return limit_reached end
return HoleReference(hole, path)
end
Expand All @@ -64,7 +93,7 @@ end
Defines a heuristic over holes, where random holes get chosen randomly using random exploration. Returns a [`HoleReference`](@ref) once a hole is found.
"""
function heuristic_random(node::AbstractRuleNode, max_depth::Int)::Union{ExpandFailureReason, HoleReference}
function random(node::RuleNode, max_depth::Int, path::Vector{Int})::Union{ExpandFailureReason, HoleReference}
function random(node::AbstractRuleNode, max_depth::Int, path::Vector{Int})::Union{ExpandFailureReason, HoleReference}
if max_depth == 0 return limit_reached end

for (i, child) in shuffle(collect(enumerate(node.children)))
Expand All @@ -78,7 +107,7 @@ function heuristic_random(node::AbstractRuleNode, max_depth::Int)::Union{ExpandF
return already_complete
end

function random(hole::Hole, max_depth::Int, path::Vector{Int})::Union{ExpandFailureReason, HoleReference}
function random(hole::VariableShapedHole, max_depth::Int, path::Vector{Int})::Union{ExpandFailureReason, HoleReference}
if max_depth == 0 return limit_reached end
return HoleReference(hole, path)
end
Expand All @@ -92,7 +121,7 @@ end
Defines a heuristic over all available holes in the unfinished AST, by considering the size of their respective domains. A domain here describes the number of possible derivations with respect to the constraints. Returns a [`HoleReference`](@ref) once a hole is found.
"""
function heuristic_smallest_domain(node::AbstractRuleNode, max_depth::Int)::Union{ExpandFailureReason, HoleReference}
function smallest_domain(node::RuleNode, max_depth::Int, path::Vector{Int})::Union{ExpandFailureReason, HoleReference}
function smallest_domain(node::AbstractRuleNode, max_depth::Int, path::Vector{Int})::Union{ExpandFailureReason, HoleReference}
if max_depth == 0 return limit_reached end

smallest_size::Int = typemax(Int)
Expand All @@ -119,7 +148,7 @@ function heuristic_smallest_domain(node::AbstractRuleNode, max_depth::Int)::Unio
return smallest_result
end

function smallest_domain(hole::Hole, max_depth::Int, path::Vector{Int})::Union{ExpandFailureReason, HoleReference}
function smallest_domain(hole::VariableShapedHole, max_depth::Int, path::Vector{Int})::Union{ExpandFailureReason, HoleReference}
if max_depth == 0 return limit_reached end
return HoleReference(hole, path)
end
Expand Down
9 changes: 6 additions & 3 deletions src/program_iterator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ abstract type ProgramIterator end

Base.IteratorSize(::ProgramIterator) = Base.SizeUnknown()

Base.eltype(::ProgramIterator) = RuleNode
#TODO: currently, ProgramIterator will not create `StateFixedShapedHole` yet, but this should be possible
Base.eltype(::ProgramIterator) = Union{RuleNode, StateFixedShapedHole}

"""
@programiterator
Expand Down Expand Up @@ -49,7 +50,7 @@ macro programiterator(ex)
generate_iterator(__module__, ex)
end

function generate_iterator(mod::Module, ex::Expr, mut::Bool=false)
function generate_iterator(mod::Module, ex::Expr, mut::Bool=true)
Base.remove_linenums!(ex)

@match ex begin
Expand All @@ -71,7 +72,8 @@ processdecl(mod::Module, mut::Bool, decl::Expr, super=nothing) = @match decl beg
Expr(:kw, :(max_depth::Int), typemax(Int)),
Expr(:kw, :(max_size::Int), typemax(Int)),
Expr(:kw, :(max_time::Int), typemax(Int)),
Expr(:kw, :(max_enumerations::Int), typemax(Int))
Expr(:kw, :(max_enumerations::Int), typemax(Int)),
Expr(:kw, :(solver::Union{Solver, Nothing}), nothing)
]

head = Expr(:(<:), name, isnothing(super) ? :(HerbSearch.ProgramIterator) : :($mod.$super))
Expand All @@ -82,6 +84,7 @@ processdecl(mod::Module, mut::Bool, decl::Expr, super=nothing) = @match decl beg
max_size::Int
max_time::Int
max_enumerations::Int
solver::Union{Solver, Nothing}
end)

map!(ex -> processkwarg!(kwargs, ex), extrafields, extrafields)
Expand Down
Loading