Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor MineRL/Probe #108

Merged
merged 6 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions src/HerbSearch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@ include("genetic_functions/select_parents.jl")
include("genetic_search_iterator.jl")

include("random_iterator.jl")

include("probe/program_cache.jl")
include("probe/sum_iterator.jl")
include("probe/new_program_iterator.jl")
include("probe/select_partial_sols.jl")
include("probe/update_grammar.jl")
include("probe/guided_search_iterator.jl")
include("probe/guided_trace_search_iterator.jl")
include("probe/probe_iterator.jl")

export
Expand Down Expand Up @@ -71,9 +79,8 @@ export
VLSNSearchIterator,
SASearchIterator,

ProbeSearchIterator,
GuidedSearchIterator,
GuidedSearchIterator,
GuidedSearchTraceIterator,
probe,

mean_squared_error,
Expand Down
69 changes: 0 additions & 69 deletions src/minecraft/create_minerl_env.jl

This file was deleted.

81 changes: 14 additions & 67 deletions src/minecraft/getting_started_minerl.jl
Original file line number Diff line number Diff line change
@@ -1,82 +1,29 @@
include("create_minerl_env.jl")
using HerbGrammar, HerbSpecification
using HerbSearch
include("minerl.jl")
include("logo_print.jl")

using HerbGrammar, HerbSpecification, HerbSearch
using Logging
disable_logging(LogLevel(1))

minerl_grammar = @pcsgrammar begin
1:action_name = "forward"
1:action_name = "left"
1:action_name = "right"
1:action_name = "back"
1:action_name = "jump"
1:sequence_actions = [sequence_actions; action]
1:sequence_actions = []
1:action = (TIMES, Dict("camera" => [0, 0], action_name => 1))
5:TIMES = 1 | 5 | 25 | 50 | 75 | 100
end

minerl_grammar_2 = @pcsgrammar begin
1:SEQ = [ACT]
8:DIR = 0b0001 | 0b0010 | 0b0100 | 0b1000 | 0b0101 | 0b1001 | 0b0110 | 0b1010 # forward | back | left | right | forward-left | forward-right | back-left | back-right
1:ACT = (TIMES, Dict("move" => DIR, "sprint" => 1, "jump" => 1))
6:TIMES = 5 | 10 | 25 | 50 | 75 | 100
end

function evaluate_trace_minerl(prog, grammar, env, show_moves)
resetPosition()
expr = rulenode2expr(prog, grammar)

sequence_of_actions = eval(expr)

sum_of_rewards = 0
is_done = false
obs = nothing
for (times, action) ∈ sequence_of_actions
new_action = env.action_space.noop()
for (key, val) in action
if key == "move"
new_action["forward"] = val & 1
new_action["back"] = val >> 1 & 1
new_action["left"] = val >> 2 & 1
new_action["right"] = val >> 3
else
new_action[key] = val
end
end

for i in 1:times
obs, reward, done, _ = env.step(new_action)
if show_moves
env.render()
end

sum_of_rewards += reward
if done
is_done = true
printstyled("sum of rewards: $sum_of_rewards. Done\n", color=:green)
break
end
end
if is_done
break
end
end
println("Reward $sum_of_rewards")
return get_xyz_from_env(obs), is_done, sum_of_rewards
end

# make sure the probabilities are equal
@assert all(prob -> prob == minerl_grammar_2.log_probabilities[begin], minerl_grammar_2.log_probabilities)
@assert all(prob -> prob == minerl_grammar.log_probabilities[begin], minerl_grammar.log_probabilities)

Check warning on line 16 in src/minecraft/getting_started_minerl.jl

View check run for this annotation

Codecov / codecov/patch

src/minecraft/getting_started_minerl.jl#L16

Added line #L16 was not covered by tests

