Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Noisy Cross Entropy Method (CEM) #219

Draft
wants to merge 12 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
### Bug fixes
- Policies saved during during optimization with distributed Optuna load on new systems (@jkterry)
- Fixed script for recording video that was not up to date with the enjoy script
- Added CEM support

### Documentation

Expand Down
11 changes: 6 additions & 5 deletions hyperparams/ars.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@ Pendulum-v1: &pendulum-params
policy_kwargs: "dict(net_arch=[16])"
zero_policy: False

# TO BE Tuned
# Almost Tuned
LunarLander-v2:
<<: *pendulum-params
n_delta: 6
n_top: 1
n_timesteps: !!float 2e6
n_timesteps: !!float 4e6
n_delta: 64
n_top: 4
policy_kwargs: "dict(net_arch=[])"
# delta_std: 0.08

# Tuned
LunarLanderContinuous-v2:
Expand Down Expand Up @@ -215,4 +217,3 @@ A1Jumping-v0:
# alive_bonus_offset: -1
normalize: "dict(norm_obs=True, norm_reward=False)"
# policy_kwargs: "dict(net_arch=[16])"

190 changes: 190 additions & 0 deletions hyperparams/cem.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
# Tuned
CartPole-v1:
n_envs: 1
n_timesteps: !!float 1e5
policy: 'LinearPolicy'
pop_size: 4
n_top: 2

# Tuned
Pendulum-v1: &pendulum-params
n_envs: 1
n_timesteps: !!float 2e6
policy: 'MlpPolicy'
normalize: "dict(norm_obs=True, norm_reward=False)"
pop_size: 8
n_top: 4
policy_kwargs: "dict(net_arch=[16])"

# Tuned
LunarLander-v2:
<<: *pendulum-params
n_timesteps: !!float 1e6
pop_size: 32
n_top: 6
policy_kwargs: "dict(net_arch=[])"

# Tuned
LunarLanderContinuous-v2:
<<: *pendulum-params
n_timesteps: !!float 2e6
pop_size: 16
n_top: 4

# Tuned
Acrobot-v1:
<<: *pendulum-params
n_timesteps: !!float 5e5

# Tuned
MountainCar-v0:
<<: *pendulum-params
pop_size: 16
n_timesteps: !!float 5e5

# Tuned
MountainCarContinuous-v0:
<<: *pendulum-params
n_timesteps: !!float 5e5

# === Pybullet Envs ===
# To be tuned
HalfCheetahBulletEnv-v0: &pybullet-defaults
<<: *pendulum-params
n_timesteps: !!float 1e6
pop_size: 64
n_top: 6
extra_noise_std: 0.1

# To be tuned
AntBulletEnv-v0:
n_envs: 1
policy: 'MlpPolicy'
n_timesteps: !!float 7.5e7
learning_rate: !!float 0.02
delta_std: !!float 0.03
n_delta: 32
n_top: 32
alive_bonus_offset: 0
normalize: "dict(norm_obs=True, norm_reward=False)"
policy_kwargs: "dict(net_arch=[128, 64])"
zero_policy: False


Walker2DBulletEnv-v0:
policy: 'MlpPolicy'
n_timesteps: !!float 7.5e7
learning_rate: !!float 0.03
delta_std: !!float 0.025
n_delta: 40
n_top: 30
alive_bonus_offset: -1
normalize: "dict(norm_obs=True, norm_reward=False)"
policy_kwargs: "dict(net_arch=[64, 64])"
zero_policy: False

# Tuned
HopperBulletEnv-v0:
<<: *pendulum-params
n_timesteps: !!float 1e6
pop_size: 64
n_top: 6
extra_noise_std: 0.1
alive_bonus_offset: -1

ReacherBulletEnv-v0:
<<: *pybullet-defaults
n_timesteps: !!float 1e6

# === Mujoco Envs ===
# Tuned
Swimmer-v3:
<<: *pendulum-params
n_timesteps: !!float 1e6
n_top: 2

Hopper-v3:
<<: *pendulum-params
n_timesteps: !!float 1e6
pop_size: 64
n_top: 6
extra_noise_std: 0.1
alive_bonus_offset: -1


