-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #108 from Herb-AI/refactor-minerl
Refactor MineRL/Probe
- Loading branch information
Showing
14 changed files
with
312 additions
and
286 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,82 +1,29 @@ | ||
include("create_minerl_env.jl") | ||
using HerbGrammar, HerbSpecification | ||
using HerbSearch | ||
include("minerl.jl") | ||
include("logo_print.jl") | ||
|
||
using HerbGrammar, HerbSpecification, HerbSearch | ||
using Logging | ||
disable_logging(LogLevel(1)) | ||
|
||
minerl_grammar = @pcsgrammar begin | ||
1:action_name = "forward" | ||
1:action_name = "left" | ||
1:action_name = "right" | ||
1:action_name = "back" | ||
1:action_name = "jump" | ||
1:sequence_actions = [sequence_actions; action] | ||
1:sequence_actions = [] | ||
1:action = (TIMES, Dict("camera" => [0, 0], action_name => 1)) | ||
5:TIMES = 1 | 5 | 25 | 50 | 75 | 100 | ||
end | ||
|
||
minerl_grammar_2 = @pcsgrammar begin | ||
1:SEQ = [ACT] | ||
8:DIR = 0b0001 | 0b0010 | 0b0100 | 0b1000 | 0b0101 | 0b1001 | 0b0110 | 0b1010 # forward | back | left | right | forward-left | forward-right | back-left | back-right | ||
1:ACT = (TIMES, Dict("move" => DIR, "sprint" => 1, "jump" => 1)) | ||
6:TIMES = 5 | 10 | 25 | 50 | 75 | 100 | ||
end | ||
|
||
function evaluate_trace_minerl(prog, grammar, env, show_moves) | ||
resetPosition() | ||
expr = rulenode2expr(prog, grammar) | ||
|
||
sequence_of_actions = eval(expr) | ||
|
||
sum_of_rewards = 0 | ||
is_done = false | ||
obs = nothing | ||
for (times, action) ∈ sequence_of_actions | ||
new_action = env.action_space.noop() | ||
for (key, val) in action | ||
if key == "move" | ||
new_action["forward"] = val & 1 | ||
new_action["back"] = val >> 1 & 1 | ||
new_action["left"] = val >> 2 & 1 | ||
new_action["right"] = val >> 3 | ||
else | ||
new_action[key] = val | ||
end | ||
end | ||
|
||
for i in 1:times | ||
obs, reward, done, _ = env.step(new_action) | ||
if show_moves | ||
env.render() | ||
end | ||
|
||
sum_of_rewards += reward | ||
if done | ||
is_done = true | ||
printstyled("sum of rewards: $sum_of_rewards. Done\n", color=:green) | ||
break | ||
end | ||
end | ||
if is_done | ||
break | ||
end | ||
end | ||
println("Reward $sum_of_rewards") | ||
return get_xyz_from_env(obs), is_done, sum_of_rewards | ||
end | ||
|
||
# make sure the probabilities are equal | ||
@assert all(prob -> prob == minerl_grammar_2.log_probabilities[begin], minerl_grammar_2.log_probabilities) | ||
@assert all(prob -> prob == minerl_grammar.log_probabilities[begin], minerl_grammar.log_probabilities) | ||
|
||
function HerbSearch.set_env_position(x, y, z) | ||
println("Setting env position: ($x, $y, $z)") | ||
set_start_xyz(x, y, z) | ||
end | ||
# overwrite the evaluate trace function | ||
HerbSearch.evaluate_trace(prog::RuleNode, grammar::ContextSensitiveGrammar; show_moves=false) = evaluate_trace_minerl(prog, grammar, env, show_moves) | ||
HerbSearch.evaluate_trace(prog::RuleNode, grammar::ContextSensitiveGrammar; show_moves=true) = evaluate_trace_minerl(prog, grammar, environment, show_moves) | ||
HerbSearch.calculate_rule_cost(rule_index::Int, grammar::ContextSensitiveGrammar) = HerbSearch.calculate_rule_cost_prob(rule_index, grammar) | ||
|
||
# resetEnv() | ||
iter = HerbSearch.GuidedTraceSearchIterator(minerl_grammar_2, :SEQ) | ||
program = @time probe(Vector{Trace}(), iter, 3000000, 6) | ||
SEED = 958129 | ||
if !(@isdefined environment) | ||
environment = create_env("MineRLNavigateDenseProgSynth-v0"; seed=SEED, inf_health=true, inf_food=true, disable_mobs=true) | ||
end | ||
print_logo() | ||
iter = HerbSearch.GuidedSearchTraceIterator(minerl_grammar, :SEQ) | ||
program = @time probe(Vector{Trace}(), iter, max_time=3000000, cycle_length=6) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
""" | ||
print_logo() | ||
Prints a stylized ascii art of the word probe. | ||
""" | ||
function print_logo() | ||
printstyled(raw"""______ _ __ ___ ____ ______ _ | ||
| ___ \ | | / / | \/ (_) | ___ \ | | ||
| |_/ / __ ___ | |__ ___ / / | . . |_ _ __ ___| |_/ / | | ||
| __/ '__/ _ \| '_ \ / _ \ / / | |\/| | | '_ \ / _ \ /| | | ||
| | | | | (_) | |_) | __/ / / | | | | | | | | __/ |\ \| |____ | ||
\_| |_| \___/|_.__/ \___| /_/ \_| |_/_|_| |_|\___\_| \_\_____/ | ||
""", color=:magenta, bold=true) | ||
println() | ||
println(repeat("=", 80) * "\n") | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
using PyCall | ||
pyimport("minerl") | ||
|
||
gym = pyimport("gym") | ||
|
||
# WARNING: !!! NEVER MOVE THIS. It should ALWAYS be after the `pyimport`. I spent hours debugging this !!! | ||
using HerbGrammar | ||
mutable struct Environment | ||
env::PyObject | ||
settings::Dict{Symbol,Integer} | ||
start_pos::Tuple{Float64,Float64,Float64} | ||
end | ||
|
||
""" | ||
create_env(name::String; <keyword arguments>) | ||
Create environment. | ||
# Arguments | ||
- `seed::Int`: the world seed. | ||
- `inf_health::Bool`: enable infinite health. | ||
- `inf_food::Bool`: enable infinite food. | ||
- `disable_mobs`: disable mobs. | ||
""" | ||
function create_env(name::String; kwargs...) | ||
environment = Environment(gym.make(name), Dict(kwargs), (0, 0, 0)) | ||
reset_env(environment) | ||
return environment | ||
end | ||
|
||
""" | ||
close_env(environment::Environment) | ||
Close `environment`. | ||
""" | ||
function close_env(environment::Environment) | ||
environment.env.close() | ||
end | ||
|
||
""" | ||
reset_env(environment::Environment) | ||
Hard reset `environment`. | ||
""" | ||
function reset_env(environment::Environment) | ||
env = environment.env | ||
settings = environment.settings | ||
|
||
# set seed | ||
if haskey(settings, :seed) | ||
env.seed(settings[:seed]) | ||
end | ||
obs = env.reset() | ||
|
||
# set start position | ||
environment.start_pos = get_xyz_from_obs(obs) | ||
print(environment.start_pos) #TODO: remove/change print | ||
|
||
# weird bug fix | ||
action = env.action_space.noop() | ||
action["forward"] = 1 | ||
env.step(action) | ||
|
||
# infinite health | ||
if get(settings, :inf_health, false) | ||
env.set_next_chat_message("/effect @a minecraft:instant_health 1000000 100 true") | ||
env.step(action) | ||
end | ||
|
||
# infinite food | ||
if get(settings, :inf_food, false) | ||
env.set_next_chat_message("/effect @a minecraft:saturation 1000000 255 true") | ||
env.step(action) | ||
end | ||
|
||
# disable mobs | ||
if get(settings, :disable_mobs, false) | ||
env.set_next_chat_message("/gamerule doMobSpawning false") | ||
env.step(action) | ||
env.set_next_chat_message("/kill @e[type=!player]") | ||
env.step(action) | ||
end | ||
|
||
printstyled("Environment created. x: $(environment.start_pos[1]), y: $(environment.start_pos[2]), z: $(environment.start_pos[3])\n", color=:green) #TODO: remove/change print | ||
end | ||
|
||
""" | ||
get_xyz_from_obs(obs)::Tuple{Float64, Float64, Float64} | ||
Get player coordinates from `obs`. | ||
""" | ||
function get_xyz_from_obs(obs)::Tuple{Float64,Float64,Float64} | ||
return obs["xpos"][1], obs["ypos"][1], obs["zpos"][1] | ||
end | ||
|
||
""" | ||
soft_reset_env(environment::Environment) | ||
Reset player position to `environment.start_pos`. | ||
""" | ||
function soft_reset_env(environment::Environment) | ||
env = environment.env | ||
action = env.action_space.noop() | ||
x_player_start, y_player_start, z_player_start = environment.start_pos | ||
env.set_next_chat_message("/tp @a $(x_player_start) $(y_player_start) $(z_player_start)") | ||
|
||
obs = env.step(action)[1] | ||
obsx, obsy, obsz = get_xyz_from_obs(obs) | ||
while obsx != x_player_start || obsy != y_player_start || obsz != z_player_start | ||
obs = env.step(action)[1] | ||
obsx, obsy, obsz = get_xyz_from_obs(obs) | ||
end | ||
println((obsx, obsy, obsz)) #TODO: remove/change print | ||
end | ||
|
||
""" | ||
evaluate_trace_minerl(prog::AbstractRuleNode, grammar::ContextSensitiveGrammar, environment::Environment, show_moves::Bool) | ||
Evaluate in MineRL `environment`. | ||
""" | ||
function evaluate_trace_minerl(prog::AbstractRuleNode, grammar::ContextSensitiveGrammar, environment::Environment, show_moves::Bool) | ||
soft_reset_env(environment) | ||
|
||
expr = rulenode2expr(prog, grammar) | ||
sequence_of_actions = eval(expr) | ||
|
||
sum_of_rewards = 0 | ||
is_done = false | ||
obs = nothing | ||
env = environment.env | ||
for (times, action) ∈ sequence_of_actions | ||
new_action = env.action_space.noop() | ||
for (key, val) in action | ||
if key == "move" | ||
new_action["forward"] = val & 1 | ||
new_action["back"] = val >> 1 & 1 | ||
new_action["left"] = val >> 2 & 1 | ||
new_action["right"] = val >> 3 | ||
else | ||
new_action[key] = val | ||
end | ||
end | ||
|
||
for i in 1:times | ||
obs, reward, done, _ = env.step(new_action) | ||
if show_moves | ||
env.render() | ||
end | ||
|
||
sum_of_rewards += reward | ||
if done | ||
is_done = true | ||
printstyled("sum of rewards: $sum_of_rewards. Done\n", color=:green) #TODO: remove/change print | ||
break | ||
end | ||
end | ||
if is_done | ||
break | ||
end | ||
end | ||
println("Reward $sum_of_rewards") #TODO: remove/change print | ||
return get_xyz_from_obs(obs), is_done, sum_of_rewards | ||
end |
Oops, something went wrong.