function HerbSearch.set_env_position(x, y, z)
println("Setting env position: ($x, $y, $z)")
set_start_xyz(x, y, z)
end
# overwrite the evaluate trace function
HerbSearch.evaluate_trace(prog::RuleNode, grammar::ContextSensitiveGrammar; show_moves=false) = evaluate_trace_minerl(prog, grammar, env, show_moves)
HerbSearch.evaluate_trace(prog::RuleNode, grammar::ContextSensitiveGrammar; show_moves=true) = evaluate_trace_minerl(prog, grammar, environment, show_moves)

Check warning on line 19 in src/minecraft/getting_started_minerl.jl

View check run for this annotation

Codecov / codecov/patch

src/minecraft/getting_started_minerl.jl#L19

Added line #L19 was not covered by tests
HerbSearch.calculate_rule_cost(rule_index::Int, grammar::ContextSensitiveGrammar) = HerbSearch.calculate_rule_cost_prob(rule_index, grammar)

# resetEnv()
iter = HerbSearch.GuidedTraceSearchIterator(minerl_grammar_2, :SEQ)
program = @time probe(Vector{Trace}(), iter, 3000000, 6)
SEED = 958129
if !(@isdefined environment)
environment = create_env("MineRLNavigateDenseProgSynth-v0"; seed=SEED, inf_health=true, inf_food=true, disable_mobs=true)
end
print_logo()
iter = HerbSearch.GuidedSearchTraceIterator(minerl_grammar, :SEQ)
program = @time probe(Vector{Trace}(), iter, max_time=3000000, cycle_length=6)

17 changes: 17 additions & 0 deletions src/minecraft/logo_print.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""
print_logo()

Prints a stylized ascii art of the word probe.
"""
function print_logo()
printstyled(raw""" _

Check warning on line 7 in src/minecraft/logo_print.jl

View check run for this annotation

Codecov / codecov/patch

src/minecraft/logo_print.jl#L6-L7