HalfCheetah-v3:
<<: *pendulum-params
n_timesteps: !!float 1e6
pop_size: 50
n_top: 6
extra_noise_std: 0.1

Walker2d-v3:
n_envs: 1
policy: 'LinearPolicy'
n_timesteps: !!float 7.5e7
pop_size: 40
n_top: 30
alive_bonus_offset: -1
normalize: "dict(norm_obs=True, norm_reward=False)"

# TO BE TUNED
Ant-v3:
<<: *pendulum-params
# For hyperparameter optimization, to alive_bonus_offset
# taken into account (reward_offset):
# env_wrapper:
# - utils.wrappers.DoneOnSuccessWrapper:
# reward_offset: -1.0
# - stable_baselines3.common.monitor.Monitor
n_timesteps: !!float 2e6
pop_size: 64
n_top: 6
# extra_noise_std: 0.1
# noise_multiplier: 0.999
alive_bonus_offset: -1
policy_kwargs: "dict(net_arch=[])"


Humanoid-v3:
n_envs: 1
policy: 'LinearPolicy'
n_timesteps: !!float 2.5e8
pop_size: 256
n_top: 256
alive_bonus_offset: -5
normalize: "dict(norm_obs=True, norm_reward=False)"

BipedalWalker-v3:
n_envs: 1
policy: 'MlpPolicy'
n_timesteps: !!float 1e8
pop_size: 64
n_top: 32
alive_bonus_offset: -0.1
normalize: "dict(norm_obs=True, norm_reward=False)"
policy_kwargs: "dict(net_arch=[16])"

# TO Be Tuned
BipedalWalkerHardcore-v3:
n_envs: 1
policy: 'MlpPolicy'
n_timesteps: !!float 5e8
pop_size: 64
n_top: 32
alive_bonus_offset: -0.1
normalize: "dict(norm_obs=True, norm_reward=False)"
policy_kwargs: "dict(net_arch=[16])"

A1Walking-v0:
<<: *pendulum-params
n_timesteps: !!float 2e6

A1Jumping-v0:
policy: 'LinearPolicy'
n_timesteps: !!float 7.5e7
pop_size: 80
n_top: 30
# alive_bonus_offset: -1
normalize: "dict(norm_obs=True, norm_reward=False)"
# policy_kwargs: "dict(net_arch=[16])"
1 change: 0 additions & 1 deletion utils/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def _on_step(self) -> bool:
if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0:
super(TrialEvalCallback, self)._on_step()
self.eval_idx += 1
# report best or report current ?
# report num_timesteps or elasped time ?
self.trial.report(self.last_mean_reward, self.eval_idx)
# Prune trial if need
Expand Down
12 changes: 6 additions & 6 deletions utils/exp_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def setup_experiment(self) -> Optional[Tuple[BaseAlgorithm, Dict[str, Any]]]:
self.create_callbacks()

# Create env to have access to action space for action noise
n_envs = 1 if self.algo == "ars" else self.n_envs
n_envs = 1 if self.algo in ["ars", "cem"] else self.n_envs
env = self.create_envs(n_envs, no_log=False)

self._hyperparams = self._preprocess_action_noise(hyperparams, saved_hyperparams, env)
Expand Down Expand Up @@ -200,8 +200,8 @@ def learn(self, model: BaseAlgorithm) -> None:
if len(self.callbacks) > 0:
kwargs["callback"] = self.callbacks

# Special case for ARS
if self.algo == "ars" and self.n_envs > 1:
# Special case for ARS and CEM
if self.algo in ["ars", "cem"] and self.n_envs > 1:
kwargs["async_eval"] = AsyncEval(
[lambda: self.create_envs(n_envs=1, no_log=True) for _ in range(self.n_envs)], model.policy
)
Expand Down Expand Up @@ -625,7 +625,7 @@ def objective(self, trial: optuna.Trial) -> float:
sampled_hyperparams = HYPERPARAMS_SAMPLER[self.algo](trial)
kwargs.update(sampled_hyperparams)

n_envs = 1 if self.algo == "ars" else self.n_envs
n_envs = 1 if self.algo in ["ars", "cem"] else self.n_envs
env = self.create_envs(n_envs, no_log=True)

