Skip to content

Commit

Permalink
Merge pull request #106 from Herb-AI/probe-with-minerl-nick
Browse files Browse the repository at this point in the history
Make the search start from the best reward position
  • Loading branch information
eErr0Re authored May 24, 2024
2 parents bc0079a + 1f0c97e commit 677c030
Show file tree
Hide file tree
Showing 7 changed files with 246 additions and 185 deletions.
62 changes: 58 additions & 4 deletions src/minecraft/create_minerl_env.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,68 @@ using PyCall
pyimport("minerl")

gym = pyimport("gym")
# our_seed = 95812 <- hard env
# our_seed = 958122
our_seed = 958129 # initial env

"""
resetEnv()
This function resets the enviornment with the global provided seed and saves the initial X,Y,Z positions.
"""
function resetEnv()
env.seed(our_seed)
obs = env.reset()
global x_player_start, y_player_start, z_player_start = get_xyz_from_env(obs)
print(x_player_start, y_player_start, z_player_start)

action = env.action_space.noop()
action["forward"] = 1

# infinite health
env.set_next_chat_message("/effect @a minecraft:instant_health 1000000 100 true")
env.step(action)

# infinite food
env.set_next_chat_message("/effect @a minecraft:saturation 1000000 255 true")
env.step(action)

# disable mobs
env.set_next_chat_message("/gamerule doMobSpawning false")
env.step(action)
env.set_next_chat_message("/kill @e[type=!player]")
env.step(action)

printstyled("Environment created. x: $x_player_start, y: $y_player_start, z: $z_player_start\n", color=:green)
end

function set_start_xyz(x, y, z)
global x_player_start = x
global y_player_start = y
global z_player_start = z
println("New x: $x_player_start y: $y_player_start z: $z_player_start pos")
end
function get_xyz_from_env(obs)
return obs["xpos"][1], obs["ypos"][1], obs["zpos"][1]
end

function resetPosition()
action = env.action_space.noop()
env.set_next_chat_message("/tp @a $(x_player_start) $(y_player_start) $(z_player_start)")

obs = env.step(action)[1]
obsx, obsy, obsz = get_xyz_from_env(obs)
while obsx != x_player_start || obsy != y_player_start || obsz != z_player_start
obs = env.step(action)[1]
obsx, obsy, obsz = get_xyz_from_env(obs)
end
println((obsx, obsy, obsz))
end

if !@isdefined env
printstyled("Creating environment\n", color=:yellow)
env = gym.make("MineRLNavigateDenseProgSynth-v0")
env.seed(958129)
obs = env.reset()
x_player_start, y_player_start, z_player_start = obs["xpos"], obs["ypos"], obs["zpos"]
printstyled("Environment created x: $x_player_start, y: $y_player_start, z: $z_player_start\n", color=:green)
resetEnv()
printstyled("Environment created. x: $x_player_start, y: $y_player_start, z: $z_player_start\n", color=:green)
else
printstyled("Environment already created\n", color=:green)
end
166 changes: 49 additions & 117 deletions src/minecraft/getting_started_minerl.jl
Original file line number Diff line number Diff line change
@@ -1,150 +1,82 @@
include("create_minerl_env.jl")
using HerbGrammar, HerbSpecification
using HerbSearch

os = pyimport("os")
os.environ["LANG"] = "en_US.UTF-8"
# you might get an error here. Just type using HerbSearch in the REPL and run the whole file again.

# minerl_grammar = @pcsgrammar begin
# 10:action_name = "forward" | "left" | "right"
# 1:action_name = "jump"
# 1:action = Dict("camera" => [0, 0], action_name => 1)
# 1:sequence_actions = [sequence_actions; action]
# 1:sequence_actions = []
# end

minerl_grammar_2 = @pcsgrammar begin
100:F = 1
1:F = 0
1:L = 1
100:L = 0
1:R = 0
100:R = 1
100:J = 0
1:J = 1
100:B = 1
1:B = 0
10:sequence_actions = [sequence_actions; action]
using Logging
disable_logging(LogLevel(1))

minerl_grammar = @pcsgrammar begin
1:action_name = "forward"
1:action_name = "left"
1:action_name = "right"
1:action_name = "back"
1:action_name = "jump"
1:sequence_actions = [sequence_actions; action]
1:sequence_actions = []
10:action= (TIMES, Dict("camera" => [0, 0], "forward" => F, "left" => L, "right" => R, "jump" => J, "back" => B))
1:TIMES = 25 | 50 | 75 | 100 | 125 | 150
1:action = (TIMES, Dict("camera" => [0, 0], action_name => 1))
5:TIMES = 1 | 5 | 25 | 50 | 75 | 100
end