Added lines #L6 - L7 were not covered by tests
| |
_ __ _ __ ___ | |__ ___
| '_ \| '__/ _ \| '_ \ / _ \
| |_) | | | (_) | |_) | __/
| .__/|_| \___/|_.__/ \___|
| |
|_| """, color=:magenta, bold=true)
println()
println(repeat("=", 80) * "\n")

Check warning on line 16 in src/minecraft/logo_print.jl

View check run for this annotation

Codecov / codecov/patch

src/minecraft/logo_print.jl#L15-L16

Added lines #L15 - L16 were not covered by tests
end
163 changes: 163 additions & 0 deletions src/minecraft/minerl.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
using PyCall
pyimport("minerl")

gym = pyimport("gym")

# WARNING: !!! NEVER MOVE THIS. It should ALWAYS be after the `pyimport`. I spent hours debugging this !!!
using HerbGrammar
mutable struct Environment
env::PyObject
settings::Dict{Symbol,Integer}
start_pos::Tuple{Float64,Float64,Float64}
end

"""
create_env(name::String; <keyword arguments>)
Create environment.
# Arguments
- `seed::Int`: the world seed.
- `inf_health::Bool`: enable infinite health.
- `inf_food::Bool`: enable infinite food.
- `disable_mobs`: disable mobs.
"""
function create_env(name::String; kwargs...)
environment = Environment(gym.make(name), Dict(kwargs), (0, 0, 0))
reset_env(environment)
return environment

Check warning on line 28 in src/minecraft/minerl.jl

View check run for this annotation

Codecov / codecov/patch

src/minecraft/minerl.jl#L25-L28

Added lines #L25 - L28 were not covered by tests
end

"""
close_env(environment::Environment)
Close `environment`.
"""
function close_env(environment::Environment)
environment.env.close()

Check warning on line 37 in src/minecraft/minerl.jl

View check run for this annotation

Codecov / codecov/patch

src/minecraft/minerl.jl#L36-L37

Added lines #L36 - L37 were not covered by tests
end

"""
reset_env(environment::Environment)
Hard reset `environment`.
"""
function reset_env(environment::Environment)
env = environment.env
settings = environment.settings

Check warning on line 47 in src/minecraft/minerl.jl

View check run for this annotation

Codecov / codecov/patch

src/minecraft/minerl.jl#L45-L47

Added lines #L45 - L47 were not covered by tests

# set seed
if haskey(settings, :seed)
env.seed(settings[:seed])

Check warning on line 51 in src/minecraft/minerl.jl

View check run for this annotation

Codecov / codecov/patch

src/minecraft/minerl.jl#L50-L51

Added lines #L50 - L51 were not covered by tests
end
obs = env.reset()

Check warning on line 53 in src/minecraft/minerl.jl

View check run for this annotation

Codecov / codecov/patch

src/minecraft/minerl.jl#L53

Added line #L53 was not covered by tests

# set start position
environment.start_pos = get_xyz_from_obs(obs)
print(environment.start_pos) #TODO: remove/change print

Check warning on line 57 in src/minecraft/minerl.jl

View check run for this annotation

Codecov / codecov/patch

src/minecraft/minerl.jl#L56-L57

Added lines #L56 - L57 were not covered by tests

# weird bug fix
action = env.action_space.noop()
action["forward"] = 1
env.step(action)

Check warning on line 62 in src/minecraft/minerl.jl

View check run for this annotation

Codecov / codecov/patch

src/minecraft/minerl.jl#L60-L62

Added lines #L60 - L62 were not covered by tests

# infinite health
if get(settings, :inf_health, false)
env.set_next_chat_message("/effect @a minecraft:instant_health 1000000 100 true")
env.step(action)

Check warning on line 67 in src/minecraft/minerl.jl

View check run for this annotation

Codecov / codecov/patch

src/minecraft/minerl.jl#L65-L67

Added lines #L65 - L67 were not covered by tests
end

# infinite food
if get(settings, :inf_food, false)
env.set_next_chat_message("/effect @a minecraft:saturation 1000000 255 true")
env.step(action)

Check warning on line 73 in src/minecraft/minerl.jl

View check run for this annotation

Codecov / codecov/patch

src/minecraft/minerl.jl#L71-L73

Added lines #L71 - L73 were not covered by tests
end

# disable mobs
if get(settings, :disable_mobs, false)
env.set_next_chat_message("/gamerule doMobSpawning false")
env.step(action)
env.set_next_chat_message("/kill @e[type=!player]")
env.step(action)

Check warning on line 81 in src/minecraft/minerl.jl

View check run for this annotation

Codecov / codecov/patch

src/minecraft/minerl.jl#L77-L81

Added lines #L77 - L81 were not covered by tests
end

printstyled("Environment created. x: $(environment.start_pos[1]), y: $(environment.start_pos[2]), z: $(environment.start_pos[3])\n", color=:green) #TODO: remove/change print

Check warning on line 84 in src/minecraft/minerl.jl

View check run for this annotation

Codecov / codecov/patch

src/minecraft/minerl.jl#L84

Added line #L84 was not covered by tests
end

"""
get_xyz_from_obs(obs)::Tuple{Float64, Float64, Float64}
Get player coordinates from `obs`.
"""
function get_xyz_from_obs(obs)::Tuple{Float64,Float64,Float64}
return obs["xpos"][1], obs["ypos"][1], obs["zpos"][1]

Check warning on line 93 in src/minecraft/minerl.jl

View check run for this annotation

Codecov / codecov/patch

src/minecraft/minerl.jl#L92-L93

Added lines #L92 - L93 were not covered by tests
end

"""
soft_reset_env(environment::Environment)
Reset player position to `environment.start_pos`.
"""
function soft_reset_env(environment::Environment)
env = environment.env
action = env.action_space.noop()
x_player_start, y_player_start, z_player_start = environment.start_pos
env.set_next_chat_message("/tp @a $(x_player_start) $(y_player_start) $(z_player_start)")

Check warning on line 105 in src/minecraft/minerl.jl

View check run for this annotation

Codecov / codecov/patch

src/minecraft/minerl.jl#L101-L105

Added lines #L101 - L105 were not covered by tests

obs = env.step(action)[1]
obsx, obsy, obsz = get_xyz_from_obs(obs)
while obsx != x_player_start || obsy != y_player_start || obsz != z_player_start
obs = env.step(action)[1]
obsx, obsy, obsz = get_xyz_from_obs(obs)
end
println((obsx, obsy, obsz)) #TODO: remove/change print

Check warning on line 113 in src/minecraft/minerl.jl

View check run for this annotation

Codecov / codecov/patch

src/minecraft/minerl.jl#L107-L113

Added lines #L107 - L113 were not covered by tests
end

"""
evaluate_trace_minerl(prog::AbstractRuleNode, grammar::ContextSensitiveGrammar, environment::Environment, show_moves::Bool)
Evaluate in MineRL `environment`.
"""
function evaluate_trace_minerl(prog::AbstractRuleNode, grammar::ContextSensitiveGrammar, environment::Environment, show_moves::Bool)
soft_reset_env(environment)

Check warning on line 122 in src/minecraft/minerl.jl

View check run for this annotation

Codecov / codecov/patch

src/minecraft/minerl.jl#L121-L122

Added lines #L121 - L122 were not covered by tests

expr = rulenode2expr(prog, grammar)
sequence_of_actions = eval(expr)

Check warning on line 125 in src/minecraft/minerl.jl

View check run for this annotation

Codecov / codecov/patch

src/minecraft/minerl.jl#L124-L125

Added lines #L124 - L125 were not covered by tests

sum_of_rewards = 0
is_done = false
obs = nothing
env = environment.env
for (times, action) sequence_of_actions
new_action = env.action_space.noop()
for (key, val) in action
if key == "move"
new_action["forward"] = val & 1
new_action["back"] = val >> 1 & 1
new_action["left"] = val >> 2 & 1
new_action["right"] = val >> 3

Check warning on line 138 in src/minecraft/minerl.jl

View check run for this annotation

Codecov / codecov/patch

src/minecraft/minerl.jl#L127-L138

Added lines #L127 - L138 were not covered by tests
else
new_action[key] = val

Check warning on line 140 in src/minecraft/minerl.jl

View check run for this annotation

Codecov / codecov/patch

src/minecraft/minerl.jl#L140

Added line #L140 was not covered by tests
end
end

Check warning on line 142 in src/minecraft/minerl.jl

View check run for this annotation

Codecov / codecov/patch

src/minecraft/minerl.jl#L142

Added line #L142 was not covered by tests

for i in 1:times
obs, reward, done, _ = env.step(new_action)
if show_moves
env.render()

Check warning on line 147 in src/minecraft/minerl.jl

View check run for this annotation

Codecov / codecov/patch

src/minecraft/minerl.jl#L144-L147

Added lines #L144 - L147 were not covered by tests
end

sum_of_rewards += reward
if done
is_done = true
printstyled("sum of rewards: $sum_of_rewards. Done\n", color=:green) #TODO: remove/change print
break

Check warning on line 154 in src/minecraft/minerl.jl

View check run for this annotation

Codecov / codecov/patch

src/minecraft/minerl.jl#L150-L154

Added lines #L150 - L154 were not covered by tests
end
end
if is_done
break

Check warning on line 158 in src/minecraft/minerl.jl

View check run for this annotation

Codecov / codecov/patch

src/minecraft/minerl.jl#L156-L158

Added lines #L156 - L158 were not covered by tests
end
end
println("Reward $sum_of_rewards") #TODO: remove/change print
return get_xyz_from_obs(obs), is_done, sum_of_rewards

Check warning on line 162 in src/minecraft/minerl.jl

View check run for this annotation

Codecov / codecov/patch

src/minecraft/minerl.jl#L160-L162

Added lines #L160 - L162 were not covered by tests
end
Loading
Loading