-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy path08_ale.py
122 lines (106 loc) · 3.5 KB
/
08_ale.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from argparse import ArgumentParser
from functools import partial
import wandb
from amago.envs.builtin.ale_retro import AtariAMAGOWrapper, AtariGame
from amago.nets.cnn import NatureishCNN, IMPALAishCNN
from amago.cli_utils import *
def add_cli(parser):
parser.add_argument("--games", nargs="+", default=None)
parser.add_argument("--max_seq_len", type=int, default=80)
parser.add_argument(
"--cnn", type=str, choices=["nature", "impala"], default="impala"
)
return parser
DEFAULT_MULTIGAME_LIST = [
"Pong",
"Boxing",
"Breakout",
"Gopher",
"MsPacman",
"ChopperCommand",
"CrazyClimber",
"BattleZone",
"Qbert",
"Seaquest",
]
ATARI_TIME_LIMIT = (30 * 60 * 60) // 5 # (30 minutes of game time)
def make_atari_game(game_name):
return AtariAMAGOWrapper(
AtariGame(
game=game_name,
action_space="discrete",
terminal_on_life_loss=False,
version="v5",
frame_skip=5,
grayscale=False,
sticky_action_prob=0.25,
clip_rewards=False,
),
)
if __name__ == "__main__":
parser = ArgumentParser()
add_cli(parser)
add_common_cli(parser)
args = parser.parse_args()
config = {
"amago.agent.Agent.reward_multiplier": 0.25,
"amago.agent.Agent.offline_coeff": (
1.0 if args.agent_type == "multitask" else 0.0
),
}
traj_encoder_type = switch_traj_encoder(
config,
arch=args.traj_encoder,
memory_size=args.memory_size,
layers=args.memory_layers,
)
if args.cnn == "nature":
cnn_type = NatureishCNN
elif args.cnn == "impala":
cnn_type = IMPALAishCNN
tstep_encoder_type = switch_tstep_encoder(
config,
arch="cnn",
cnn_type=cnn_type,
channels_first=True,
drqv2_aug=True,
)
agent_type = switch_agent(config, args.agent_type)
use_config(config, args.configs)
# Episode lengths in Atari vary widely across games, so we manually set actors
# to a specific game so that all games are always played in parallel.
games = args.games or DEFAULT_MULTIGAME_LIST
assert (
args.parallel_actors % len(games) == 0
), "Number of actors must be divisible by number of games."
env_funcs = []
for actor in range(args.parallel_actors):
game_name = games[actor % len(games)]
env_funcs.append(partial(make_atari_game, game_name))
group_name = f"{args.run_name}_atari_l_{args.max_seq_len}_cnn_{args.cnn}"
for trial in range(args.trials):
run_name = group_name + f"_trial_{trial}"
experiment = create_experiment_from_cli(
args,
make_train_env=env_funcs,
make_val_env=env_funcs,
max_seq_len=args.max_seq_len,
traj_save_len=args.max_seq_len * 3,
run_name=run_name,
tstep_encoder_type=tstep_encoder_type,
traj_encoder_type=traj_encoder_type,
agent_type=agent_type,
group_name=group_name,
val_timesteps_per_epoch=ATARI_TIME_LIMIT,
save_trajs_as="npz-compressed",
)
switch_async_mode(experiment, args.mode)
experiment.start()
if args.ckpt is not None:
experiment.load_checkpoint(args.ckpt)
experiment.learn()
experiment.evaluate_test(
env_funcs, timesteps=ATARI_TIME_LIMIT * 5, render=False
)
experiment.delete_buffer_from_disk()
wandb.finish()