Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make the search start from the best reward position #106

Merged
merged 6 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 58 additions & 4 deletions src/minecraft/create_minerl_env.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,68 @@
pyimport("minerl")

gym = pyimport("gym")
# our_seed = 95812 <- hard env
# our_seed = 958122
our_seed = 958129 # initial env

"""
resetEnv()
This function resets the enviornment with the global provided seed and saves the initial X,Y,Z positions.
"""
function resetEnv()
env.seed(our_seed)
obs = env.reset()
global x_player_start, y_player_start, z_player_start = get_xyz_from_env(obs)
print(x_player_start, y_player_start, z_player_start)

Check warning on line 17 in src/minecraft/create_minerl_env.jl

View check run for this annotation

Codecov / codecov/patch

src/minecraft/create_minerl_env.jl#L13-L17

Added lines #L13 - L17 were not covered by tests

action = env.action_space.noop()
action["forward"] = 1

Check warning on line 20 in src/minecraft/create_minerl_env.jl

View check run for this annotation

Codecov / codecov/patch

src/minecraft/create_minerl_env.jl#L19-L20

Added lines #L19 - L20 were not covered by tests

# infinite health
env.set_next_chat_message("/effect @a minecraft:instant_health 1000000 100 true")
env.step(action)

Check warning on line 24 in src/minecraft/create_minerl_env.jl

View check run for this annotation

Codecov / codecov/patch

src/minecraft/create_minerl_env.jl#L23-L24

Added lines #L23 - L24 were not covered by tests

# infinite food
env.set_next_chat_message("/effect @a minecraft:saturation 1000000 255 true")
env.step(action)

Check warning on line 28 in src/minecraft/create_minerl_env.jl

View check run for this annotation

Codecov / codecov/patch

src/minecraft/create_minerl_env.jl#L27-L28

Added lines #L27 - L28 were not covered by tests

# disable mobs
env.set_next_chat_message("/gamerule doMobSpawning false")
env.step(action)
env.set_next_chat_message("/kill @e[type=!player]")
env.step(action)

Check warning on line 34 in src/minecraft/create_minerl_env.jl

View check run for this annotation

Codecov / codecov/patch

src/minecraft/create_minerl_env.jl#L31-L34

Added lines #L31 - L34 were not covered by tests

printstyled("Environment created. x: $x_player_start, y: $y_player_start, z: $z_player_start\n", color=:green)

Check warning on line 36 in src/minecraft/create_minerl_env.jl

View check run for this annotation

Codecov / codecov/patch

src/minecraft/create_minerl_env.jl#L36

Added line #L36 was not covered by tests
end

function set_start_xyz(x, y, z)
global x_player_start = x
global y_player_start = y
global z_player_start = z
println("New x: $x_player_start y: $y_player_start z: $z_player_start pos")

Check warning on line 43 in src/minecraft/create_minerl_env.jl

View check run for this annotation

Codecov / codecov/patch

src/minecraft/create_minerl_env.jl#L39-L43

Added lines #L39 - L43 were not covered by tests
end
function get_xyz_from_env(obs)
return obs["xpos"][1], obs["ypos"][1], obs["zpos"][1]

Check warning on line 46 in src/minecraft/create_minerl_env.jl

View check run for this annotation

Codecov / codecov/patch

src/minecraft/create_minerl_env.jl#L45-L46

Added lines #L45 - L46 were not covered by tests
end

function resetPosition()
action = env.action_space.noop()
env.set_next_chat_message("/tp @a $(x_player_start) $(y_player_start) $(z_player_start)")

Check warning on line 51 in src/minecraft/create_minerl_env.jl

View check run for this annotation

Codecov / codecov/patch

src/minecraft/create_minerl_env.jl#L49-L51

Added lines #L49 - L51 were not covered by tests

obs = env.step(action)[1]
obsx, obsy, obsz = get_xyz_from_env(obs)
while obsx != x_player_start || obsy != y_player_start || obsz != z_player_start
obs = env.step(action)[1]
obsx, obsy, obsz = get_xyz_from_env(obs)
end
println((obsx, obsy, obsz))

Check warning on line 59 in src/minecraft/create_minerl_env.jl

View check run for this annotation

