diff --git a/src/minecraft/create_minerl_env.jl b/src/minecraft/create_minerl_env.jl index e0bce55..087f14c 100644 --- a/src/minecraft/create_minerl_env.jl +++ b/src/minecraft/create_minerl_env.jl @@ -2,14 +2,68 @@ using PyCall pyimport("minerl") gym = pyimport("gym") +# our_seed = 95812 <- hard env +# our_seed = 958122 +our_seed = 958129 # initial env + +""" + resetEnv() +This function resets the enviornment with the global provided seed and saves the initial X,Y,Z positions. +""" +function resetEnv() + env.seed(our_seed) + obs = env.reset() + global x_player_start, y_player_start, z_player_start = get_xyz_from_env(obs) + print(x_player_start, y_player_start, z_player_start) + + action = env.action_space.noop() + action["forward"] = 1 + + # infinite health + env.set_next_chat_message("/effect @a minecraft:instant_health 1000000 100 true") + env.step(action) + + # infinite food + env.set_next_chat_message("/effect @a minecraft:saturation 1000000 255 true") + env.step(action) + + # disable mobs + env.set_next_chat_message("/gamerule doMobSpawning false") + env.step(action) + env.set_next_chat_message("/kill @e[type=!player]") + env.step(action) + + printstyled("Environment created. x: $x_player_start, y: $y_player_start, z: $z_player_start\n", color=:green) +end + +function set_start_xyz(x, y, z) + global x_player_start = x + global y_player_start = y + global z_player_start = z + println("New x: $x_player_start y: $y_player_start z: $z_player_start pos") +end +function get_xyz_from_env(obs) + return obs["xpos"][1], obs["ypos"][1], obs["zpos"][1] +end + +function resetPosition() + action = env.action_space.noop() + env.set_next_chat_message("/tp @a $(x_player_start) $(y_player_start) $(z_player_start)") + + obs = env.step(action)[1] + obsx, obsy, obsz = get_xyz_from_env(obs) + while obsx != x_player_start || obsy != y_player_start || obsz != z_player_start + obs = env.step(action)[1] + obsx, obsy, obsz = get_xyz_from_env(obs) + end + println((obsx, obsy, obsz)) +end if !@isdefined env printstyled("Creating environment\n", color=:yellow) env = gym.make("MineRLNavigateDenseProgSynth-v0") - env.seed(958129) - obs = env.reset() - x_player_start, y_player_start, z_player_start = obs["xpos"], obs["ypos"], obs["zpos"] - printstyled("Environment created x: $x_player_start, y: $y_player_start, z: $z_player_start\n", color=:green) + resetEnv() + printstyled("Environment created. x: $x_player_start, y: $y_player_start, z: $z_player_start\n", color=:green) else printstyled("Environment already created\n", color=:green) end \ No newline at end of file diff --git a/src/minecraft/getting_started_minerl.jl b/src/minecraft/getting_started_minerl.jl index b8151d3..4136163 100644 --- a/src/minecraft/getting_started_minerl.jl +++ b/src/minecraft/getting_started_minerl.jl @@ -1,150 +1,82 @@ include("create_minerl_env.jl") using HerbGrammar, HerbSpecification using HerbSearch - -os = pyimport("os") -os.environ["LANG"] = "en_US.UTF-8" -# you might get an error here. Just type using HerbSearch in the REPL and run the whole file again. - -# minerl_grammar = @pcsgrammar begin -# 10:action_name = "forward" | "left" | "right" -# 1:action_name = "jump" -# 1:action = Dict("camera" => [0, 0], action_name => 1) -# 1:sequence_actions = [sequence_actions; action] -# 1:sequence_actions = [] -# end - -minerl_grammar_2 = @pcsgrammar begin - 100:F = 1 - 1:F = 0 - 1:L = 1 - 100:L = 0 - 1:R = 0 - 100:R = 1 - 100:J = 0 - 1:J = 1 - 100:B = 1 - 1:B = 0 - 10:sequence_actions = [sequence_actions; action] +using Logging +disable_logging(LogLevel(1)) + +minerl_grammar = @pcsgrammar begin + 1:action_name = "forward" + 1:action_name = "left" + 1:action_name = "right" + 1:action_name = "back" + 1:action_name = "jump" + 1:sequence_actions = [sequence_actions; action] 1:sequence_actions = [] - 10:action= (TIMES, Dict("camera" => [0, 0], "forward" => F, "left" => L, "right" => R, "jump" => J, "back" => B)) - 1:TIMES = 25 | 50 | 75 | 100 | 125 | 150 + 1:action = (TIMES, Dict("camera" => [0, 0], action_name => 1)) + 5:TIMES = 1 | 5 | 25 | 50 | 75 | 100 end -function create_random_program() - return rand(RuleNode, minerl_grammar, :sequence_actions) +minerl_grammar_2 = @pcsgrammar begin + 1:SEQ = [ACT] + 8:DIR = 0b0001 | 0b0010 | 0b0100 | 0b1000 | 0b0101 | 0b1001 | 0b0110 | 0b1010 # forward | back | left | right | forward-left | forward-right | back-left | back-right + 1:ACT = (TIMES, Dict("move" => DIR, "sprint" => 1, "jump" => 1)) + 6:TIMES = 5 | 10 | 25 | 50 | 75 | 100 end -iterations = 0 -best_reward = 0 -function evaluate_trace_minerl(prog, grammar, env) +function evaluate_trace_minerl(prog, grammar, env, show_moves) resetPosition() expr = rulenode2expr(prog, grammar) - # println("expr is: ", expr) - - - is_done = false - is_partial_sol = false sequence_of_actions = eval(expr) - if isempty(sequence_of_actions) - return (0, 0, 0), false, false, 0 - end sum_of_rewards = 0 + is_done = false obs = nothing - for saved_action ∈ sequence_of_actions - times, action = saved_action - + for (times, action) ∈ sequence_of_actions new_action = env.action_space.noop() - for key ∈ keys(action) - new_action[key] = action[key] + for (key, val) in action + if key == "move" + new_action["forward"] = val & 1 + new_action["back"] = val >> 1 & 1 + new_action["left"] = val >> 2 & 1 + new_action["right"] = val >> 3 + else + new_action[key] = val + end end - + for i in 1:times obs, reward, done, _ = env.step(new_action) - env.render() + if show_moves + env.render() + end sum_of_rewards += reward - - if reward > 0 - is_partial_sol = true - end if done - println("Rewards ", sum_of_rewards) is_done = true - printstyled("done\n", color=:green) + printstyled("sum of rewards: $sum_of_rewards. Done\n", color=:green) break end end - end - println("Got reward: ", sum_of_rewards) - global best_reward = max(best_reward, sum_of_rewards) - printstyled("Best reward: $best_reward\n",color=:red) - - eval_observation = (round(obs["xpos"][1], digits=1), round(obs["ypos"][1], digits=1), round(obs["zpos"][1], digits=1)) - println(eval_observation) - return eval_observation, is_done, is_partial_sol, sum_of_rewards -end - -function run_action(action) - obs, reward, done, _ = env.step(action) - env.render() - println("reward: $reward") -end - -function resetEnv() - obs = env.reset() - x_player_start, y_player_start, z_player_start = obs["xpos"], obs["ypos"], obs["zpos"] - printstyled("Environment reset x: $x_player_start, y: $y_player_start, z: $z_player_start\n", color=:green) -end - - -function resetPosition() - action = env.action_space.noop() - env.set_next_chat_message("/tp @a $(x_player_start[1]) $(y_player_start[1]) $(z_player_start[1])") - - obs = env.step(action)[1] - while obs["xpos"] != x_player_start || obs["ypos"] != y_player_start || obs["zpos"] != z_player_start - obs = env.step(action)[1] - end - - env.render() -end - -function run_forward_and_random() - while true - if rand() < 0.8 - new_action = env.action_space.noop() - new_action["forward"] = 1 - # new_action["jump"] = 1 - new_action["spint"] = 1 - run_action(new_action) - else - prog = create_random_program() - evaluate_trace_minerl(prog, minerl_grammar_2, env) + if is_done + break end end - + println("Reward $sum_of_rewards") + return get_xyz_from_env(obs), is_done, sum_of_rewards end -# run_forward_and_random() +# make sure the probabilities are equal +@assert all(prob -> prob == minerl_grammar_2.log_probabilities[begin], minerl_grammar_2.log_probabilities) -# # # overwrite the evaluate trace function -HerbSearch.evaluate_trace(prog::RuleNode, grammar::ContextSensitiveGrammar) = evaluate_trace_minerl(prog, grammar, env) +function HerbSearch.set_env_position(x, y, z) + println("Setting env position: ($x, $y, $z)") + set_start_xyz(x, y, z) +end +# overwrite the evaluate trace function +HerbSearch.evaluate_trace(prog::RuleNode, grammar::ContextSensitiveGrammar; show_moves=false) = evaluate_trace_minerl(prog, grammar, env, show_moves) HerbSearch.calculate_rule_cost(rule_index::Int, grammar::ContextSensitiveGrammar) = HerbSearch.calculate_rule_cost_prob(rule_index, grammar) -iter = HerbSearch.GuidedTraceSearchIterator(minerl_grammar_2, :sequence_actions) -program = probe(Vector{Trace}(), iter, 400000, 100000) - - -# state = nothing -# next = iterate(iter) -# print(next) -# while next !== nothing -# prog, state = next -# if (state.level > 100) -# break -# end -# global next = iterate(iter, state) -# end +# resetEnv() +iter = HerbSearch.GuidedTraceSearchIterator(minerl_grammar_2, :SEQ) +program = @time probe(Vector{Trace}(), iter, 3000000, 6) diff --git a/src/probe/guided_search_iterator.jl b/src/probe/guided_search_iterator.jl index a4fa31b..759ea55 100644 --- a/src/probe/guided_search_iterator.jl +++ b/src/probe/guided_search_iterator.jl @@ -8,7 +8,7 @@ Base.@kwdef mutable struct GuidedSearchState bank::Vector{Vector{RuleNode}} eval_cache::Set iter::NewProgramsIterator - next_iter::Union{Tuple{RuleNode, NewProgramsState}, Nothing} + next_iter::Union{Tuple{RuleNode,NewProgramsState},Nothing} end function Base.iterate(iter::GuidedSearchIterator) @@ -21,7 +21,7 @@ function Base.iterate(iter::GuidedSearchIterator) )) end -function Base.iterate(iter::GuidedSearchIterator, state::GuidedSearchState)::Union{Tuple{RuleNode, GuidedSearchState}, Nothing} +function Base.iterate(iter::GuidedSearchIterator, state::GuidedSearchState)::Union{Tuple{Tuple{RuleNode,Vector{Any}},GuidedSearchState},Nothing} grammar = get_grammar(iter.solver) start_symbol = get_starting_symbol(iter.solver) # wrap in while true to optimize for tail call @@ -30,7 +30,7 @@ function Base.iterate(iter::GuidedSearchIterator, state::GuidedSearchState)::Uni state.level += 1 push!(state.bank, []) - state.iter = NewProgramsIterator(state.level, state.bank, grammar) + state.iter = NewProgramsIterator(state.level, state.bank, grammar) state.next_iter = iterate(state.iter) if state.level > 0 @info ("Finished level $(state.level - 1) with $(length(state.bank[state.level])) programs") @@ -56,10 +56,11 @@ function Base.iterate(iter::GuidedSearchIterator, state::GuidedSearchState)::Uni if eval_observation in state.eval_cache # program already cached continue end - + push!(state.eval_cache, eval_observation) # add result to cache push!(state.bank[state.level+1], prog) # add program to bank - return (prog, state) # return program + + return ((prog, eval_observation), state) # return program end push!(state.bank[state.level+1], prog) # add program to bank diff --git a/src/probe/guided_trace_search_iterator.jl b/src/probe/guided_trace_search_iterator.jl index 2f00257..6da2674 100644 --- a/src/probe/guided_trace_search_iterator.jl +++ b/src/probe/guided_trace_search_iterator.jl @@ -1,12 +1,5 @@ @programiterator GuidedTraceSearchIterator() -Base.@kwdef mutable struct GuidedSearchState - level::Int64 - bank::Vector{Vector{RuleNode}} - eval_cache::Set - iter::NewProgramsIterator - next_iter::Union{Tuple{RuleNode, NewProgramsState}, Nothing} -end function Base.iterate(iter::GuidedTraceSearchIterator) iterate(iter, GuidedSearchState( @@ -18,7 +11,7 @@ function Base.iterate(iter::GuidedTraceSearchIterator) )) end -function Base.iterate(iter::GuidedTraceSearchIterator, state::GuidedSearchState)::Union{Tuple{RuleNode, GuidedSearchState}, Nothing} +function Base.iterate(iter::GuidedTraceSearchIterator, state::GuidedSearchState) grammar = get_grammar(iter.solver) start_symbol = get_starting_symbol(iter.solver) # wrap in while true to optimize for tail call @@ -27,7 +20,7 @@ function Base.iterate(iter::GuidedTraceSearchIterator, state::GuidedSearchState) state.level += 1 push!(state.bank, []) - state.iter = NewProgramsIterator(state.level, state.bank, grammar) + state.iter = NewProgramsIterator(state.level, state.bank, grammar) state.next_iter = iterate(state.iter) if state.level > 0 @info ("Finished level $(state.level - 1) with $(length(state.bank[state.level])) programs") @@ -43,15 +36,18 @@ function Base.iterate(iter::GuidedTraceSearchIterator, state::GuidedSearchState) # evaluate program if starting symbol if return_type(grammar, prog.ind) == start_symbol - eval_observation, is_done, is_partial_sol, final_reward = evaluate_trace(prog, grammar) - if eval_observation in state.eval_cache # program already cached + eval_observation, is_done, final_reward = evaluate_trace(prog, grammar) + eval_observation_rounded = round.(eval_observation, digits=1) + if eval_observation_rounded in state.eval_cache # program already cached # print("Skipping this.") + @info "Skipping program" continue end - - push!(state.eval_cache, eval_observation) # add result to cache + + push!(state.eval_cache, eval_observation_rounded) # add result to cache push!(state.bank[state.level+1], prog) # add program to bank - return (prog, state) # return program + + return ((prog, (eval_observation, is_done, final_reward)), state) # return program end push!(state.bank[state.level+1], prog) # add program to bank diff --git a/src/probe/probe_iterator.jl b/src/probe/probe_iterator.jl index dc591eb..ea47b87 100644 --- a/src/probe/probe_iterator.jl +++ b/src/probe/probe_iterator.jl @@ -14,6 +14,17 @@ function Base.:(==)(a::ProgramCache, b::ProgramCache) end Base.hash(a::ProgramCache) = hash(a.program) +mutable struct ProgramCacheTrace + program::RuleNode + cost::Int + reward::Float64 +end + +function Base.:(==)(a::ProgramCacheTrace, b::ProgramCacheTrace) + return a.program == b.program +end +Base.hash(a::ProgramCacheTrace) = hash(a.program) + include("sum_iterator.jl") include("new_program_iterator.jl") include("guided_search_iterator.jl") @@ -25,6 +36,12 @@ include("update_grammar.jl") select_partial_solution(partial_sols::Vector{ProgramCache}, all_selected_psols::Set{ProgramCache}) = HerbSearch.selectpsol_largest_subset(partial_sols, all_selected_psols) update_grammar!(grammar::ContextSensitiveGrammar, PSols_with_eval_cache::Vector{ProgramCache}, examples::Vector{<:IOExample}) = update_grammar(grammar, PSols_with_eval_cache, examples) +get_prog_eval(::ProgramIterator, prog::RuleNode) = (prog, []) + +get_prog_eval(::GuidedSearchIterator, prog::Tuple{RuleNode,Vector{Any}}) = prog + +get_prog_eval(::GuidedTraceSearchIterator, prog::Tuple{RuleNode,Tuple{Any,Bool,Number}}) = prog + """ probe(examples::Vector{<:IOExample}, iterator::ProgramIterator, max_time::Int, iteration_size::Int) @@ -49,15 +66,23 @@ function probe(examples::Vector{<:IOExample}, iterator::ProgramIterator, max_tim program, state = next # evaluate program - eval_observation = [] + program, eval_observation = get_prog_eval(iterator, program) correct_examples = Vector{Int}() - expr = rulenode2expr(program, grammar) - for (example_index, example) ∈ enumerate(examples) - output = execute_on_input(symboltable, expr, example.in) - push!(eval_observation, output) - - if output == example.out - push!(correct_examples, example_index) + if isempty(eval_observation) + expr = rulenode2expr(program, grammar) + for (example_index, example) ∈ enumerate(examples) + output = execute_on_input(symboltable, expr, example.in) + push!(eval_observation, output) + + if output == example.out + push!(correct_examples, example_index) + end + end + else + for i in 1:length(eval_observation) + if eval_observation[i] == examples[i].out + push!(correct_examples, i) + end end end @@ -105,17 +130,26 @@ function probe(examples::Vector{<:IOExample}, iterator::ProgramIterator, max_tim end -evaluate_trace(program::RuleNode, grammar::ContextSensitiveGrammar) = error("Evaluate trace method should be overwritten") +evaluate_trace(program::RuleNode, grammar::ContextSensitiveGrammar) = error("Evaluate trace method should be overwritten") +# this is here just to be overwritten in getting_started_minerl.jl +set_env_position(x, y, z) = error("Set env position method should be overwritten") -mutable struct ProgramCacheTrace - program::RuleNode - reward::Float64 +function select_partial_solution(partial_sols::Vector{ProgramCacheTrace}, all_selected_psols::Set{ProgramCacheTrace}) + if isempty(partial_sols) + return Vector{ProgramCache}() + end + push!(partial_sols, all_selected_psols...) + # sort partial solutions by reward + sort!(partial_sols, by=x -> x.reward, rev=true) + to_select = 5 + return partial_sols[1:min(to_select, length(partial_sols))] end + """ Probe for a solution using the given `iterator` and `examples` with a time limit of `max_time` and `iteration_size`. """ -function probe(traces::Vector{Trace}, iterator::ProgramIterator, max_time::Int, iteration_size::Int) where A +function probe(traces::Vector{Trace}, iterator::ProgramIterator, max_time::Int, iteration_size::Int) start_time = time() # store a set of all the results of evaluation programs eval_cache = Set() @@ -135,48 +169,60 @@ function probe(traces::Vector{Trace}, iterator::ProgramIterator, max_time::Int, while next !== nothing && i < iteration_size # run one iteration program, state = next - # evaluate - eval_observation, is_done, is_partial_sol, final_reward = evaluate_trace(program, grammar) - - if is_done + # evaluate + program, evaluation = get_prog_eval(iterator, program) + eval_observation, is_done, reward = isempty(evaluation) ? evaluate_trace(program, grammar, show_moves=true) : evaluation + eval_observation_rounded = round.(eval_observation, digits=1) + is_partial_sol = false + if reward > best_reward + 0.2 + best_reward = reward + printstyled("Best reward: $best_reward\n", color=:red) + is_partial_sol = true + end + if is_done @info "Last level: $(length(state.bank[state.level + 1])) programs" return program - elseif eval_observation in eval_cache # result already in cache + elseif eval_observation_rounded in eval_cache # result already in cache next = iterate(iterator, state) continue elseif is_partial_sol # partial solution - reward = calculate_program_cost(program, grammar) - push!(psol_with_eval_cache, ProgramCacheTrace(program, reward)) + cost = calculate_program_cost(program, grammar) + push!(psol_with_eval_cache, ProgramCacheTrace(program, cost, reward)) + # if length(psol_with_eval_cache) >= 2 # play with this threshold + # break + # end end - push!(eval_cache, eval_observation) + push!(eval_cache, eval_observation_rounded) - next = iterate(iterator, state) i += 1 + if i < iteration_size + next = iterate(iterator, state) + end end # check if program iterator is exhausted if next === nothing return nothing end - # TODO: Implement select_partial_solution and update for traces - partial_sols = filter(x -> x ∉ all_selected_psols, psol_with_eval_cache) - - # if !isempty(partial_sols) - # push!(all_selected_psols, partial_sols...) - # # update probabilites if any promising partial solutions - # update_grammar!(grammar, partial_sols, examples) # update probabilites - # # restart iterator - # eval_cache = Set() - # state = nothing - - # #for loop to update all_selected_psols with new costs - # for prog_with_cache ∈ all_selected_psols - # program = prog_with_cache.program - # new_cost = calculate_program_cost(program, grammar) - # # prog_with_cache.cost = new_cost - # end - # end + + partial_sols = select_partial_solution(psol_with_eval_cache, all_selected_psols) + if !isempty(partial_sols) + printstyled("Restarting!\n", color=:magenta) + + update_grammar!(grammar, partial_sols) # update probabilites + + # restart iterator + eval_cache = Set() + state = nothing + + #for loop to update all_selected_psols with new costs + # for prog_with_cache ∈ all_selected_psols + # program = prog_with_cache.program + # new_cost = calculate_program_cost(program, grammar) + # prog_with_cache.cost = new_cost + # end + end end return nothing diff --git a/src/probe/update_grammar.jl b/src/probe/update_grammar.jl index 1943f55..18ed953 100644 --- a/src/probe/update_grammar.jl +++ b/src/probe/update_grammar.jl @@ -38,6 +38,38 @@ function update_grammar(grammar::ContextSensitiveGrammar, PSols_with_eval_cache: @assert abs(total_sum - 1) <= 1e-4 "Total sum is $(total_sum) " end +function update_grammar!(grammar::ContextSensitiveGrammar, PSols_with_eval_cache::Vector{ProgramCacheTrace}) + sum = 0 + for rule_index in eachindex(grammar.rules) + best_reward = 0 + for psol in PSols_with_eval_cache + program = psol.program.children[end] + reward = psol.reward + # check if the program tree has rule_index somewhere inside it using a recursive function + if contains_rule(program, rule_index) && reward > best_reward + best_reward = reward + end + end + # fitness higher is better + # TODO: think about better thing here + fitness = min(best_reward / 100, 1) + + p_current = 2^(grammar.log_probabilities[rule_index]) + + sum += p_current^(1 - fitness) + log_prob = ((1 - fitness) * log(2, p_current)) + grammar.log_probabilities[rule_index] = log_prob + end + total_sum = 0 + for rule_index in eachindex(grammar.rules) + grammar.log_probabilities[rule_index] = grammar.log_probabilities[rule_index] - log(2, sum) + total_sum += 2^(grammar.log_probabilities[rule_index]) + end + expr = rulenode2expr(PSols_with_eval_cache[begin].program, grammar) + grammar.rules[1] = :([$expr; ACT]) + @assert abs(total_sum - 1) <= 1e-4 "Total sum is $(total_sum) " +end + """ contains_rule(program::RuleNode, rule_index::Int) diff --git a/test/test_probe.jl b/test/test_probe.jl index 15bb953..10124a5 100644 --- a/test/test_probe.jl +++ b/test/test_probe.jl @@ -191,10 +191,10 @@ end @testset "Multiple nonterminals" begin grammar = @pcsgrammar begin - 1 : A = 1 - 1 : A = A - B - 1 : B = 2 - 1 : C = A + B + 1:A = 1 + 1:A = A - B + 1:B = 2 + 1:C = A + B end examples = [ @@ -209,7 +209,7 @@ end next = iterate(iter) while next !== nothing prog, state = next - push!(progs, prog) + push!(progs, prog[1]) if (state.level > 1) break end