Skip to content

Commit

Permalink
Avoid evaluating twice
Browse files Browse the repository at this point in the history
  • Loading branch information
eErr0Re committed May 22, 2024
1 parent b0a5796 commit 2ae3dd9
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 29 deletions.
5 changes: 4 additions & 1 deletion src/minecraft/getting_started_minerl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ function evaluate_trace_minerl(prog, grammar, env, show_moves)
break
end
end
if is_done
break

Check warning on line 64 in src/minecraft/getting_started_minerl.jl

View check run for this annotation

Codecov / codecov/patch

src/minecraft/getting_started_minerl.jl#L63-L64

Added lines #L63 - L64 were not covered by tests
end
end
println("Reward $sum_of_rewards")
if sum_of_rewards <= 0.2
Expand All @@ -78,7 +81,7 @@ function HerbSearch.set_env_position(x, y, z)
set_start_xyz(x, y, z)

Check warning on line 81 in src/minecraft/getting_started_minerl.jl

View check run for this annotation

Codecov / codecov/patch

src/minecraft/getting_started_minerl.jl#L79-L81

Added lines #L79 - L81 were not covered by tests
end
# overwrite the evaluate trace function
HerbSearch.evaluate_trace(prog::RuleNode, grammar::ContextSensitiveGrammar; show_moves = false) = evaluate_trace_minerl(prog, grammar, env, show_moves)
HerbSearch.evaluate_trace(prog::RuleNode, grammar::ContextSensitiveGrammar; show_moves=false) = evaluate_trace_minerl(prog, grammar, env, show_moves)

Check warning on line 84 in src/minecraft/getting_started_minerl.jl

View check run for this annotation

Codecov / codecov/patch

src/minecraft/getting_started_minerl.jl#L84

Added line #L84 was not covered by tests
HerbSearch.calculate_rule_cost(rule_index::Int, grammar::ContextSensitiveGrammar) = HerbSearch.calculate_rule_cost_prob(rule_index, grammar)

# resetEnv()
Expand Down
11 changes: 6 additions & 5 deletions src/probe/guided_search_iterator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Base.@kwdef mutable struct GuidedSearchState
bank::Vector{Vector{RuleNode}}
eval_cache::Set
iter::NewProgramsIterator
next_iter::Union{Tuple{RuleNode, NewProgramsState}, Nothing}
next_iter::Union{Tuple{RuleNode,NewProgramsState},Nothing}
end

function Base.iterate(iter::GuidedSearchIterator)
Expand All @@ -21,7 +21,7 @@ function Base.iterate(iter::GuidedSearchIterator)
))
end

function Base.iterate(iter::GuidedSearchIterator, state::GuidedSearchState)::Union{Tuple{RuleNode, GuidedSearchState}, Nothing}
function Base.iterate(iter::GuidedSearchIterator, state::GuidedSearchState)::Union{Tuple{Tuple{RuleNode,Vector{Any}},GuidedSearchState},Nothing}
grammar = get_grammar(iter.solver)
start_symbol = get_starting_symbol(iter.solver)
# wrap in while true to optimize for tail call
Expand All @@ -30,7 +30,7 @@ function Base.iterate(iter::GuidedSearchIterator, state::GuidedSearchState)::Uni
state.level += 1
push!(state.bank, [])

state.iter = NewProgramsIterator(state.level, state.bank, grammar)
state.iter = NewProgramsIterator(state.level, state.bank, grammar)
state.next_iter = iterate(state.iter)
if state.level > 0
@info ("Finished level $(state.level - 1) with $(length(state.bank[state.level])) programs")
Expand All @@ -56,10 +56,11 @@ function Base.iterate(iter::GuidedSearchIterator, state::GuidedSearchState)::Uni
if eval_observation in state.eval_cache # program already cached
continue
end

push!(state.eval_cache, eval_observation) # add result to cache
push!(state.bank[state.level+1], prog) # add program to bank
return (prog, state) # return program

return ((prog, eval_observation), state) # return program
end

push!(state.bank[state.level+1], prog) # add program to bank
Expand Down
9 changes: 5 additions & 4 deletions src/probe/guided_trace_search_iterator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ function Base.iterate(iter::GuidedTraceSearchIterator)
))
end

