-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #106 from Herb-AI/probe-with-minerl-nick
Make the search start from the best reward position
- Loading branch information
Showing
7 changed files
with
246 additions
and
185 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,150 +1,82 @@ | ||
include("create_minerl_env.jl") | ||
using HerbGrammar, HerbSpecification | ||
using HerbSearch | ||
|
||
os = pyimport("os") | ||
os.environ["LANG"] = "en_US.UTF-8" | ||
# you might get an error here. Just type using HerbSearch in the REPL and run the whole file again. | ||
|
||
# minerl_grammar = @pcsgrammar begin | ||
# 10:action_name = "forward" | "left" | "right" | ||
# 1:action_name = "jump" | ||
# 1:action = Dict("camera" => [0, 0], action_name => 1) | ||
# 1:sequence_actions = [sequence_actions; action] | ||
# 1:sequence_actions = [] | ||
# end | ||
|
||
minerl_grammar_2 = @pcsgrammar begin | ||
100:F = 1 | ||
1:F = 0 | ||
1:L = 1 | ||
100:L = 0 | ||
1:R = 0 | ||
100:R = 1 | ||
100:J = 0 | ||
1:J = 1 | ||
100:B = 1 | ||
1:B = 0 | ||
10:sequence_actions = [sequence_actions; action] | ||
using Logging | ||
disable_logging(LogLevel(1)) | ||
|
||
minerl_grammar = @pcsgrammar begin | ||
1:action_name = "forward" | ||
1:action_name = "left" | ||
1:action_name = "right" | ||
1:action_name = "back" | ||
1:action_name = "jump" | ||
1:sequence_actions = [sequence_actions; action] | ||
1:sequence_actions = [] | ||
10:action= (TIMES, Dict("camera" => [0, 0], "forward" => F, "left" => L, "right" => R, "jump" => J, "back" => B)) | ||
1:TIMES = 25 | 50 | 75 | 100 | 125 | 150 | ||
1:action = (TIMES, Dict("camera" => [0, 0], action_name => 1)) | ||
5:TIMES = 1 | 5 | 25 | 50 | 75 | 100 | ||
end | ||
|
||
function create_random_program() | ||
return rand(RuleNode, minerl_grammar, :sequence_actions) | ||
minerl_grammar_2 = @pcsgrammar begin | ||
1:SEQ = [ACT] | ||
8:DIR = 0b0001 | 0b0010 | 0b0100 | 0b1000 | 0b0101 | 0b1001 | 0b0110 | 0b1010 # forward | back | left | right | forward-left | forward-right | back-left | back-right | ||
1:ACT = (TIMES, Dict("move" => DIR, "sprint" => 1, "jump" => 1)) | ||
6:TIMES = 5 | 10 | 25 | 50 | 75 | 100 | ||
end | ||
|
||
iterations = 0 | ||
best_reward = 0 | ||
function evaluate_trace_minerl(prog, grammar, env) | ||
function evaluate_trace_minerl(prog, grammar, env, show_moves) | ||
resetPosition() | ||
expr = rulenode2expr(prog, grammar) | ||
# println("expr is: ", expr) | ||
|
||
|
||
is_done = false | ||
is_partial_sol = false | ||
|
||
sequence_of_actions = eval(expr) | ||
if isempty(sequence_of_actions) | ||
return (0, 0, 0), false, false, 0 | ||
end | ||
|
||
sum_of_rewards = 0 | ||
is_done = false | ||
obs = nothing | ||
for saved_action ∈ sequence_of_actions | ||
times, action = saved_action | ||
|
||
for (times, action) ∈ sequence_of_actions | ||
new_action = env.action_space.noop() | ||
for key ∈ keys(action) | ||
new_action[key] = action[key] | ||
for (key, val) in action | ||
if key == "move" | ||
new_action["forward"] = val & 1 | ||
new_action["back"] = val >> 1 & 1 | ||
new_action["left"] = val >> 2 & 1 | ||
new_action["right"] = val >> 3 | ||
else | ||
new_action[key] = val | ||
end | ||
end | ||
|
||
for i in 1:times | ||
obs, reward, done, _ = env.step(new_action) | ||
env.render() | ||
if show_moves | ||
env.render() | ||
end | ||
|
||
sum_of_rewards += reward | ||
|
||
if reward > 0 | ||
is_partial_sol = true | ||
end | ||
if done | ||
println("Rewards ", sum_of_rewards) | ||
is_done = true | ||
printstyled("done\n", color=:green) | ||
printstyled("sum of rewards: $sum_of_rewards. Done\n", color=:green) | ||
break | ||
end | ||
end | ||
end | ||
println("Got reward: ", sum_of_rewards) | ||
global best_reward = max(best_reward, sum_of_rewards) | ||
printstyled("Best reward: $best_reward\n",color=:red) | ||
|
||
eval_observation = (round(obs["xpos"][1], digits=1), round(obs["ypos"][1], digits=1), round(obs["zpos"][1], digits=1)) | ||
println(eval_observation) | ||
return eval_observation, is_done, is_partial_sol, sum_of_rewards | ||
end | ||
|
||
function run_action(action) | ||
obs, reward, done, _ = env.step(action) | ||
env.render() | ||
println("reward: $reward") | ||
end | ||
|
||
function resetEnv() | ||
obs = env.reset() | ||
x_player_start, y_player_start, z_player_start = obs["xpos"], obs["ypos"], obs["zpos"] | ||
printstyled("Environment reset x: $x_player_start, y: $y_player_start, z: $z_player_start\n", color=:green) | ||
end | ||
|
||
|
||
function resetPosition() | ||
action = env.action_space.noop() | ||
env.set_next_chat_message("/tp @a $(x_player_start[1]) $(y_player_start[1]) $(z_player_start[1])") | ||
|
||
obs = env.step(action)[1] | ||
while obs["xpos"] != x_player_start || obs["ypos"] != y_player_start || obs["zpos"] != z_player_start | ||
obs = env.step(action)[1] | ||
end | ||
|
||
env.render() | ||
end | ||
|
||
function run_forward_and_random() | ||
while true | ||
if rand() < 0.8 | ||
new_action = env.action_space.noop() | ||
new_action["forward"] = 1 | ||
# new_action["jump"] = 1 | ||
new_action["spint"] = 1 | ||
run_action(new_action) | ||
else | ||
prog = create_random_program() | ||
evaluate_trace_minerl(prog, minerl_grammar_2, env) | ||
if is_done | ||
break | ||
end | ||
end | ||
|
||
println("Reward $sum_of_rewards") | ||
return get_xyz_from_env(obs), is_done, sum_of_rewards | ||
end | ||
|
||
# run_forward_and_random() | ||
# make sure the probabilities are equal | ||
@assert all(prob -> prob == minerl_grammar_2.log_probabilities[begin], minerl_grammar_2.log_probabilities) | ||
|
||
# # # overwrite the evaluate trace function | ||
HerbSearch.evaluate_trace(prog::RuleNode, grammar::ContextSensitiveGrammar) = evaluate_trace_minerl(prog, grammar, env) | ||
function HerbSearch.set_env_position(x, y, z) | ||
println("Setting env position: ($x, $y, $z)") | ||
set_start_xyz(x, y, z) | ||
end | ||
# overwrite the evaluate trace function | ||
HerbSearch.evaluate_trace(prog::RuleNode, grammar::ContextSensitiveGrammar; show_moves=false) = evaluate_trace_minerl(prog, grammar, env, show_moves) | ||
HerbSearch.calculate_rule_cost(rule_index::Int, grammar::ContextSensitiveGrammar) = HerbSearch.calculate_rule_cost_prob(rule_index, grammar) | ||
iter = HerbSearch.GuidedTraceSearchIterator(minerl_grammar_2, :sequence_actions) | ||
program = probe(Vector{Trace}(), iter, 400000, 100000) | ||
|
||
|
||
# state = nothing | ||
# next = iterate(iter) | ||
# print(next) | ||
# while next !== nothing | ||
# prog, state = next | ||
# if (state.level > 100) | ||
# break | ||
# end | ||
# global next = iterate(iter, state) | ||
# end | ||
|
||
# resetEnv() | ||
iter = HerbSearch.GuidedTraceSearchIterator(minerl_grammar_2, :SEQ) | ||
program = @time probe(Vector{Trace}(), iter, 3000000, 6) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.