Codecov / codecov/patch

src/minecraft/create_minerl_env.jl#L53-L59

Added lines #L53 - L59 were not covered by tests
end

if !@isdefined env
printstyled("Creating environment\n", color=:yellow)
env = gym.make("MineRLNavigateDenseProgSynth-v0")
env.seed(958129)
obs = env.reset()
x_player_start, y_player_start, z_player_start = obs["xpos"], obs["ypos"], obs["zpos"]
printstyled("Environment created x: $x_player_start, y: $y_player_start, z: $z_player_start\n", color=:green)
resetEnv()
printstyled("Environment created. x: $x_player_start, y: $y_player_start, z: $z_player_start\n", color=:green)
else
printstyled("Environment already created\n", color=:green)
end
166 changes: 49 additions & 117 deletions src/minecraft/getting_started_minerl.jl
Original file line number Diff line number Diff line change
@@ -1,150 +1,82 @@
include("create_minerl_env.jl")
using HerbGrammar, HerbSpecification
using HerbSearch

os = pyimport("os")
os.environ["LANG"] = "en_US.UTF-8"
# you might get an error here. Just type using HerbSearch in the REPL and run the whole file again.

# minerl_grammar = @pcsgrammar begin
# 10:action_name = "forward" | "left" | "right"
# 1:action_name = "jump"
# 1:action = Dict("camera" => [0, 0], action_name => 1)
# 1:sequence_actions = [sequence_actions; action]
# 1:sequence_actions = []
# end

minerl_grammar_2 = @pcsgrammar begin
100:F = 1
1:F = 0
1:L = 1
100:L = 0
1:R = 0
100:R = 1
100:J = 0
1:J = 1
100:B = 1
1:B = 0
10:sequence_actions = [sequence_actions; action]
using Logging
disable_logging(LogLevel(1))

minerl_grammar = @pcsgrammar begin
1:action_name = "forward"
1:action_name = "left"
1:action_name = "right"
1:action_name = "back"
1:action_name = "jump"
1:sequence_actions = [sequence_actions; action]

Check warning on line 13 in src/minecraft/getting_started_minerl.jl

View check run for this annotation

Codecov / codecov/patch

src/minecraft/getting_started_minerl.jl#L8-L13

Added lines #L8 - L13 were not covered by tests
1:sequence_actions = []
10:action= (TIMES, Dict("camera" => [0, 0], "forward" => F, "left" => L, "right" => R, "jump" => J, "back" => B))
1:TIMES = 25 | 50 | 75 | 100 | 125 | 150
1:action = (TIMES, Dict("camera" => [0, 0], action_name => 1))
5:TIMES = 1 | 5 | 25 | 50 | 75 | 100

Check warning on line 16 in src/minecraft/getting_started_minerl.jl

View check run for this annotation

Codecov / codecov/patch

src/minecraft/getting_started_minerl.jl#L15-L16

Added lines #L15 - L16 were not covered by tests
end

function create_random_program()
return rand(RuleNode, minerl_grammar, :sequence_actions)
minerl_grammar_2 = @pcsgrammar begin
1:SEQ = [ACT]
8:DIR = 0b0001 | 0b0010 | 0b0100 | 0b1000 | 0b0101 | 0b1001 | 0b0110 | 0b1010 # forward | back | left | right | forward-left | forward-right | back-left | back-right
1:ACT = (TIMES, Dict("move" => DIR, "sprint" => 1, "jump" => 1))
6:TIMES = 5 | 10 | 25 | 50 | 75 | 100

Check warning on line 23 in src/minecraft/getting_started_minerl.jl

View check run for this annotation

Codecov / codecov/patch

src/minecraft/getting_started_minerl.jl#L20-L23

Added lines #L20 - L23 were not covered by tests
end

iterations = 0
best_reward = 0
function evaluate_trace_minerl(prog, grammar, env)
function evaluate_trace_minerl(prog, grammar, env, show_moves)

Check warning on line 26 in src/minecraft/getting_started_minerl.jl

View check run for this annotation

Codecov / codecov/patch

src/minecraft/getting_started_minerl.jl#L26

Added line #L26 was not covered by tests
resetPosition()
expr = rulenode2expr(prog, grammar)
# println("expr is: ", expr)