function Base.iterate(iter::GuidedTraceSearchIterator, state::GuidedSearchState)::Union{Tuple{RuleNode, GuidedSearchState}, Nothing}
function Base.iterate(iter::GuidedTraceSearchIterator, state::GuidedSearchState)

Check warning on line 14 in src/probe/guided_trace_search_iterator.jl

View check run for this annotation

Codecov / codecov/patch

src/probe/guided_trace_search_iterator.jl#L14

Added line #L14 was not covered by tests
grammar = get_grammar(iter.solver)
start_symbol = get_starting_symbol(iter.solver)
# wrap in while true to optimize for tail call
Expand All @@ -20,7 +20,7 @@ function Base.iterate(iter::GuidedTraceSearchIterator, state::GuidedSearchState)
state.level += 1
push!(state.bank, [])

state.iter = NewProgramsIterator(state.level, state.bank, grammar)
state.iter = NewProgramsIterator(state.level, state.bank, grammar)

Check warning on line 23 in src/probe/guided_trace_search_iterator.jl

View check run for this annotation

Codecov / codecov/patch

src/probe/guided_trace_search_iterator.jl#L23

Added line #L23 was not covered by tests
state.next_iter = iterate(state.iter)
if state.level > 0
@info ("Finished level $(state.level - 1) with $(length(state.bank[state.level])) programs")
Expand All @@ -41,10 +41,11 @@ function Base.iterate(iter::GuidedTraceSearchIterator, state::GuidedSearchState)
# print("Skipping this.")
continue
end

push!(state.eval_cache, eval_observation) # add result to cache
push!(state.bank[state.level+1], prog) # add program to bank
return (prog, state) # return program

return ((prog, (eval_observation, is_done, final_reward)), state) # return program

Check warning on line 48 in src/probe/guided_trace_search_iterator.jl

View check run for this annotation

Codecov / codecov/patch

src/probe/guided_trace_search_iterator.jl#L48

Added line #L48 was not covered by tests
end

push!(state.bank[state.level+1], prog) # add program to bank
Expand Down
43 changes: 29 additions & 14 deletions src/probe/probe_iterator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ include("update_grammar.jl")
select_partial_solution(partial_sols::Vector{ProgramCache}, all_selected_psols::Set{ProgramCache}) = HerbSearch.selectpsol_largest_subset(partial_sols, all_selected_psols)
update_grammar!(grammar::ContextSensitiveGrammar, PSols_with_eval_cache::Vector{ProgramCache}, examples::Vector{<:IOExample}) = update_grammar(grammar, PSols_with_eval_cache, examples)

get_prog_eval(::ProgramIterator, prog::RuleNode) = (prog, [])

Check warning on line 39 in src/probe/probe_iterator.jl

View check run for this annotation

Codecov / codecov/patch

src/probe/probe_iterator.jl#L39

Added line #L39 was not covered by tests

get_prog_eval(::GuidedSearchIterator, prog::Tuple{RuleNode,Vector{Any}}) = prog

get_prog_eval(::GuidedTraceSearchIterator, prog::Tuple{RuleNode,Tuple{Any,Bool,Number}}) = prog

Check warning on line 43 in src/probe/probe_iterator.jl

View check run for this annotation

Codecov / codecov/patch

src/probe/probe_iterator.jl#L43

Added line #L43 was not covered by tests

