Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Misc/veridream #181

Draft
wants to merge 87 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
87 commits
Select commit Hold shift + click to select a range
c963be2
Remove HER
araffin Apr 5, 2021
40ce4c5
Add basic support for refactored HER
araffin Apr 5, 2021
7fad104
Add TQC
araffin Apr 7, 2021
b4d26c9
Add space engineer env
araffin Apr 27, 2021
76b40d6
Merge branch 'master' into refactor/her
araffin Apr 29, 2021
e7e9ee6
Update hyperparams
araffin Apr 29, 2021
ae706eb
Fix hyperparam
araffin Apr 29, 2021
d079811
Merge branch 'master' into refactor/her
araffin May 3, 2021
73bea6e
Removed unused callback
araffin May 4, 2021
7e96e10
Update CI
araffin May 4, 2021
7c4f1bc
Add partial support for parallel training
araffin May 4, 2021
b635157
Cleanup
araffin May 5, 2021
2966689
Avoid modify by reference + add sleep time
araffin May 5, 2021
aa6d934
Take learning starts into account
araffin May 6, 2021
76ecb24
Merge branch 'master' into misc/veridream
araffin May 10, 2021
54d5596
Merge branch 'feat/parallel-train' into misc/veridream
araffin May 10, 2021
563d853
Update hyperparams
araffin May 11, 2021
779c4c6
Add dict obs support
araffin May 11, 2021
43188ad
Update test env
araffin May 12, 2021
20c5dd8
Merge branch 'master' into refactor/her
araffin May 12, 2021
b461048
Version bump
araffin May 12, 2021
15c8fa1
Merge branch 'refactor/her' into misc/veridream
araffin May 12, 2021
4857bca
Merge branch 'master' into misc/veridream
araffin May 12, 2021
f73a65e
Update for symmetric control + catch zmq error
araffin May 14, 2021
da441b8
Save best model
araffin May 15, 2021
bf05e92
Fix parallel save (maybe issue with optimizer)
araffin May 16, 2021
a9d63db
Update hyperparams
araffin May 16, 2021
d395120
Update best params
araffin May 17, 2021
5634128
Update hyperparams
araffin May 19, 2021
c5b7f55
Prepare big network experiment
araffin May 19, 2021
976833d
Revert to normal net
araffin May 19, 2021
c6b2dce
Add exception for windows
araffin May 19, 2021
6703390
Update plot script: allow multiple envs
araffin May 19, 2021
a771249
Add bert params
araffin May 21, 2021
c966371
Save multirobot hyperparams
araffin May 26, 2021
f6f3dd0
Merge branch 'master' into misc/veridream
araffin May 27, 2021
a99f691
Add POC for VecEnvWrapper
araffin May 30, 2021
c7bd1b9
Add support for SubProc
araffin May 30, 2021
f9f97d9
Update params
araffin Jun 2, 2021
616ebda
Merge branch 'master' into misc/veridream
araffin Jun 3, 2021
00d7cd9
Add Phase Feature
araffin Jun 4, 2021
04f7cc8
Test with phase only
araffin Jun 4, 2021
87cfd39
Add stop on reward threshold callback
araffin Jun 4, 2021
497e669
Update params
araffin Jun 4, 2021
651344d
Merge branch 'master' into misc/veridream
araffin Jun 6, 2021
16f93e4
Hack for zmq + early termination
araffin Jun 6, 2021
ba98f7a
Merge branch 'misc/veridream' of github.com:DLR-RM/rl-baselines3-zoo …
araffin Jun 6, 2021
96f3f75
Add human teleop
araffin Jun 7, 2021
861dc58
Merge branch 'misc/veridream' of github.com:DLR-RM/rl-baselines3-zoo …
araffin Jun 7, 2021
4b43731
Bug fixes for teleop
araffin Jun 7, 2021
265f7ce
One controller per movement
araffin Jun 7, 2021
7b070ef
Update hyperparams
araffin Jun 9, 2021
87c80b7
Merge branch 'master' into misc/veridream
araffin Jun 9, 2021
b841d9f
Update name
araffin Jun 9, 2021
4f8941b
Merge branch 'master' into misc/veridream
araffin Jun 9, 2021
2143e74
Reformat
araffin Jun 9, 2021
8bae4c5
Update env name + add defaults
araffin Jun 10, 2021
eb9580b
Prepare hyperparam optim
araffin Jun 11, 2021
77308d7
Fixes for hyperparam optim
araffin Jun 11, 2021
78c064f
Test with PPO
araffin Jun 14, 2021
6d24d70
Better disable train
araffin Jun 14, 2021
ba79d71
Merge branch 'misc/veridream' of github.com:DLR-RM/rl-baselines3-zoo …
araffin Jun 14, 2021
859fcec
Merge branch 'master' into misc/veridream
araffin Jun 21, 2021
c2c53b2
Merge branch 'master' into misc/veridream
araffin Jun 21, 2021
9c6dc41
Merge branch 'master' into misc/veridream
araffin Jul 13, 2021
1d6b67e
Add mixture of experts policy
araffin Jul 16, 2021
c3473ed
Bug fixes
araffin Jul 16, 2021
2bd21b6
Stop grad + additional experts
araffin Jul 19, 2021
79c0deb
Learn from scratch
araffin Jul 20, 2021
dcd2607
Reformat
araffin Jul 20, 2021
2ea1bb4
Add multi task controller
araffin Jul 20, 2021
49d9391
One policy to rule them all
araffin Jul 29, 2021
1e426d0
Merge branch 'master' into misc/veridream
araffin Jul 29, 2021
6e7cc95
Add forward left
araffin Aug 13, 2021
1659381
Merge remote-tracking branch 'origin/master' into misc/veridream
araffin Sep 11, 2021
67706f8
Merge branch 'master' into misc/veridream
araffin Sep 27, 2021
559bd44
Rename task
araffin Oct 4, 2021
40df557
Merge branch 'master' into misc/veridream
araffin Oct 20, 2021
7c082db
Merge branch 'misc/veridream' of github.com:DLR-RM/rl-baselines3-zoo …
araffin Oct 20, 2021
99758dc
Merge branch 'master' into misc/veridream
araffin Oct 27, 2021
34c8676
Remove unused requirements
araffin Nov 2, 2021
7efd1a0
Update env name
araffin Nov 3, 2021
0c8192c
Add new env
araffin Nov 3, 2021
ce70bfd
Fix predict
araffin Nov 4, 2021
87001ed
Change the task only when needed
araffin Nov 15, 2021
73a87b4
Merge branch 'master' into misc/veridream
araffin Jan 10, 2022
a652aa6
Sync vec normalize for parallel training
araffin Jan 7, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions hyperparams/human.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Space Engineers envs
SE-WalkingTest-v1:
env_wrapper:
- utils.wrappers.HistoryWrapper:
horizon: 2
n_timesteps: !!float 2e6
policy: 'MlpPolicy'
94 changes: 94 additions & 0 deletions hyperparams/tqc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -254,3 +254,97 @@ parking-v0:
n_sampled_goal=4,
max_episode_length=100
)"