is_done = false
is_partial_sol = false

sequence_of_actions = eval(expr)
if isempty(sequence_of_actions)
return (0, 0, 0), false, false, 0
end

sum_of_rewards = 0
is_done = false

Check warning on line 33 in src/minecraft/getting_started_minerl.jl

View check run for this annotation

Codecov / codecov/patch

src/minecraft/getting_started_minerl.jl#L33

Added line #L33 was not covered by tests
obs = nothing
for saved_action ∈ sequence_of_actions
times, action = saved_action

for (times, action) ∈ sequence_of_actions

Check warning on line 35 in src/minecraft/getting_started_minerl.jl

View check run for this annotation

Codecov / codecov/patch

src/minecraft/getting_started_minerl.jl#L35

Added line #L35 was not covered by tests
new_action = env.action_space.noop()
for key ∈ keys(action)
new_action[key] = action[key]
for (key, val) in action
if key == "move"
new_action["forward"] = val & 1
new_action["back"] = val >> 1 & 1
new_action["left"] = val >> 2 & 1
new_action["right"] = val >> 3

Check warning on line 42 in src/minecraft/getting_started_minerl.jl

View check run for this annotation

Codecov / codecov/patch

src/minecraft/getting_started_minerl.jl#L37-L42

Added lines #L37 - L42 were not covered by tests
else
new_action[key] = val

Check warning on line 44 in src/minecraft/getting_started_minerl.jl

View check run for this annotation

Codecov / codecov/patch

src/minecraft/getting_started_minerl.jl#L44

Added line #L44 was not covered by tests
end
end

for i in 1:times
obs, reward, done, _ = env.step(new_action)
env.render()
if show_moves
env.render()

Check warning on line 51 in src/minecraft/getting_started_minerl.jl

View check run for this annotation

Codecov / codecov/patch

src/minecraft/getting_started_minerl.jl#L50-L51

Added lines #L50 - L51 were not covered by tests
end

sum_of_rewards += reward

if reward > 0
is_partial_sol = true
end
if done
println("Rewards ", sum_of_rewards)
is_done = true
printstyled("done\n", color=:green)
printstyled("sum of rewards: $sum_of_rewards. Done\n", color=:green)

Check warning on line 57 in src/minecraft/getting_started_minerl.jl

View check run for this annotation

Codecov / codecov/patch

src/minecraft/getting_started_minerl.jl#L57

Added line #L57 was not covered by tests
break
end
end
end
println("Got reward: ", sum_of_rewards)
global best_reward = max(best_reward, sum_of_rewards)
printstyled("Best reward: $best_reward\n",color=:red)

eval_observation = (round(obs["xpos"][1], digits=1), round(obs["ypos"][1], digits=1), round(obs["zpos"][1], digits=1))
println(eval_observation)
return eval_observation, is_done, is_partial_sol, sum_of_rewards
end

function run_action(action)
obs, reward, done, _ = env.step(action)
env.render()
println("reward: $reward")
end

function resetEnv()
obs = env.reset()
x_player_start, y_player_start, z_player_start = obs["xpos"], obs["ypos"], obs["zpos"]
printstyled("Environment reset x: $x_player_start, y: $y_player_start, z: $z_player_start\n", color=:green)
end


function resetPosition()
action = env.action_space.noop()
env.set_next_chat_message("/tp @a $(x_player_start[1]) $(y_player_start[1]) $(z_player_start[1])")

obs = env.step(action)[1]
while obs["xpos"] != x_player_start || obs["ypos"] != y_player_start || obs["zpos"] != z_player_start
obs = env.step(action)[1]
end

env.render()
end

function run_forward_and_random()
while true
if rand() < 0.8
new_action = env.action_space.noop()
new_action["forward"] = 1
# new_action["jump"] = 1
new_action["spint"] = 1
run_action(new_action)
else
prog = create_random_program()
evaluate_trace_minerl(prog, minerl_grammar_2, env)
if is_done
break

Check warning on line 62 in src/minecraft/getting_started_minerl.jl

View check run for this annotation

Codecov / codecov/patch

src/minecraft/getting_started_minerl.jl#L61-L62

