Skip to content

Commit

Permalink
Merge changed with 'dev'
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolaefilat committed May 3, 2024
2 parents 27af9cd + c8cdb1c commit 67b1a2d
Show file tree
Hide file tree
Showing 13 changed files with 428 additions and 100 deletions.
7 changes: 4 additions & 3 deletions src/HerbSearch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ using MLStyle
include("sampling_grammar.jl")

include("program_iterator.jl")
include("count_expressions.jl")
include("uniform_iterator.jl")

include("heuristics.jl")

Expand All @@ -39,7 +39,6 @@ include("genetic_search_iterator.jl")
include("random_iterator.jl")

export
count_expressions,
ProgramIterator,
@programiterator,

Expand All @@ -57,7 +56,9 @@ export
optimal_program,
suboptimal_program,

FixedShapedIterator,
FixedShapedIterator, #TODO: deprecated after the cp thesis
UniformIterator,
next_solution!,

TopDownIterator,
RandomIterator,
Expand Down
20 changes: 0 additions & 20 deletions src/count_expressions.jl

This file was deleted.

2 changes: 1 addition & 1 deletion src/fixed_shaped_iterator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ function Base.iterate(iter::FixedShapedIterator)
pq :: PriorityQueue{SolverState, Union{Real, Tuple{Vararg{Real}}}} = PriorityQueue()

solver = iter.solver
@assert !contains_variable_shaped_hole(get_tree(iter.solver)) "A FixedShapedIterator cannot iterate partial programs with Holes"
@assert !contains_nonuniform_hole(get_tree(iter.solver)) "A FixedShapedIterator cannot iterate partial programs with Holes"

if isfeasible(solver)
enqueue!(pq, get_state(solver), priority_function(iter, get_grammar(solver), get_tree(solver), 0))
Expand Down
16 changes: 16 additions & 0 deletions src/program_iterator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,22 @@ Base.IteratorSize(::ProgramIterator) = Base.SizeUnknown()

Base.eltype(::ProgramIterator) = Union{RuleNode,StateHole}


"""
Base.length(iter::ProgramIterator)
Counts and returns the number of possible programs without storing all the programs.
!!! warning: modifies and exhausts the iterator
"""
function Base.length(iter::ProgramIterator)
l = 0
for _ iter
l += 1
end
return l
end