# Space Engineers envs
SE-Forward-v1: &defaults
env_wrapper:
- utils.wrappers.HistoryWrapper:
horizon: 2
vec_env_wrapper:
- utils.wrappers.VecForceResetWrapper
callback:
- utils.callbacks.ParallelTrainCallback:
gradient_steps: 400
# - utils.callbacks.StopTrainingOnMeanRewardThreshold:
# reward_threshold: 250
# verbose: 1
n_timesteps: !!float 5e6
policy: 'MlpPolicy'
learning_rate: !!float 7.3e-4
buffer_size: 100000
batch_size: 256
ent_coef: 'auto'
gamma: 0.98
tau: 0.05
# train_freq: [1, "episode"]
train_freq: 100
n_envs: 4
gradient_steps: -1
learning_starts: 800
use_sde: False
top_quantiles_to_drop_per_net: 2
policy_kwargs: "dict(net_arch=[256, 256], n_critics=2)"

SE-Symmetric-v1:
<<: *defaults

SE-Corrections-v1:
<<: *defaults

SE-Generic-v1:
<<: *defaults
callback:
- utils.callbacks.ParallelTrainCallback:
gradient_steps: 400
# - utils.callbacks.StopTrainingOnMeanRewardThreshold:
# reward_threshold: 250
# verbose: 1

SE-TurnLeft-v1:
<<: *defaults
callback:
- utils.callbacks.ParallelTrainCallback:
gradient_steps: 400
- utils.callbacks.StopTrainingOnMeanRewardThreshold:
reward_threshold: 250
verbose: 1

SE-MultiTask-v1:
<<: *defaults
# policy: 'MixtureMlpPolicy'
learning_rate: !!float 7.3e-4
# gamma: 0.99
# tau: 0.005
buffer_size: 200000
callback:
- utils.callbacks.ParallelTrainCallback:
gradient_steps: 400
# policy_kwargs: "dict(net_arch=[400, 300], n_critics=2, n_additional_experts=2)"
policy_kwargs: "dict(net_arch=[256, 256], n_critics=5)"


# ======== Real Robot envs ============