function create_random_program()
return rand(RuleNode, minerl_grammar, :sequence_actions)
minerl_grammar_2 = @pcsgrammar begin
1:SEQ = [ACT]
8:DIR = 0b0001 | 0b0010 | 0b0100 | 0b1000 | 0b0101 | 0b1001 | 0b0110 | 0b1010 # forward | back | left | right | forward-left | forward-right | back-left | back-right
1:ACT = (TIMES, Dict("move" => DIR, "sprint" => 1, "jump" => 1))
6:TIMES = 5 | 10 | 25 | 50 | 75 | 100
end

iterations = 0
best_reward = 0
function evaluate_trace_minerl(prog, grammar, env)
function evaluate_trace_minerl(prog, grammar, env, show_moves)
resetPosition()
expr = rulenode2expr(prog, grammar)
# println("expr is: ", expr)


is_done = false
is_partial_sol = false

sequence_of_actions = eval(expr)
if isempty(sequence_of_actions)
return (0, 0, 0), false, false, 0
end

sum_of_rewards = 0
is_done = false
obs = nothing
for saved_action sequence_of_actions
times, action = saved_action

for (times, action) sequence_of_actions
new_action = env.action_space.noop()
for key keys(action)
new_action[key] = action[key]
for (key, val) in action
if key == "move"
new_action["forward"] = val & 1
new_action["back"] = val >> 1 & 1
new_action["left"] = val >> 2 & 1
new_action["right"] = val >> 3
else
new_action[key] = val
end
end

for i in 1:times
obs, reward, done, _ = env.step(new_action)
env.render()
if show_moves
env.render()
end

sum_of_rewards += reward

if reward > 0
is_partial_sol = true
end
if done
println("Rewards ", sum_of_rewards)
is_done = true
printstyled("done\n", color=:green)
printstyled("sum of rewards: $sum_of_rewards. Done\n", color=:green)
break
end
end
end
println("Got reward: ", sum_of_rewards)
global best_reward = max(best_reward, sum_of_rewards)
printstyled("Best reward: $best_reward\n",color=:red)

eval_observation = (round(obs["xpos"][1], digits=1), round(obs["ypos"][1], digits=1), round(obs["zpos"][1], digits=1))
println(eval_observation)
return eval_observation, is_done, is_partial_sol, sum_of_rewards
end

function run_action(action)
obs, reward, done, _ = env.step(action)
env.render()
println("reward: $reward")
end

function resetEnv()
obs = env.reset()
x_player_start, y_player_start, z_player_start = obs["xpos"], obs["ypos"], obs["zpos"]
printstyled("Environment reset x: $x_player_start, y: $y_player_start, z: $z_player_start\n", color=:green)
end


function resetPosition()
action = env.action_space.noop()
env.set_next_chat_message("/tp @a $(x_player_start[1]) $(y_player_start[1]) $(z_player_start[1])")

obs = env.step(action)[1]
while obs["xpos"] != x_player_start || obs["ypos"] != y_player_start || obs["zpos"] != z_player_start
obs = env.step(action)[1]
end

env.render()
end

function run_forward_and_random()
while true
if rand() < 0.8
new_action = env.action_space.noop()
new_action["forward"] = 1
# new_action["jump"] = 1
new_action["spint"] = 1
run_action(new_action)
else
prog = create_random_program()
evaluate_trace_minerl(prog, minerl_grammar_2, env)
if is_done
break
end
end

println("Reward $sum_of_rewards")
return get_xyz_from_env(obs), is_done, sum_of_rewards
end

# run_forward_and_random()
# make sure the probabilities are equal
@assert all(prob -> prob == minerl_grammar_2.log_probabilities[begin], minerl_grammar_2.log_probabilities)

# # # overwrite the evaluate trace function
HerbSearch.evaluate_trace(prog::RuleNode, grammar::ContextSensitiveGrammar) = evaluate_trace_minerl(prog, grammar, env)
function HerbSearch.set_env_position(x, y, z)
println("Setting env position: ($x, $y, $z)")
set_start_xyz(x, y, z)
end
# overwrite the evaluate trace function
HerbSearch.evaluate_trace(prog::RuleNode, grammar::ContextSensitiveGrammar; show_moves=false) = evaluate_trace_minerl(prog, grammar, env, show_moves)
HerbSearch.calculate_rule_cost(rule_index::Int, grammar::ContextSensitiveGrammar) = HerbSearch.calculate_rule_cost_prob(rule_index, grammar)
iter = HerbSearch.GuidedTraceSearchIterator(minerl_grammar_2, :sequence_actions)
program = probe(Vector{Trace}(), iter, 400000, 100000)


# state = nothing
# next = iterate(iter)
# print(next)
# while next !== nothing
# prog, state = next
# if (state.level > 100)
# break
# end
# global next = iterate(iter, state)
# end

