-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaluate_policy.py
87 lines (70 loc) · 2.57 KB
/
evaluate_policy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import argparse
import logging
import os
import json
from stable_baselines3 import PPO
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.evaluation import evaluate_policy
from envs import TowerfallBlankEnv, GridObservation, PlayerObservation, FollowCloseTargetCurriculum
from common import GridView
class NoLevelFormatter(logging.Formatter):
def format(self, record):
return record.getMessage()
logging.basicConfig(level=logging.INFO)
logging.getLogger().handlers[0].setFormatter(NoLevelFormatter())
_HOST = '127.0.0.1'
_PORT = 12024
def input_blocking_callback(locals, globals):
import ipdb; ipdb.set_trace()
def create_env(configs, connection: Connection) -> TowerfallBlankEnv:
grid_view = GridView(grid_factor=5)
objective = FollowCloseTargetCurriculum(grid_view, **configs['objective_params'])
print('CREATIN ENV')
env = TowerfallBlankEnv(
connection=connection,
observations= [
GridObservation(grid_view, **configs['grid_params']),
PlayerObservation()
],
objective=objective)
print('ENV CREATED')
check_env(env)
return env
def evaluate(load_from: str):
logging.info(f'Loading experiment from {load_from}')
with open(os.path.join(load_from, 'hparams.json'), 'r') as file:
configs = json.load(file)
connection = Connection(_HOST, _PORT)
env = Monitor(create_env(configs, connection))
model_names = os.listdir(os.path.join(load_from, 'models'))
last_model = None
last_step = -1
for name in model_names:
if name == 'model.zip':
last_model = 'model.zip'
break
step = int(name.replace('.zip', ''))
if step > last_step:
last_step = step
last_model = name
last_model = os.path.join(load_from, 'models', last_model)
# last_model = os.path.abspath(last_model)
logging.info(f'Loading model from {last_model}')
model = PPO.load(last_model)
logging.info(f'Running evaluation for {last_model}')
logging.info('Deterministic=False')
evaluate_policy(model,
env=env,
n_eval_episodes=50,
render=False,
deterministic=False,
callback=input_blocking_callback)
# logging.info('Deterministic=True')
# evaluate_policy(model, env=env, n_eval_episodes=30, render=False, deterministic=True)
# logging.info(f'Finished evaluation for {last_model}')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--load-from', type=str, required=True)
args = parser.parse_args()
evaluate(args.load_from)