"""
probe(examples::Vector{<:IOExample}, iterator::ProgramIterator, max_time::Int, iteration_size::Int)
Expand All @@ -60,15 +66,23 @@ function probe(examples::Vector{<:IOExample}, iterator::ProgramIterator, max_tim
program, state = next

# evaluate program
eval_observation = []
program, eval_observation = get_prog_eval(iterator, program)
correct_examples = Vector{Int}()
expr = rulenode2expr(program, grammar)
for (example_index, example) enumerate(examples)
output = execute_on_input(symboltable, expr, example.in)
push!(eval_observation, output)

if output == example.out
push!(correct_examples, example_index)
if isempty(eval_observation)
expr = rulenode2expr(program, grammar)
for (example_index, example) enumerate(examples)
output = execute_on_input(symboltable, expr, example.in)
push!(eval_observation, output)

Check warning on line 75 in src/probe/probe_iterator.jl

View check run for this annotation

Codecov / codecov/patch

src/probe/probe_iterator.jl#L72-L75

Added lines #L72 - L75 were not covered by tests

if output == example.out
push!(correct_examples, example_index)

Check warning on line 78 in src/probe/probe_iterator.jl

View check run for this annotation

Codecov / codecov/patch

src/probe/probe_iterator.jl#L77-L78

Added lines #L77 - L78 were not covered by tests
end
end

Check warning on line 80 in src/probe/probe_iterator.jl

View check run for this annotation

Codecov / codecov/patch

src/probe/probe_iterator.jl#L80

Added line #L80 was not covered by tests
else
for i in 1:length(eval_observation)
if eval_observation[i] == examples[i].out
push!(correct_examples, i)
end
end
end

Expand Down Expand Up @@ -128,7 +142,7 @@ function select_partial_solution(partial_sols::Vector{ProgramCacheTrace}, all_se
# sort partial solutions by reward
sort!(partial_sols, by=x -> x.reward, rev=true)
to_select = 5
return partial_sols[1 : min(to_select, length(partial_sols))]
return partial_sols[1:min(to_select, length(partial_sols))]

Check warning on line 145 in src/probe/probe_iterator.jl

View check run for this annotation

Codecov / codecov/patch

src/probe/probe_iterator.jl#L143-L145

Added lines #L143 - L145 were not covered by tests
end

"""
Expand Down Expand Up @@ -156,11 +170,12 @@ function probe(traces::Vector{Trace}, iterator::ProgramIterator, max_time::Int,
while next !== nothing && i < iteration_size # run one iteration
program, state = next

# evaluate
eval_observation, is_done, reward = evaluate_trace(program, grammar, show_moves = true)
# evaluate
program, evaluation = get_prog_eval(iterator, program)
eval_observation, is_done, reward = isempty(evaluation) ? evaluate_trace(program, grammar, show_moves=true) : evaluation
is_partial_sol = false
if reward > best_reward
best_reward = reward
if reward > best_reward
best_reward = reward
best_eval_obs = eval_observation
printstyled("Best reward: $best_reward\n", color=:red)
is_partial_sol = true

Check warning on line 181 in src/probe/probe_iterator.jl

View check run for this annotation

Codecov / codecov/patch

src/probe/probe_iterator.jl#L174-L181

Added lines #L174 - L181 were not covered by tests
Expand Down Expand Up @@ -193,7 +208,7 @@ function probe(traces::Vector{Trace}, iterator::ProgramIterator, max_time::Int,
partial_sols = filter(x -> x all_selected_psols, select_partial_solution(psol_with_eval_cache, all_selected_psols))
if !isempty(partial_sols)
printstyled("Restarting!\n", color=:magenta)

Check warning on line 210 in src/probe/probe_iterator.jl

View check run for this annotation

Codecov / codecov/patch

src/probe/probe_iterator.jl#L208-L210

Added lines #L208 - L210 were not covered by tests

# set the player position to the best position so far
set_env_position(best_eval_obs[1], best_eval_obs[2], best_eval_obs[3])

Check warning on line 213 in src/probe/probe_iterator.jl

View check run for this annotation

Codecov / codecov/patch

src/probe/probe_iterator.jl#L213

Added line #L213 was not covered by tests

Expand Down
10 changes: 5 additions & 5 deletions test/test_probe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,10 @@ end

@testset "Multiple nonterminals" begin
grammar = @pcsgrammar begin
1 : A = 1
1 : A = A - B
1 : B = 2
1 : C = A + B
1:A = 1
1:A = A - B
1:B = 2
1:C = A + B
end

examples = [
Expand All @@ -209,7 +209,7 @@ end
next = iterate(iter)
while next !== nothing
prog, state = next
push!(progs, prog)
push!(progs, prog[1])
if (state.level > 1)
break
end
Expand Down

0 comments on commit 2ae3dd9

Please sign in to comment.