WalkingBertSim-v1:
env_wrapper:
- utils.wrappers.HistoryWrapper:
horizon: 2
callback:
- utils.callbacks.ParallelTrainCallback:
gradient_steps: 400
n_timesteps: !!float 2e6
policy: 'MlpPolicy'
learning_rate: !!float 7.3e-4
buffer_size: 300000
batch_size: 256
ent_coef: 'auto'
gamma: 0.98
tau: 0.02
train_freq: [1, "episode"]
gradient_steps: -1
learning_starts: 1200
use_sde_at_warmup: True
use_sde: True
sde_sample_freq: 4
top_quantiles_to_drop_per_net: 2
policy_kwargs: "dict(log_std_init=-3, net_arch=[256, 256], n_critics=2)"
16 changes: 9 additions & 7 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
gym>=0.17,<0.20
stable-baselines3[extra,tests,docs]>=1.3.1a8
sb3-contrib>=1.3.1a7
box2d-py==2.3.8
pybullet
# stable-baselines3[extra,tests,docs]>=1.3.0
sb3-contrib>=1.3.0
# box2d-py==2.3.8
# pybullet
gym-minigrid
scikit-optimize
optuna
Expand All @@ -11,7 +11,9 @@ seaborn
pyyaml>=5.1
cloudpickle>=1.5.0
# tmp fix: ROM missing in newest release
atari-py==0.2.6
# atari-py==0.2.6
plotly
panda-gym==1.1.1 # tmp fix: until compatibility with panda-gym v2
rliable>=1.0.5
pygame
# panda-gym>=1.1.1
# rliable requires python 3.7+
# rliable>=1.0.5
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ per-file-ignores =
./scripts/all_plots.py:E501
./scripts/plot_train.py:E501
./scripts/plot_training_success.py:E501
./utils/teleop.py:F405

exclude =
# No need to traverse our git directory
Expand Down
47 changes: 46 additions & 1 deletion utils/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import pickle
import tempfile
import time
from copy import deepcopy
Expand All @@ -10,7 +11,8 @@
from sb3_contrib import TQC
from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import BaseCallback, EvalCallback
from stable_baselines3.common.vec_env import VecEnv
from stable_baselines3.common.utils import safe_mean
from stable_baselines3.common.vec_env import VecEnv, sync_envs_normalization


class TrialEvalCallback(EvalCallback):
Expand Down Expand Up @@ -129,6 +131,12 @@ def _init_callback(self) -> None:

self.model.save(temp_file)

if self.model.get_vec_normalize_env() is not None:
temp_file_norm = os.path.join("logs", "vec_normalize.pkl")

with open(temp_file_norm, "wb") as file_handler:
pickle.dump(self.model.get_vec_normalize_env(), file_handler)

# TODO: add support for other algorithms
for model_class in [SAC, TQC]:
if isinstance(self.model, model_class):
Expand All @@ -138,6 +146,11 @@ def _init_callback(self) -> None:
assert self.model_class is not None, f"{self.model} is not supported for parallel training"
self._model = self.model_class.load(temp_file)

if self.model.get_vec_normalize_env() is not None:
with open(temp_file_norm, "rb") as file_handler:
self._model._vec_normalize_env = pickle.load(file_handler)
self._model._vec_normalize_env.training = False

self.batch_size = self._model.batch_size

# Disable train method
Expand Down Expand Up @@ -182,6 +195,10 @@ def _on_rollout_end(self) -> None:
self._model.replay_buffer = deepcopy(self.model.replay_buffer)
self.model.set_parameters(deepcopy(self._model.get_parameters()))
self.model.actor = self.model.policy.actor
# Sync VecNormalize
if self.model.get_vec_normalize_env() is not None:
sync_envs_normalization(self.model.get_vec_normalize_env(), self._model._vec_normalize_env)

if self.num_timesteps >= self._model.learning_starts:
self.train()
# Do not wait for the training loop to finish
Expand All @@ -193,3 +210,31 @@ def _on_training_end(self) -> None:
if self.verbose > 0:
print("Waiting for training thread to terminate")
self.process.join()


class StopTrainingOnMeanRewardThreshold(BaseCallback):
"""
Stop the training once a threshold in mean episodic reward
has been reached (i.e. when the model is good enough).

:param reward_threshold: Minimum expected reward per episode
to stop training.
:param verbose:
"""

def __init__(self, reward_threshold: float, verbose: int = 0):
super().__init__(verbose=verbose)
self.reward_threshold = reward_threshold

