-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathmain.py
58 lines (45 loc) · 1.73 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
"""
This file is an example train and test loop for the different environments.
Selecting different environments is done through setting the 'env_name' variable.
TODO:
* add rgb_array rendering for the different environments to allow saving videos
"""
import gymnasium as gym
from stable_baselines3 import PPO, SAC, TD3, DDPG
import numpy as np
import bluesky_gym
import bluesky_gym.envs
from bluesky_gym.utils import logger
bluesky_gym.register_envs()
env_name = 'StaticObstacleEnv-v0'
algorithm = SAC
# Initialize logger
log_dir = f'./logs/{env_name}/'
file_name = f'{env_name}_{str(algorithm.__name__)}.csv'
csv_logger_callback = logger.CSVLoggerCallback(log_dir, file_name)
TRAIN = False
EVAL_EPISODES = 10
if __name__ == "__main__":
env = gym.make(env_name, render_mode=None)
obs, info = env.reset()
model = algorithm("MultiInputPolicy", env, verbose=1,learning_rate=3e-4)
if TRAIN:
model.learn(total_timesteps=2e6, callback=csv_logger_callback)
model.save(f"models/{env_name}/{env_name}_{str(algorithm.__name__)}/model")
del model
env.close()
# Test the trained model
model = algorithm.load(f"models/{env_name}/{env_name}_{str(algorithm.__name__)}/model", env=env)
env = gym.make(env_name, render_mode="human")
for i in range(EVAL_EPISODES):
done = truncated = False
obs, info = env.reset()
tot_rew = 0
while not (done or truncated):
# action = np.array(np.random.randint(-100,100,size=(2))/100)
# action = np.array([0,-1])
action, _states = model.predict(obs, deterministic=True)
obs, reward, done, truncated, info = env.step(action[()])
tot_rew += reward
print(tot_rew)
env.close()