Added lines #L61 - L62 were not covered by tests
end
end

println("Reward $sum_of_rewards")
return get_xyz_from_env(obs), is_done, sum_of_rewards

Check warning on line 66 in src/minecraft/getting_started_minerl.jl

View check run for this annotation

Codecov / codecov/patch

src/minecraft/getting_started_minerl.jl#L65-L66

Added lines #L65 - L66 were not covered by tests
end

# run_forward_and_random()
# make sure the probabilities are equal
@assert all(prob -> prob == minerl_grammar_2.log_probabilities[begin], minerl_grammar_2.log_probabilities)

Check warning on line 70 in src/minecraft/getting_started_minerl.jl

View check run for this annotation

Codecov / codecov/patch

src/minecraft/getting_started_minerl.jl#L70

Added line #L70 was not covered by tests

# # # overwrite the evaluate trace function
HerbSearch.evaluate_trace(prog::RuleNode, grammar::ContextSensitiveGrammar) = evaluate_trace_minerl(prog, grammar, env)
function HerbSearch.set_env_position(x, y, z)
println("Setting env position: ($x, $y, $z)")
set_start_xyz(x, y, z)

Check warning on line 74 in src/minecraft/getting_started_minerl.jl

View check run for this annotation

Codecov / codecov/patch

src/minecraft/getting_started_minerl.jl#L72-L74

Added lines #L72 - L74 were not covered by tests
end
# overwrite the evaluate trace function
HerbSearch.evaluate_trace(prog::RuleNode, grammar::ContextSensitiveGrammar; show_moves=false) = evaluate_trace_minerl(prog, grammar, env, show_moves)

Check warning on line 77 in src/minecraft/getting_started_minerl.jl

View check run for this annotation

Codecov / codecov/patch

src/minecraft/getting_started_minerl.jl#L77

Added line #L77 was not covered by tests
HerbSearch.calculate_rule_cost(rule_index::Int, grammar::ContextSensitiveGrammar) = HerbSearch.calculate_rule_cost_prob(rule_index, grammar)
iter = HerbSearch.GuidedTraceSearchIterator(minerl_grammar_2, :sequence_actions)
program = probe(Vector{Trace}(), iter, 400000, 100000)


# state = nothing
# next = iterate(iter)
# print(next)
# while next !== nothing
# prog, state = next
# if (state.level > 100)
# break
# end
# global next = iterate(iter, state)
# end

# resetEnv()
iter = HerbSearch.GuidedTraceSearchIterator(minerl_grammar_2, :SEQ)
program = @time probe(Vector{Trace}(), iter, 3000000, 6)
11 changes: 6 additions & 5 deletions src/probe/guided_search_iterator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Base.@kwdef mutable struct GuidedSearchState
bank::Vector{Vector{RuleNode}}
eval_cache::Set
iter::NewProgramsIterator
next_iter::Union{Tuple{RuleNode, NewProgramsState}, Nothing}
next_iter::Union{Tuple{RuleNode,NewProgramsState},Nothing}
end

function Base.iterate(iter::GuidedSearchIterator)
Expand All @@ -21,7 +21,7 @@ function Base.iterate(iter::GuidedSearchIterator)
))
end

function Base.iterate(iter::GuidedSearchIterator, state::GuidedSearchState)::Union{Tuple{RuleNode, GuidedSearchState}, Nothing}
function Base.iterate(iter::GuidedSearchIterator, state::GuidedSearchState)::Union{Tuple{Tuple{RuleNode,Vector{Any}},GuidedSearchState},Nothing}
grammar = get_grammar(iter.solver)
start_symbol = get_starting_symbol(iter.solver)
# wrap in while true to optimize for tail call
Expand All @@ -30,7 +30,7 @@ function Base.iterate(iter::GuidedSearchIterator, state::GuidedSearchState)::Uni
state.level += 1
push!(state.bank, [])

state.iter = NewProgramsIterator(state.level, state.bank, grammar)
state.iter = NewProgramsIterator(state.level, state.bank, grammar)
state.next_iter = iterate(state.iter)
if state.level > 0
@info ("Finished level $(state.level - 1) with $(length(state.bank[state.level])) programs")
Expand All @@ -56,10 +56,11 @@ function Base.iterate(iter::GuidedSearchIterator, state::GuidedSearchState)::Uni
if eval_observation in state.eval_cache # program already cached
continue
end

