Skip to content

Commit

Permalink
Make GuidedSearchIterator work with multiple nonterminals
Browse files Browse the repository at this point in the history
  • Loading branch information
eErr0Re committed May 19, 2024
1 parent 0cb2514 commit 77438f7
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 19 deletions.
2 changes: 1 addition & 1 deletion src/getting_started.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ examples = [
# IOExample(Dict(:arg => "<Change> <string> to <a> number"), "Change string to a number")
]

iter = HerbSearch.GuidedSearchIterator(grammar, :S, examples, SymbolTable(grammar))
iter = HerbSearch.GuidedSearchIterator(grammar, :S, examples, SymbolTable(grammar), :S)
@profview program = @time probe(examples, iter, 40, 10)
# program = @time probe(examples, iter, 3600, 10000)

Expand Down
30 changes: 17 additions & 13 deletions src/probe/guided_search_iterator.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@

@programiterator GuidedSearchIterator(
spec::Vector{<:IOExample},
symboltable::SymbolTable

symboltable::SymbolTable,
start::Symbol
)
Base.@kwdef mutable struct GuidedSearchState
level::Int64
Expand Down Expand Up @@ -44,21 +44,25 @@ function Base.iterate(iter::GuidedSearchIterator, state::GuidedSearchState)::Uni
# move in advance
state.next_iter = iterate(state.iter, next_state)

# evaluate program
eval_observation = []
expr = rulenode2expr(prog, grammar)
for example iter.spec
output = execute_on_input(iter.symboltable, expr, example.in)
push!(eval_observation, output)
end
# evaluate program if starting symbol
if return_type(grammar, prog.ind) == iter.start
eval_observation = []
expr = rulenode2expr(prog, grammar)
for example iter.spec
output = execute_on_input(iter.symboltable, expr, example.in)
push!(eval_observation, output)
end

if eval_observation in state.eval_cache # program already cached
continue
if eval_observation in state.eval_cache # program already cached
continue
end

push!(state.eval_cache, eval_observation) # add result to cache
push!(state.bank[state.level+1], prog) # add program to bank
return (prog, state) # return program
end

push!(state.bank[state.level+1], prog) # add program to bank
push!(state.eval_cache, eval_observation) # add result to cache
return (prog, state) # return program
end
end
end
19 changes: 17 additions & 2 deletions src/probe/new_program_iterator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,26 @@ function Base.iterate(iter::NewProgramsIterator, state::NewProgramsState)
else
# save current values
children, _ = state.cartesian_iter_state
rulenode = RuleNode(state.rule_index, collect(children))
children = collect(children)

# move to next cartesian
_, next_state = state.cartesian_iter_state
state.cartesian_iter_state = iterate(state.cartesian_iter, next_state)
return rulenode, state

# check if selected programs have the correct type
types = child_types(iter.grammar, state.rule_index)
same_types = true
for i in 1:length(types)
if return_type(iter.grammar, children[i]) != types[i]
same_types = false
break
end
end

if same_types
rulenode = RuleNode(state.rule_index, children)
return rulenode, state
end
end
end
state.rule_index += 1
Expand Down
25 changes: 25 additions & 0 deletions test/test_newprograms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,29 @@ end
[2, 1, 1, 1]
]
end
end
@testset "Test new programs iterator" begin
@testset "Multiple nonterminals" begin
HerbSearch.calculate_rule_cost(rule_index::Int, grammar::ContextSensitiveGrammar) = HerbSearch.calculate_rule_cost_size(rule_index, grammar)
grammar = @pcsgrammar begin
1 : S = A * B
1 : A = "a"
1 : B = "b"
end

bank = [[],[],[],[]]
for i in 1:3
iter = HerbSearch.NewProgramsIterator(i, bank, grammar)
for prog in iter
push!(bank[i+1], prog)
end
end

@test bank == [
[],
[RuleNode(2), RuleNode(3)],
[],
[RuleNode(1, [RuleNode(2), RuleNode(3)])]
]
end
end
37 changes: 34 additions & 3 deletions test/test_probe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ end

@testset "Running using size-based enumeration" begin
HerbSearch.calculate_rule_cost(rule_index::Int, grammar::ContextSensitiveGrammar) = HerbSearch.calculate_rule_cost_size(rule_index, grammar)
iter = HerbSearch.GuidedSearchIterator(grammar, :S, examples, symboltable)
iter = HerbSearch.GuidedSearchIterator(grammar, :S, examples, symboltable, :S)

max_level = 10
state = nothing
Expand All @@ -173,7 +173,7 @@ end
]

HerbSearch.calculate_rule_cost(rule_index::Int, grammar::ContextSensitiveGrammar) = HerbSearch.calculate_rule_cost_prob(rule_index, grammar)
iter = HerbSearch.GuidedSearchIterator(grammar, :S, examples, symboltable)
iter = HerbSearch.GuidedSearchIterator(grammar, :S, examples, symboltable, :S)

max_level = 20
state = nothing
Expand All @@ -188,6 +188,37 @@ end
sizes = [length(level) for level in state.bank]
@test sizes == [0, 0, 4, 0, 0, 0, 0, 0, 15, 0, 0, 0, 0, 0, 122, 0, 0, 0, 0, 0, 1305, 0, 0, 0, 0, 0, 1]
end

@testset "Multiple nonterminals" begin
grammar = @pcsgrammar begin
1 : A = 1
1 : A = A - B
1 : B = 2
1 : C = A + B
end

examples = [
IOExample(Dict(), 5)
]

HerbSearch.calculate_rule_cost(rule_index::Int, grammar::ContextSensitiveGrammar) = HerbSearch.calculate_rule_cost_size(rule_index, grammar)
iter = HerbSearch.GuidedSearchIterator(grammar, :A, examples, SymbolTable(grammar), :A)

progs = []

state = nothing
next = iterate(iter)
while next !== nothing
prog, state = next
push!(progs, prog)
if (state.level > 1)
break
end
next = iterate(iter, state)
end

@test progs == [RuleNode(1), RuleNode(2, [RuleNode(1), RuleNode(3)])]
end
end

@testset "Running probe" begin
Expand Down Expand Up @@ -222,7 +253,7 @@ end
HerbSearch.select_partial_solution(partial_sols::Vector{ProgramCache}, all_selected_psols::Set{ProgramCache}) = select_func(partial_sols, all_selected_psols)

deep_copy_grammar = deepcopy(grammar_to_use)
iter = HerbSearch.GuidedSearchIterator(deep_copy_grammar, :S, examples, symboltable)
iter = HerbSearch.GuidedSearchIterator(deep_copy_grammar, :S, examples, symboltable, :S)
max_time = 5
runtime = @timed program = probe(examples, iter, max_time, 100)
expression = rulenode2expr(program, grammar_to_use)
Expand Down

0 comments on commit 77438f7

Please sign in to comment.