"""
@programiterator
Expand Down
23 changes: 7 additions & 16 deletions src/stochastic_functions/propose.jl
Original file line number Diff line number Diff line change
@@ -1,37 +1,28 @@
"""
The propose functions return the fully constructed proposed programs.
The propose functions return the fully constructed proposed programs given a path to a location to fill in.
"""

"""
random_fill_propose(current_program::RuleNode, neighbourhood_node_loc::NodeLoc, grammar::AbstractGrammar, max_depth::Int, dmap::AbstractVector{Int}, dict::Union{Nothing,Dict{String,Any}})
random_fill_propose(solver::Solver, path::Vector{Int}, dict::Union{Nothing,Dict{String,Any}}, nr_random=5)
Returns a list with only one proposed, completely random, subprogram.
# Arguments
- `current_program::RuleNode`: the current program.
- `neighbourhood_node_loc::NodeLoc`: the location of the program to replace.
- `grammar::AbstractGrammar`: the grammar used to create programs.
- `max_depth::Int`: the maximum depth of the resulting programs.
- `dmap::AbstractVector{Int} : the minimum possible depth to reach for each rule`
- `solver::solver`: solver
- `path::Vector{Int}`: path to the location to be filled.
- `dict::Dict{String, Any}`: the dictionary with additional arguments; not used.
- `nr_random`=1 : the number of random subprograms to be generated.
"""
#TODO: Update documentation with correct function signature
function random_fill_propose(solver::Solver, path::Vector{Int}, dict::Union{Nothing,Dict{String,Any}})
return Iterators.take(RandomSearchIterator(solver, path),5)
function random_fill_propose(solver::Solver, path::Vector{Int}, dict::Union{Nothing,Dict{String,Any}}, nr_random=1)
return Iterators.take(RandomSearchIterator(solver, path), nr_random)
end

"""
enumerate_neighbours_propose(enumeration_depth::Int64)
The return function is a function that produces a list with all the subprograms with depth at most `enumeration_depth`.
# Arguments
- `enumeration_depth::Int64`: the maximum enumeration depth.
"""
# TODO: Refactor to not return functions
# TODO: Update documentation with correct function signature
function enumerate_neighbours_propose(enumeration_depth::Int64)
return (solver::Solver, path::Vector{Int}, dict::Union{Nothing,Dict{String,Any}}) -> begin
#TODO use the rule subset from the dict variable
#BFSIterator(solver, allowed_rules = dict[:rule_subset])
return BFSIterator(solver)
end
end
Expand Down
1 change: 0 additions & 1 deletion src/stochastic_iterator.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
using Random

#TODO: Update documentation with correct function signatures!
"""
abstract type StochasticSearchIterator <: ProgramIterator
Expand Down
65 changes: 34 additions & 31 deletions src/top_down_iterator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,13 @@ function priority_function(
end

"""
function derivation_heuristic(::TopDownIterator)
function derivation_heuristic(::TopDownIterator, indices::Vector{Int})
Returns a sorted sublist of the `indices`, based on which rules are most promising to fill a hole.
By default, this is the identity function.
"""
function derivation_heuristic(::TopDownIterator)
return function (indices)
return indices;
end
function derivation_heuristic(::TopDownIterator, indices::Vector{Int})
return indices;
end

"""
Expand Down Expand Up @@ -76,14 +74,12 @@ function priority_function(
end

"""
function derivation_heuristic(::RandomIterator)
function derivation_heuristic(::RandomIterator, indices::Vector{Int})
Randomly shuffles the rules.
"""
function derivation_heuristic(::RandomIterator)
return function (indices)
return Random.shuffle!(indices);
end
function derivation_heuristic(::RandomIterator, indices::Vector{Int})
return Random.shuffle!(indices);
end


Expand Down Expand Up @@ -174,14 +170,32 @@ Currently, there are two possible causes of the expansion failing:
"""
@enum ExpandFailureReason limit_reached=1 already_complete=2


"""
function Base.collect(iter::TopDownIterator)
Return an array of all programs in the TopDownIterator.
!!! warning
This requires deepcopying programs from type StateHole to type RuleNode.
If it is not needed to save all programs, iterate over the iterator manually.
"""
function Base.collect(iter::TopDownIterator)
@warn "Collecting all programs of a TopDownIterator requires freeze_state"
programs = Vector{RuleNode}()
for program iter
push!(programs, freeze_state(program))
end
return programs
end

"""
Base.iterate(iter::TopDownIterator)
Describes the iteration for a given [`TopDownIterator`](@ref) over the grammar. The iteration constructs a [`PriorityQueue`](@ref) first and then prunes it propagating the active constraints. Recursively returns the result for the priority queue.
"""
function Base.iterate(iter::TopDownIterator)
# Priority queue with `SolverState`s (for variable shaped trees) and `UniformSolver`s (for fixed shaped trees)
pq :: PriorityQueue{Union{SolverState, UniformSolver}, Union{Real, Tuple{Vararg{Real}}}} = PriorityQueue()
# Priority queue with `SolverState`s (for variable shaped trees) and `UniformIterator`s (for fixed shaped trees)
pq :: PriorityQueue{Union{SolverState, UniformIterator}, Union{Real, Tuple{Vararg{Real}}}} = PriorityQueue()

solver = iter.solver

Expand All @@ -191,18 +205,6 @@ function Base.iterate(iter::TopDownIterator)
return _find_next_complete_tree(iter.solver, pq, iter)
end


# """
# Base.iterate(iter::TopDownIterator, pq::DataStructures.PriorityQueue)

# Describes the iteration for a given [`TopDownIterator`](@ref) and a [`PriorityQueue`](@ref) over the grammar without enqueueing new items to the priority queue. Recursively returns the result for the priority queue.
# """
# function Base.iterate(iter::TopDownIterator, pq::DataStructures.PriorityQueue)
# solver, max_depth, max_size = iter.solver, iter.max_depth, iter.max_size

# return _find_next_complete_tree(solver, max_depth, max_size, pq, iter)
# end

"""
Base.iterate(iter::TopDownIterator, pq::DataStructures.PriorityQueue)
Expand Down Expand Up @@ -237,12 +239,12 @@ function _find_next_complete_tree(
)#::Union{Tuple{RuleNode, Tuple{Vector{AbstractRuleNode}, PriorityQueue}}, Nothing} #@TODO Fix this comment
while length(pq) 0
(item, priority_value) = dequeue_pair!(pq)
if item isa UniformSolver
if item isa UniformIterator
#the item is a fixed shaped solver, we should get the next solution and re-enqueue it with a new priority value
fixed_shaped_solver = item
solution = next_solution!(fixed_shaped_solver)
uniform_iterator = item
solution = next_solution!(uniform_iterator)
if !isnothing(solution)
enqueue!(pq, fixed_shaped_solver, priority_function(iter, get_grammar(solver), solution, priority_value, true))
enqueue!(pq, uniform_iterator, priority_function(iter, get_grammar(solver), solution, priority_value, true))
return (solution, pq)
end
elseif item isa SolverState
Expand All @@ -255,10 +257,11 @@ function _find_next_complete_tree(
track!(solver.statistics, "#FixedShapedTrees")
if solver.use_uniformsolver
#TODO: use_uniformsolver should be the default case
fixed_shaped_solver = UniformSolver(get_grammar(solver), get_tree(solver), with_statistics=solver.statistics, derivation_heuristic=derivation_heuristic(iter))
solution = next_solution!(fixed_shaped_solver)
uniform_solver = UniformSolver(get_grammar(solver), get_tree(solver), with_statistics=solver.statistics)
uniform_iterator = UniformIterator(uniform_solver, iter)
solution = next_solution!(uniform_iterator)
if !isnothing(solution)
enqueue!(pq, fixed_shaped_solver, priority_function(iter, get_grammar(solver), solution, priority_value, true))
enqueue!(pq, uniform_iterator, priority_function(iter, get_grammar(solver), solution, priority_value, true))
return (solution, pq)
end
else
Expand Down
Loading

0 comments on commit 67b1a2d

Please sign in to comment.