# By default, do not activate verbose output to keep
Expand Down Expand Up @@ -667,8 +667,8 @@ def objective(self, trial: optuna.Trial) -> float:
callbacks.append(eval_callback)

learn_kwargs = {}
# Special case for ARS
if self.algo == "ars" and self.n_envs > 1:
# Special case for ARS and CEM
if self.algo in ["ars", "cem"] and self.n_envs > 1:
learn_kwargs["async_eval"] = AsyncEval(
[lambda: self.create_envs(n_envs=1, no_log=True) for _ in range(self.n_envs)], model.policy
)
Expand Down
41 changes: 40 additions & 1 deletion utils/hyperparams_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ def sample_ars_params(trial: optuna.Trial) -> Dict[str, Any]:
:return:
"""
# n_eval_episodes = trial.suggest_categorical("n_eval_episodes", [1, 2])
n_delta = trial.suggest_categorical("n_delta", [4, 8, 6, 32, 64])
n_delta = trial.suggest_categorical("n_delta", [4, 8, 16, 32, 64])
# learning_rate = trial.suggest_categorical("learning_rate", [0.01, 0.02, 0.025, 0.03])
learning_rate = trial.suggest_loguniform("learning_rate", 1e-5, 1)
delta_std = trial.suggest_categorical("delta_std", [0.01, 0.02, 0.025, 0.03, 0.05, 0.1, 0.2, 0.3])
Expand Down Expand Up @@ -519,9 +519,48 @@ def sample_ars_params(trial: optuna.Trial) -> Dict[str, Any]:
}


def sample_cem_params(trial: optuna.Trial) -> Dict[str, Any]:
"""
Sampler for CEM hyperparams.
:param trial:
:return:
"""
# n_eval_episodes = trial.suggest_categorical("n_eval_episodes", [1, 2])
# pop_size = trial.suggest_categorical("pop_size", [4, 8, 16, 32, 64, 128])
pop_size = trial.suggest_int("pop_size", 2, 130, step=4)
extra_noise_std = trial.suggest_categorical("extra_noise_std", [0.01, 0.02, 0.025, 0.03, 0.05, 0.1, 0.2, 0.3])
noise_multiplier = trial.suggest_categorical("noise_multiplier", [0.99, 0.995, 0.998, 0.999, 0.9995, 0.9998])
top_frac_size = trial.suggest_categorical("top_frac_size", [0.1, 0.2, 0.3, 0.4, 0.5, 0.8, 0.9, 1.0])
n_top = max(int(top_frac_size * pop_size), 2)
# use_diagonal_covariance = trial.suggest_categorical("use_diagonal_covariance", [False, True])

# net_arch = trial.suggest_categorical("net_arch", ["linear", "tiny", "small"])

# Note: remove bias to be as the original linear policy
# and do not squash output
# Comment out when doing hyperparams search with linear policy only
# net_arch = {
# "linear": [],
# "tiny": [16],
# "small": [32],
# }[net_arch]

# TODO: optimize the alive_bonus_offset too

return {
# "n_eval_episodes": n_eval_episodes,
"pop_size": pop_size,
"extra_noise_std": extra_noise_std,
"noise_multiplier": noise_multiplier,
"n_top": n_top,
# "policy_kwargs": dict(net_arch=net_arch),
}


HYPERPARAMS_SAMPLER = {
"a2c": sample_a2c_params,
"ars": sample_ars_params,
"cem": sample_cem_params,
"ddpg": sample_ddpg_params,
"dqn": sample_dqn_params,
"qrdqn": sample_qrdqn_params,
Expand Down
3 changes: 2 additions & 1 deletion utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import stable_baselines3 as sb3 # noqa: F401
import torch as th # noqa: F401
import yaml
from sb3_contrib import ARS, QRDQN, TQC, TRPO
from sb3_contrib import ARS, CEM, QRDQN, TQC, TRPO
from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.env_util import make_vec_env
Expand All @@ -27,6 +27,7 @@
"td3": TD3,
# SB3 Contrib,
"ars": ARS,
"cem": CEM,
"qrdqn": QRDQN,
"tqc": TQC,
"trpo": TRPO,
Expand Down