push!(state.eval_cache, eval_observation) # add result to cache
push!(state.bank[state.level+1], prog) # add program to bank
return (prog, state) # return program

return ((prog, eval_observation), state) # return program
end

push!(state.bank[state.level+1], prog) # add program to bank
Expand Down
24 changes: 10 additions & 14 deletions src/probe/guided_trace_search_iterator.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,5 @@

@programiterator GuidedTraceSearchIterator()
Base.@kwdef mutable struct GuidedSearchState
level::Int64
bank::Vector{Vector{RuleNode}}
eval_cache::Set
iter::NewProgramsIterator
next_iter::Union{Tuple{RuleNode, NewProgramsState}, Nothing}
end

function Base.iterate(iter::GuidedTraceSearchIterator)
iterate(iter, GuidedSearchState(
Expand All @@ -18,7 +11,7 @@
))
end

function Base.iterate(iter::GuidedTraceSearchIterator, state::GuidedSearchState)::Union{Tuple{RuleNode, GuidedSearchState}, Nothing}
function Base.iterate(iter::GuidedTraceSearchIterator, state::GuidedSearchState)

Check warning on line 14 in src/probe/guided_trace_search_iterator.jl

View check run for this annotation

Codecov / codecov/patch

src/probe/guided_trace_search_iterator.jl#L14

Added line #L14 was not covered by tests
grammar = get_grammar(iter.solver)
start_symbol = get_starting_symbol(iter.solver)
# wrap in while true to optimize for tail call
Expand All @@ -27,7 +20,7 @@
state.level += 1
push!(state.bank, [])

state.iter = NewProgramsIterator(state.level, state.bank, grammar)
state.iter = NewProgramsIterator(state.level, state.bank, grammar)

Check warning on line 23 in src/probe/guided_trace_search_iterator.jl

View check run for this annotation

Codecov / codecov/patch

src/probe/guided_trace_search_iterator.jl#L23

Added line #L23 was not covered by tests
state.next_iter = iterate(state.iter)
if state.level > 0
@info ("Finished level $(state.level - 1) with $(length(state.bank[state.level])) programs")
Expand All @@ -43,15 +36,18 @@

# evaluate program if starting symbol
if return_type(grammar, prog.ind) == start_symbol
eval_observation, is_done, is_partial_sol, final_reward = evaluate_trace(prog, grammar)
if eval_observation in state.eval_cache # program already cached
eval_observation, is_done, final_reward = evaluate_trace(prog, grammar)
eval_observation_rounded = round.(eval_observation, digits=1)
if eval_observation_rounded in state.eval_cache # program already cached

Check warning on line 41 in src/probe/guided_trace_search_iterator.jl

View check run for this annotation

Codecov / codecov/patch

src/probe/guided_trace_search_iterator.jl#L39-L41

Added lines #L39 - L41 were not covered by tests
# print("Skipping this.")
@info "Skipping program"

Check warning on line 43 in src/probe/guided_trace_search_iterator.jl

View check run for this annotation

Codecov / codecov/patch

src/probe/guided_trace_search_iterator.jl#L43

Added line #L43 was not covered by tests
continue
end
push!(state.eval_cache, eval_observation) # add result to cache

push!(state.eval_cache, eval_observation_rounded) # add result to cache

Check warning on line 47 in src/probe/guided_trace_search_iterator.jl

View check run for this annotation

Codecov / codecov/patch

src/probe/guided_trace_search_iterator.jl#L47

Added line #L47 was not covered by tests
push!(state.bank[state.level+1], prog) # add program to bank
return (prog, state) # return program

return ((prog, (eval_observation, is_done, final_reward)), state) # return program

Check warning on line 50 in src/probe/guided_trace_search_iterator.jl

View check run for this annotation

Codecov / codecov/patch

src/probe/guided_trace_search_iterator.jl#L50

Added line #L50 was not covered by tests
end

push!(state.bank[state.level+1], prog) # add program to bank
Expand Down
Loading
Loading