-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_digging_cnn.py
137 lines (119 loc) · 4.26 KB
/
test_digging_cnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import os
from random import seed
import numpy as np
from gym.envs.registration import make
from stable_baselines3 import SAC, PPO, A2C
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.sac.policies import MlpPolicy
import gym
from stable_baselines3.common.vec_env import (
SubprocVecEnv,
VecTransposeImage,
DummyVecEnv,
)
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3 import TD3
from stable_baselines3.common import results_plotter
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import (
SubprocVecEnv,
DummyVecEnv,
VecTransposeImage,
)
from stable_baselines3.common.results_plotter import load_results, ts2xy, plot_results
from stable_baselines3.common.callbacks import (
CallbackList,
CheckpointCallback,
EvalCallback,
)
from stable_baselines3.common.env_checker import check_env
from custom_callbacks import VideoRecorderCallback, TensorboardCallback
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
import tensorflow as tf
from helpers import make_boxworld, make_env, parallel_boxworlds, parallel_worlds
from net import RelationalNet
import torch as th
import torch.nn as nn
import matplotlib.pyplot as plt
from extractors import RelationalNet, DeeperExtractor, SimpleExtractor, SimpleExtractorDict
import heightgrid # register the environment
# checkpoint (model and replay buffer): check
# logging (personalized): -
# on custom environment
# tensorflow integration: check
if __name__ == "__main__":
log_dir = "./logs/heightgrid/ppo/16x16_cnn"
os.makedirs(log_dir, exist_ok=True)
env_id = "HeightGrid-Empty-Random-16x16-v0"
env = parallel_worlds(env_id, log_dir=log_dir, flat_obs=False, num_envs=1)
eval_env = make_env(env_id, log_dir=log_dir, seed=24, flat_obs=False)()
# figure, ax = eval_env.render()
# plt.plot(figure)
policy_kwargs = dict(
features_extractor_class=SimpleExtractorDict,
net_arch=[dict(pi=[256], vf=[256])]
)
model = PPO(
ActorCriticPolicy,
env,
gamma=1,
batch_size=64,
n_steps=512, # with 12 environments these are 32 trajectories
n_epochs=4,
ent_coef=0.001,
policy_kwargs=policy_kwargs,
verbose=1,
create_eval_env=True,
tensorboard_log=log_dir,
)
# with steps 2058 * num_envs
checkpoint_callback = CheckpointCallback(
save_freq=200000, save_path=log_dir, name_prefix="ppo_goal_target_16"
)
# Separate evaluation env
# eval_env.render('human')
# check_env(eval_env)
# print("Created eval env")
eval_callback = EvalCallback(
eval_env,
best_model_save_path=log_dir + "/best_model",
log_path=log_dir + "/results",
eval_freq=20000,
)
# not saving
video_recorder = VideoRecorderCallback(eval_env, render_freq=20000)
callbacks = CallbackList([eval_callback, checkpoint_callback])
# Create the callback list
# Evaluate the model every 1000 steps on 5 test episodes
# and save the evaluation to the "logs/" folder
# model.load('logs/box_world/ppo_split_model_4800000_steps', env=env)
model.learn(
10000000,
callback=callbacks,
tb_log_name="ppo_goal_16x16",
reset_num_timesteps=False,
)
# save the model
model.save("ppo_goal_8x8")
# now save the replay buffer too
# model.save_replay_buffer("sac_replay_buffer")
# # load the model
# loaded_model = PPO.load("ppo_boxworld", env=env)
# # load it into the loaded_model
# # loaded_model.load_replay_buffer("sac_replay_buffer")
# loaded_model.learn(
# total_timesteps=7000, tb_log_name="second", reset_num_timesteps=False
# )
# # Retrieve the environment
# env = model.get_env()
# Evaluate the policy
mean_reward, std_reward = evaluate_policy(
model.policy, eval_env, n_eval_episodes=10, deterministic=True
)
print(f"mean_reward={mean_reward:.2f} +/- {std_reward}")
obs = eval_env.reset()
for i in range(1000):
action, _states = model.predict(obs, deterministic=True)
obs, rewards, dones, info = eval_env.step(action)
env.render("human")