# resetEnv()
iter = HerbSearch.GuidedTraceSearchIterator(minerl_grammar_2, :SEQ)
program = @time probe(Vector{Trace}(), iter, 3000000, 6)
11 changes: 6 additions & 5 deletions src/probe/guided_search_iterator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Base.@kwdef mutable struct GuidedSearchState
bank::Vector{Vector{RuleNode}}
eval_cache::Set
iter::NewProgramsIterator
next_iter::Union{Tuple{RuleNode, NewProgramsState}, Nothing}
next_iter::Union{Tuple{RuleNode,NewProgramsState},Nothing}
end

function Base.iterate(iter::GuidedSearchIterator)
Expand All @@ -21,7 +21,7 @@ function Base.iterate(iter::GuidedSearchIterator)
))
end

function Base.iterate(iter::GuidedSearchIterator, state::GuidedSearchState)::Union{Tuple{RuleNode, GuidedSearchState}, Nothing}
function Base.iterate(iter::GuidedSearchIterator, state::GuidedSearchState)::Union{Tuple{Tuple{RuleNode,Vector{Any}},GuidedSearchState},Nothing}
grammar = get_grammar(iter.solver)
start_symbol = get_starting_symbol(iter.solver)
# wrap in while true to optimize for tail call
Expand All @@ -30,7 +30,7 @@ function Base.iterate(iter::GuidedSearchIterator, state::GuidedSearchState)::Uni
state.level += 1
push!(state.bank, [])

state.iter = NewProgramsIterator(state.level, state.bank, grammar)
state.iter = NewProgramsIterator(state.level, state.bank, grammar)
state.next_iter = iterate(state.iter)
if state.level > 0
@info ("Finished level $(state.level - 1) with $(length(state.bank[state.level])) programs")
Expand All @@ -56,10 +56,11 @@ function Base.iterate(iter::GuidedSearchIterator, state::GuidedSearchState)::Uni
if eval_observation in state.eval_cache # program already cached
continue
end

push!(state.eval_cache, eval_observation) # add result to cache
push!(state.bank[state.level+1], prog) # add program to bank
return (prog, state) # return program

return ((prog, eval_observation), state) # return program
end

push!(state.bank[state.level+1], prog) # add program to bank
Expand Down
24 changes: 10 additions & 14 deletions src/probe/guided_trace_search_iterator.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,5 @@

@programiterator GuidedTraceSearchIterator()
Base.@kwdef mutable struct GuidedSearchState
level::Int64
bank::Vector{Vector{RuleNode}}
eval_cache::Set
iter::NewProgramsIterator
next_iter::Union{Tuple{RuleNode, NewProgramsState}, Nothing}
end

function Base.iterate(iter::GuidedTraceSearchIterator)
iterate(iter, GuidedSearchState(
Expand All @@ -18,7 +11,7 @@ function Base.iterate(iter::GuidedTraceSearchIterator)
))
end

function Base.iterate(iter::GuidedTraceSearchIterator, state::GuidedSearchState)::Union{Tuple{RuleNode, GuidedSearchState}, Nothing}
function Base.iterate(iter::GuidedTraceSearchIterator, state::GuidedSearchState)
grammar = get_grammar(iter.solver)
start_symbol = get_starting_symbol(iter.solver)
# wrap in while true to optimize for tail call
Expand All @@ -27,7 +20,7 @@ function Base.iterate(iter::GuidedTraceSearchIterator, state::GuidedSearchState)
state.level += 1
push!(state.bank, [])

state.iter = NewProgramsIterator(state.level, state.bank, grammar)
state.iter = NewProgramsIterator(state.level, state.bank, grammar)
state.next_iter = iterate(state.iter)
if state.level > 0
@info ("Finished level $(state.level - 1) with $(length(state.bank[state.level])) programs")
Expand All @@ -43,15 +36,18 @@ function Base.iterate(iter::GuidedTraceSearchIterator, state::GuidedSearchState)

# evaluate program if starting symbol
if return_type(grammar, prog.ind) == start_symbol
eval_observation, is_done, is_partial_sol, final_reward = evaluate_trace(prog, grammar)
if eval_observation in state.eval_cache # program already cached
eval_observation, is_done, final_reward = evaluate_trace(prog, grammar)
eval_observation_rounded = round.(eval_observation, digits=1)
if eval_observation_rounded in state.eval_cache # program already cached
# print("Skipping this.")
@info "Skipping program"
continue
end
push!(state.eval_cache, eval_observation) # add result to cache

push!(state.eval_cache, eval_observation_rounded) # add result to cache
push!(state.bank[state.level+1], prog) # add program to bank
return (prog, state) # return program

return ((prog, (eval_observation, is_done, final_reward)), state) # return program
end

push!(state.bank[state.level+1], prog) # add program to bank
Expand Down
Loading

0 comments on commit 677c030

Please sign in to comment.