-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtest.py
72 lines (61 loc) · 2 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import logging
import random
import yaml
import argparse
import numpy as np
import torch
import builder
import os
from copy import deepcopy
from learning_strategies.evolution.loop import ESLoop
from moviepy.editor import ImageSequenceClip
def set_seed(seed):
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--cfg-path", type=str, default="conf/ant.yaml")
parser.add_argument("--ckpt-path", type=str)
parser.add_argument("--save-gif", action="store_true")
args = parser.parse_args()
with open(args.cfg_path) as f:
config = yaml.load(f, Loader=yaml.FullLoader)
f.close()
env = builder.build_env(config["env"])
agent_ids = env.get_agent_ids()
if args.save_gif:
run_num = args.ckpt_path.split("/")[-3]
save_dir = f"test_gif/{run_num}/"
os.makedirs(save_dir)
network = builder.build_network(config["network"])
network.load_state_dict(torch.load(args.ckpt_path))
for i in range(100):
models = {}
for agent_id in agent_ids:
models[agent_id] = deepcopy(network)
models[agent_id].eval()
models[agent_id].reset()
obs = env.reset()
done = False
episode_reward = 0
ep_step = 0
ep_render_lst = []
while not done:
actions = {}
for k, model in models.items():
s = obs[k]["state"][np.newaxis, ...]
actions[k] = model(s)
obs, r, done, _ = env.step(actions)
rgb_array = env.render()
if args.save_gif:
ep_render_lst.append(rgb_array)
episode_reward += r
ep_step += 1
print("reward: ", episode_reward, "ep_step: ", ep_step)
if args.save_gif:
clip = ImageSequenceClip(ep_render_lst, fps=30)
clip.write_gif(save_dir + f"ep_{i}.gif", fps=30)
del ep_render_lst
if __name__ == "__main__":
main()