From 1f0c97eb46e541a6e39cc3d2788c0fbfab6860b8 Mon Sep 17 00:00:00 2001 From: Nils Marten Mikk Date: Fri, 24 May 2024 10:07:27 +0200 Subject: [PATCH] Make MineRL evaluation more generic --- src/minecraft/getting_started_minerl.jl | 26 +++++++++++++------------ src/probe/update_grammar.jl | 6 +++--- 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/src/minecraft/getting_started_minerl.jl b/src/minecraft/getting_started_minerl.jl index 86b75eca..41361636 100644 --- a/src/minecraft/getting_started_minerl.jl +++ b/src/minecraft/getting_started_minerl.jl @@ -17,9 +17,9 @@ minerl_grammar = @pcsgrammar begin end minerl_grammar_2 = @pcsgrammar begin + 1:SEQ = [ACT] 8:DIR = 0b0001 | 0b0010 | 0b0100 | 0b1000 | 0b0101 | 0b1001 | 0b0110 | 0b1010 # forward | back | left | right | forward-left | forward-right | back-left | back-right - 1:sequence_actions = [action] - 1:action = (TIMES, DIR) + 1:ACT = (TIMES, Dict("move" => DIR, "sprint" => 1, "jump" => 1)) 6:TIMES = 5 | 10 | 25 | 50 | 75 | 100 end @@ -32,16 +32,18 @@ function evaluate_trace_minerl(prog, grammar, env, show_moves) sum_of_rewards = 0 is_done = false obs = nothing - for saved_action ∈ sequence_of_actions - times, dir = saved_action - + for (times, action) ∈ sequence_of_actions new_action = env.action_space.noop() - new_action["forward"] = dir & 1 - new_action["back"] = dir >> 1 & 1 - new_action["left"] = dir >> 2 & 1 - new_action["right"] = dir >> 3 - new_action["sprint"] = 1 - new_action["jump"] = 1 + for (key, val) in action + if key == "move" + new_action["forward"] = val & 1 + new_action["back"] = val >> 1 & 1 + new_action["left"] = val >> 2 & 1 + new_action["right"] = val >> 3 + else + new_action[key] = val + end + end for i in 1:times obs, reward, done, _ = env.step(new_action) @@ -76,5 +78,5 @@ HerbSearch.evaluate_trace(prog::RuleNode, grammar::ContextSensitiveGrammar; show HerbSearch.calculate_rule_cost(rule_index::Int, grammar::ContextSensitiveGrammar) = HerbSearch.calculate_rule_cost_prob(rule_index, grammar) # resetEnv() -iter = HerbSearch.GuidedTraceSearchIterator(minerl_grammar_2, :sequence_actions) +iter = HerbSearch.GuidedTraceSearchIterator(minerl_grammar_2, :SEQ) program = @time probe(Vector{Trace}(), iter, 3000000, 6) diff --git a/src/probe/update_grammar.jl b/src/probe/update_grammar.jl index b248b303..18ed953f 100644 --- a/src/probe/update_grammar.jl +++ b/src/probe/update_grammar.jl @@ -54,9 +54,9 @@ function update_grammar!(grammar::ContextSensitiveGrammar, PSols_with_eval_cache # TODO: think about better thing here fitness = min(best_reward / 100, 1) - p_current = 2 ^ (grammar.log_probabilities[rule_index]) + p_current = 2^(grammar.log_probabilities[rule_index]) - sum += p_current^(1 - fitness) + sum += p_current^(1 - fitness) log_prob = ((1 - fitness) * log(2, p_current)) grammar.log_probabilities[rule_index] = log_prob end @@ -66,7 +66,7 @@ function update_grammar!(grammar::ContextSensitiveGrammar, PSols_with_eval_cache total_sum += 2^(grammar.log_probabilities[rule_index]) end expr = rulenode2expr(PSols_with_eval_cache[begin].program, grammar) - grammar.rules[9] = :([$expr; action]) + grammar.rules[1] = :([$expr; ACT]) @assert abs(total_sum - 1) <= 1e-4 "Total sum is $(total_sum) " end