def _on_step(self) -> bool:
continue_training = True
if len(self.model.ep_info_buffer) > 0 and len(self.model.ep_info_buffer[0]) > 0:
mean_reward = safe_mean([ep_info["r"] for ep_info in self.model.ep_info_buffer])
# Convert np.bool_ to bool, otherwise callback() is False won't work
continue_training = bool(mean_reward < self.reward_threshold)
if self.verbose > 0 and not continue_training:
print(
f"Stopping training because the mean reward {mean_reward:.2f} "
f" is above the threshold {self.reward_threshold}"
)
return continue_training
45 changes: 41 additions & 4 deletions utils/exp_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import numpy as np
import optuna
import yaml
import zmq
from optuna.integration.skopt import SkoptSampler
from optuna.pruners import BasePruner, MedianPruner, SuccessiveHalvingPruner
from optuna.samplers import BaseSampler, RandomSampler, TPESampler
Expand All @@ -29,6 +30,7 @@
DummyVecEnv,
SubprocVecEnv,
VecEnv,
VecEnvWrapper,
VecFrameStack,
VecNormalize,
VecTransposeImage,
Expand Down Expand Up @@ -99,6 +101,7 @@ def __init__(
self.env_wrapper = None
self.frame_stack = None
self.seed = seed
self.vec_env_wrapper = None
self.optimization_log_path = optimization_log_path

self.vec_env_class = {"dummy": DummyVecEnv, "subproc": SubprocVecEnv}[vec_env_type]
Expand Down Expand Up @@ -160,7 +163,7 @@ def setup_experiment(self) -> Optional[BaseAlgorithm]:
:return: the initialized RL model
"""
hyperparams, saved_hyperparams = self.read_hyperparameters()
hyperparams, self.env_wrapper, self.callbacks = self._preprocess_hyperparams(hyperparams)
hyperparams, self.env_wrapper, self.callbacks, self.vec_env_wrapper = self._preprocess_hyperparams(hyperparams)

self.create_log_folder()
self.create_callbacks()
Expand Down Expand Up @@ -200,12 +203,18 @@ def learn(self, model: BaseAlgorithm) -> None:

try:
model.learn(self.n_timesteps, **kwargs)
except KeyboardInterrupt:
except (KeyboardInterrupt, zmq.error.ZMQError):
# this allows to save the model when interrupting training
pass
finally:
# Release resources
try:
# Hack for zmq on Windows to allow early termination
env_tmp = model.env
while isinstance(env_tmp, VecEnvWrapper):
env_tmp = env_tmp.venv
env_tmp.waiting = False

model.env.close()
except EOFError:
pass
Expand Down Expand Up @@ -310,7 +319,7 @@ def _preprocess_normalization(self, hyperparams: Dict[str, Any]) -> Dict[str, An

def _preprocess_hyperparams(
self, hyperparams: Dict[str, Any]
) -> Tuple[Dict[str, Any], Optional[Callable], List[BaseCallback]]:
) -> Tuple[Dict[str, Any], Optional[Callable], List[BaseCallback], Optional[Callable]]:
self.n_envs = hyperparams.get("n_envs", 1)

if self.verbose > 0:
Expand Down Expand Up @@ -354,12 +363,16 @@ def _preprocess_hyperparams(
if "env_wrapper" in hyperparams.keys():
del hyperparams["env_wrapper"]

vec_env_wrapper = get_wrapper_class(hyperparams, "vec_env_wrapper")
if "vec_env_wrapper" in hyperparams.keys():
del hyperparams["vec_env_wrapper"]

callbacks = get_callback_list(hyperparams)
if "callback" in hyperparams.keys():
self.specified_callbacks = hyperparams["callback"]
del hyperparams["callback"]

return hyperparams, env_wrapper, callbacks
return hyperparams, env_wrapper, callbacks, vec_env_wrapper

def _preprocess_action_noise(
self, hyperparams: Dict[str, Any], saved_hyperparams: Dict[str, Any], env: VecEnv
Expand Down Expand Up @@ -517,6 +530,9 @@ def create_envs(self, n_envs: int, eval_env: bool = False, no_log: bool = False)
monitor_kwargs=monitor_kwargs,
)

if self.vec_env_wrapper is not None:
env = self.vec_env_wrapper(env)

# Wrap the env into a VecNormalize wrapper if needed
# and load saved statistics when present
env = self._maybe_normalize(env, eval_env)
Expand Down Expand Up @@ -653,9 +669,30 @@ def objective(self, trial: optuna.Trial) -> float:
try:
model.learn(self.n_timesteps, callback=callbacks)
# Free memory
env_tmp = model.env
while isinstance(env_tmp, VecEnvWrapper):
env_tmp = env_tmp.venv
env_tmp.waiting = False

env_tmp = eval_env
while isinstance(env_tmp, VecEnvWrapper):
env_tmp = env_tmp.venv
env_tmp.waiting = False

model.env.close()
eval_env.close()
except (AssertionError, ValueError) as e:
# Hack for zmq on Windows to allow early termination
env_tmp = model.env
while isinstance(env_tmp, VecEnvWrapper):
env_tmp = env_tmp.venv
env_tmp.waiting = False

env_tmp = eval_env
while isinstance(env_tmp, VecEnvWrapper):
env_tmp = env_tmp.venv
env_tmp.waiting = False

# Sometimes, random hyperparams can generate NaN
# Free memory
model.env.close()
Expand Down
Loading