From 38fa2957369a8c352e25a8cbc863bc3d6470dc60 Mon Sep 17 00:00:00 2001 From: lenarttreven Date: Wed, 21 Feb 2024 11:19:11 +0100 Subject: [PATCH] update base optimizer; add official brax as a dependency; add sac and ppo for brax envs --- experiments/train_inverted_pendulum/exp.py | 100 ++++ .../train_inverted_pendulum/exp_ppo.py | 88 ++++ mbpo/optimizers/base_optimizer.py | 2 +- .../policy_optimizers/ppo/ppo_brax_env.py | 340 +++++++++++++ .../policy_optimizers/sac/sac_brax_env.py | 481 ++++++++++++++++++ setup.py | 4 +- 6 files changed, 1012 insertions(+), 3 deletions(-) create mode 100644 experiments/train_inverted_pendulum/exp.py create mode 100644 experiments/train_inverted_pendulum/exp_ppo.py create mode 100644 mbpo/optimizers/policy_optimizers/ppo/ppo_brax_env.py create mode 100644 mbpo/optimizers/policy_optimizers/sac/sac_brax_env.py diff --git a/experiments/train_inverted_pendulum/exp.py b/experiments/train_inverted_pendulum/exp.py new file mode 100644 index 0000000..b50fd1a --- /dev/null +++ b/experiments/train_inverted_pendulum/exp.py @@ -0,0 +1,100 @@ +from datetime import datetime + +import flax.linen as nn + +import jax +import jax.random as jr +import matplotlib.pyplot as plt +from brax import envs +from jax.nn import squareplus, swish + +from mbpo.optimizers.policy_optimizers.sac.sac_brax_env import SAC + +env_name = 'inverted_pendulum' # @param ['ant', 'halfcheetah', 'hopper', 'humanoid', 'humanoidstandup', +# 'inverted_pendulum', 'inverted_double_pendulum', 'pusher', 'reacher', 'walker2d'] +backend = 'positional' # @param ['generalized', 'positional', 'spring'] + +env = envs.get_environment(env_name=env_name, + backend=backend) +state = jax.jit(env.reset)(rng=jax.random.PRNGKey(seed=0)) + +optimizer = SAC( + environment=env, + num_timesteps=20_000, + episode_length=1000, + action_repeat=1, + num_env_steps_between_updates=10, + num_envs=4, + num_eval_envs=32, + lr_alpha=3e-4, + lr_policy=3e-4, + lr_q=3e-4, + wd_alpha=0., + wd_policy=0., + wd_q=0., + max_grad_norm=1e5, + discounting=0.99, + batch_size=32, + num_evals=20, + normalize_observations=True, + reward_scaling=1., + tau=0.005, + min_replay_size=10 ** 2, + max_replay_size=10 ** 5, + grad_updates_per_step=10 * 32, + deterministic_eval=True, + init_log_alpha=0., + policy_hidden_layer_sizes=(64, 64), + policy_activation=swish, + critic_hidden_layer_sizes=(64, 64), + critic_activation=swish, + wandb_logging=False, + return_best_model=False, +) + +xdata, ydata = [], [] +times = [datetime.now()] + + +def progress(num_steps, metrics): + times.append(datetime.now()) + xdata.append(num_steps) + ydata.append(metrics['eval/episode_reward']) + plt.xlabel('# environment steps') + plt.ylabel('reward per episode') + plt.plot(xdata, ydata) + plt.show() + + +optimizer.run_training(key=jr.PRNGKey(0), progress_fn=progress) + +# wandb.finish() + +# train_fn = { +# 'inverted_pendulum': functools.partial(ppo.train, +# num_timesteps=2_000_000, +# num_evals=20, +# reward_scaling=10, +# episode_length=1000, +# normalize_observations=True, +# action_repeat=1, +# unroll_length=5, +# num_minibatches=32, +# num_updates_per_batch=4, +# discounting=0.97, +# learning_rate=3e-4, +# entropy_cost=1e-2, +# num_envs=2048, +# batch_size=1024, +# seed=1), +# }[env_name] +# + + +# +# +# make_inference_fn, params, _ = train_fn(environment=env, progress_fn=progress) +# + +print(f'time to jit: {times[1] - times[0]}') +print(f'time to train: {times[-1] - times[1]}') diff --git a/experiments/train_inverted_pendulum/exp_ppo.py b/experiments/train_inverted_pendulum/exp_ppo.py new file mode 100644 index 0000000..0d09dc0 --- /dev/null +++ b/experiments/train_inverted_pendulum/exp_ppo.py @@ -0,0 +1,88 @@ +from datetime import datetime + +import jax +import jax.random as jr +import matplotlib.pyplot as plt +from brax import envs +from jax.nn import swish + +from mbpo.optimizers.policy_optimizers.ppo.ppo_brax_env import PPO + +env_name = 'inverted_pendulum' # @param ['ant', 'halfcheetah', 'hopper', 'humanoid', 'humanoidstandup', +# 'inverted_pendulum', 'inverted_double_pendulum', 'pusher', 'reacher', 'walker2d'] +backend = 'positional' # @param ['generalized', 'positional', 'spring'] + +env = envs.get_environment(env_name=env_name, + backend=backend) +state = jax.jit(env.reset)(rng=jax.random.PRNGKey(seed=0)) + +optimizer = PPO( + environment=env, + num_timesteps=2_000_000, + num_evals=20, + reward_scaling=10, + episode_length=1000, + normalize_observations=True, + action_repeat=1, + unroll_length=5, + num_minibatches=32, + num_updates_per_batch=4, + discounting=0.97, + lr=3e-4, + entropy_cost=1e-2, + num_envs=2048, + batch_size=1024, + seed=1, + policy_hidden_layer_sizes=(32,) * 4, + policy_activation=swish, + critic_hidden_layer_sizes=(256,) * 5, + critic_activation=swish, + wandb_logging=False, +) + +xdata, ydata = [], [] +times = [datetime.now()] + + +def progress(num_steps, metrics): + times.append(datetime.now()) + xdata.append(num_steps) + ydata.append(metrics['eval/episode_reward']) + plt.xlabel('# environment steps') + plt.ylabel('reward per episode') + plt.plot(xdata, ydata) + plt.show() + + +optimizer.run_training(key=jr.PRNGKey(0), progress_fn=progress) + +# wandb.finish() + +# train_fn = { +# 'inverted_pendulum': functools.partial(ppo.train, +# num_timesteps=2_000_000, +# num_evals=20, +# reward_scaling=10, +# episode_length=1000, +# normalize_observations=True, +# action_repeat=1, +# unroll_length=5, +# num_minibatches=32, +# num_updates_per_batch=4, +# discounting=0.97, +# learning_rate=3e-4, +# entropy_cost=1e-2, +# num_envs=2048, +# batch_size=1024, +# seed=1), +# }[env_name] +# + + +# +# +# make_inference_fn, params, _ = train_fn(environment=env, progress_fn=progress) +# + +print(f'time to jit: {times[1] - times[0]}') +print(f'time to train: {times[-1] - times[1]}') diff --git a/mbpo/optimizers/base_optimizer.py b/mbpo/optimizers/base_optimizer.py index d53e9c0..fd7642d 100644 --- a/mbpo/optimizers/base_optimizer.py +++ b/mbpo/optimizers/base_optimizer.py @@ -12,7 +12,7 @@ class BaseOptimizer(ABC, Generic[RewardParams, DynamicsParams]): - def __init__(self, system: System | None = None, key: jr.PRNGKeyArray = jr.PRNGKey(0)): + def __init__(self, system: System | None = None, key: jr.PRNGKey = jr.PRNGKey(0)): self.system = system self.key = key pass diff --git a/mbpo/optimizers/policy_optimizers/ppo/ppo_brax_env.py b/mbpo/optimizers/policy_optimizers/ppo/ppo_brax_env.py new file mode 100644 index 0000000..b5bde7f --- /dev/null +++ b/mbpo/optimizers/policy_optimizers/ppo/ppo_brax_env.py @@ -0,0 +1,340 @@ +import math +import time +from functools import partial +from typing import Any, Tuple, Sequence, Callable + +import chex +import flax.linen as nn +import jax.numpy as jnp +import jax.random as jr +import jax.tree_util as jtu +import optax +import wandb +from brax import envs +from brax.training import acting +from brax.training import networks +from brax.training import types +from brax.training.acme import running_statistics +from brax.training.acme import specs +from brax.training.types import Params +from jax.lax import scan +from jaxtyping import PyTree + +from brax.envs.wrappers.training import wrap as wrap_for_training + +from mbpo.optimizers.policy_optimizers.ppo.losses import PPOLoss, PPONetworkParams +from mbpo.optimizers.policy_optimizers.ppo.ppo_network import PPONetworksModel, make_inference_fn +from mbpo.optimizers.policy_optimizers.sac.utils import gradient_update_fn, metrics_to_float + +Metrics = types.Metrics +Transition = types.Transition +InferenceParams = Tuple[running_statistics.NestedMeanStd, Params] + +ReplayBufferState = Any + + +@chex.dataclass +class TrainingState: + """Contains training state for the learner.""" + optimizer_state: optax.OptState + params: PPONetworkParams + normalizer_params: running_statistics.RunningStatisticsState + env_steps: jnp.ndarray + + def get_policy_params(self): + return self.normalizer_params, self.params.policy + + +class PPO: + def __init__(self, + environment: envs.Env, + num_timesteps: int, + episode_length: int, + action_repeat: int = 1, + num_envs: int = 1, + num_eval_envs: int = 128, + lr: float = 1e-4, + wd: float = 1e-5, + entropy_cost: float = 1e-4, + discounting: float = 0.9, + seed: int = 0, + unroll_length: int = 10, + batch_size: int = 32, + num_minibatches: int = 16, + num_updates_per_batch: int = 2, + num_evals: int = 1, + normalize_observations: bool = False, + reward_scaling: float = 1., + clipping_epsilon: float = .3, + gae_lambda: float = .95, + deterministic_eval: bool = False, + normalize_advantage: bool = True, + policy_hidden_layer_sizes: Sequence[int] = (64, 64, 64), + policy_activation: networks.ActivationFn = nn.swish, + critic_hidden_layer_sizes: Sequence[int] = (64, 64, 64), + critic_activation: networks.ActivationFn = nn.swish, + wandb_logging: bool = False, + ): + self.wandb_logging = wandb_logging + self.episode_length = episode_length + self.action_repeat = action_repeat + self.num_timesteps = num_timesteps + self.deterministic_eval = deterministic_eval + self.normalize_advantage = normalize_advantage + self.gae_lambda = gae_lambda + self.clipping_epsilon = clipping_epsilon + self.reward_scaling = reward_scaling + self.normalize_observations = normalize_observations + self.num_evals = num_evals + self.num_updates_per_batch = num_updates_per_batch + self.num_minibatches = num_minibatches + self.batch_size = batch_size + self.unroll_length = unroll_length + self.discounting = discounting + self.entropy_cost = entropy_cost + self.num_eval_envs = num_eval_envs + self.num_envs = num_envs + # Set this to None unless you want to use parallelism across multiple devices. + self._PMAP_AXIS_NAME = None + + assert batch_size * num_minibatches % num_envs == 0 + self.env_step_per_training_step = batch_size * unroll_length * num_minibatches * action_repeat + num_evals_after_init = max(num_evals - 1, 1) + self.num_evals_after_init = num_evals_after_init + num_training_steps_per_epoch = math.ceil( + num_timesteps / (num_evals_after_init * self.env_step_per_training_step)) + self.num_training_steps_per_epoch = num_training_steps_per_epoch + self.key = jr.PRNGKey(seed) + self.env = wrap_for_training(environment, episode_length=episode_length, action_repeat=action_repeat) + + self.x_dim = self.env.observation_size + self.u_dim = self.env.action_size + + def normalize_fn(batch: PyTree, _: PyTree) -> PyTree: + return batch + + if normalize_observations: + normalize_fn = running_statistics.normalize + self.normalize_fn = normalize_fn + + self.ppo_networks_model = PPONetworksModel( + x_dim=self.x_dim, u_dim=self.u_dim, + preprocess_observations_fn=normalize_fn, + policy_hidden_layer_sizes=policy_hidden_layer_sizes, + policy_activation=policy_activation, + value_hidden_layer_sizes=critic_hidden_layer_sizes, + value_activation=critic_activation) + + self.make_policy = make_inference_fn(self.ppo_networks_model.get_ppo_networks()) + self.optimizer = optax.adamw(learning_rate=lr, weight_decay=wd) + + self.ppo_loss = PPOLoss(ppo_network=self.ppo_networks_model.get_ppo_networks(), + entropy_cost=self.entropy_cost, + discounting=self.discounting, + reward_scaling=self.reward_scaling, + gae_lambda=self.gae_lambda, + clipping_epsilon=self.clipping_epsilon, + normalize_advantage=self.normalize_advantage, + ) + + self.ppo_update = gradient_update_fn(self.ppo_loss.loss, self.optimizer, pmap_axis_name=self._PMAP_AXIS_NAME, + has_aux=True) + + def minibatch_step(self, + carry, + data: types.Transition, + normalizer_params: running_statistics.RunningStatisticsState): + # Performs one ppo update + optimizer_state, params, key = carry + key, key_loss = jr.split(key) + (_, metrics), params, optimizer_state = self.ppo_update( + params, + normalizer_params, + data, + key_loss, + optimizer_state=optimizer_state) + + return (optimizer_state, params, key), metrics + + def sgd_step(self, + carry, + unused_t, + data: types.Transition, + normalizer_params: running_statistics.RunningStatisticsState): + optimizer_state, params, key = carry + key, key_perm, key_grad = jr.split(key, 3) + + def convert_data(x: chex.Array): + x = jr.permutation(key_perm, x) + x = jnp.reshape(x, (self.num_minibatches, -1) + x.shape[1:]) + return x + + shuffled_data = jtu.tree_map(convert_data, data) + (optimizer_state, params, _), metrics = scan( + partial(self.minibatch_step, normalizer_params=normalizer_params), + (optimizer_state, params, key_grad), + shuffled_data, + length=self.num_minibatches) + return (optimizer_state, params, key), metrics + + def training_step( + self, + carry: Tuple[TrainingState, envs.State, chex.PRNGKey], + unused_t) -> Tuple[Tuple[TrainingState, envs.State, chex.PRNGKey], Metrics]: + """ + 1. Performs a rollout of length `self.unroll_length` using the current policy. + Since there is self.num_envs environments, this will result in a self.num_envs * self.unroll_length + collected transitions. + """ + training_state, state, key = carry + key_sgd, key_generate_unroll, new_key = jr.split(key, 3) + + # Deterministic is set by default to False + policy = self.make_policy((training_state.normalizer_params, training_state.params.policy)) + + def f(carry, unused_t): + current_state, current_key = carry + current_key, next_key = jr.split(current_key) + next_state, data = acting.generate_unroll( + self.env, + current_state, + policy, + current_key, + self.unroll_length, + extra_fields=('truncation',)) + return (next_state, next_key), data + + (state, _), data = scan( + f, (state, key_generate_unroll), (), + length=self.batch_size * self.num_minibatches // self.num_envs) + # Have leading dimensions (batch_size * num_minibatches, unroll_length) + data = jtu.tree_map(lambda x: jnp.swapaxes(x, 1, 2), data) + data = jtu.tree_map(lambda x: jnp.reshape(x, (-1,) + x.shape[2:]), + data) + assert data.discount.shape[1:] == (self.unroll_length,) + + # Update normalization params and normalize observations. + normalizer_params = running_statistics.update( + training_state.normalizer_params, + data.observation, + pmap_axis_name=self._PMAP_AXIS_NAME) + + # Perform self.num_updates_per_batch calls of self.sgd_step + (optimizer_state, params, _), metrics = scan( + partial( + self.sgd_step, data=data, normalizer_params=normalizer_params), + (training_state.optimizer_state, training_state.params, key_sgd), (), + length=self.num_updates_per_batch) + + new_training_state = TrainingState( + optimizer_state=optimizer_state, + params=params, + normalizer_params=normalizer_params, + env_steps=training_state.env_steps + self.env_step_per_training_step) + return (new_training_state, state, new_key), metrics + + def training_epoch( + self, + training_state: TrainingState, + state: envs.State, + key: chex.PRNGKey) -> Tuple[TrainingState, envs.State, Metrics]: + """ + Performs self.num_training_steps_per_epoch calls of self.training_step functions. + """ + (training_state, state, _), loss_metrics = scan( + self.training_step, (training_state, state, key), (), + length=self.num_training_steps_per_epoch) + loss_metrics = jtu.tree_map(jnp.mean, loss_metrics) + return training_state, state, loss_metrics + + def training_epoch_with_timing( + self, + training_state: TrainingState, + env_state: envs.State, + key: chex.PRNGKey) -> Tuple[TrainingState, envs.State, Metrics]: + t = time.time() + training_state, env_state, metrics = self.training_epoch(training_state, env_state, key) + + metrics = jtu.tree_map(jnp.mean, metrics) + epoch_training_time = time.time() - t + sps = (self.num_training_steps_per_epoch * self.env_step_per_training_step) / epoch_training_time + metrics = { + 'training/sps': jnp.array(sps), + **{f'training/{name}': jnp.array(value) for name, value in metrics.items()} + } + return training_state, env_state, metrics + + def init_training_state(self, key: chex.PRNGKey) -> TrainingState: + keys = jr.split(key) + init_params = PPONetworkParams( + policy=self.ppo_networks_model.get_policy_network().init(keys[0]), + value=self.ppo_networks_model.get_value_network().init(keys[1])) + training_state = TrainingState( + optimizer_state=self.optimizer.init(init_params), + params=init_params, + normalizer_params=running_statistics.init_state( + specs.Array((self.x_dim,), jnp.float32)), + env_steps=0) + return training_state + + def run_training(self, key: chex.PRNGKey, progress_fn: Callable[[int, Metrics], None] = lambda *args: None): + key, subkey = jr.split(key) + training_state = self.init_training_state(subkey) + + key, rb_key, env_key, eval_key = jr.split(key, 4) + + # Initialize initial env state + env_keys = jr.split(env_key, self.num_envs) + env_state = self.env.reset(env_keys) + + evaluator = acting.Evaluator( + self.env, partial(self.make_policy, deterministic=self.deterministic_eval), + num_eval_envs=self.num_eval_envs, episode_length=self.episode_length, action_repeat=self.action_repeat, + key=eval_key) + + # Run initial eval + all_metrics = [] + metrics = {} + if self.num_evals > 1: + metrics = evaluator.run_evaluation((training_state.normalizer_params, training_state.params.policy), + training_metrics={}) + + if self.wandb_logging: + metrics = metrics_to_float(metrics) + wandb.log(metrics) + all_metrics.append(metrics) + progress_fn(0, metrics) + + # Create and initialize the replay buffer. + key, prefill_key = jr.split(key) + + current_step = 0 + for _ in range(self.num_evals_after_init): + if self.wandb_logging: + wandb.log(metrics_to_float({'step': current_step})) + + # Optimization + key, epoch_key = jr.split(key) + training_state, env_state, training_metrics = self.training_epoch_with_timing(training_state, env_state, + epoch_key) + current_step = training_state.env_steps + + # Eval and logging + # Run evals. + metrics = evaluator.run_evaluation((training_state.normalizer_params, training_state.params.policy), + training_metrics) + if self.wandb_logging: + metrics = metrics_to_float(metrics) + wandb.log(metrics) + all_metrics.append(metrics) + progress_fn(current_step, metrics) + + total_steps = current_step + # assert total_steps >= self.num_timesteps + params = (training_state.normalizer_params, training_state.params.policy) + + # If there were no mistakes the training_state should still be identical on all + # devices. + if self.wandb_logging: + wandb.log(metrics_to_float({'total steps': total_steps})) + return params, all_metrics diff --git a/mbpo/optimizers/policy_optimizers/sac/sac_brax_env.py b/mbpo/optimizers/policy_optimizers/sac/sac_brax_env.py new file mode 100644 index 0000000..8557718 --- /dev/null +++ b/mbpo/optimizers/policy_optimizers/sac/sac_brax_env.py @@ -0,0 +1,481 @@ +import functools +import math +import time +from functools import partial +from typing import Any, Callable, Optional, Tuple, Sequence + +import chex +import flax.linen as nn +import jax.numpy as jnp +import jax.random as jr +import jax.tree_util as jtu +import optax +import wandb +from brax import envs +from brax.training import networks +from brax.training import replay_buffers +from brax.training import types +from brax.training.acme import running_statistics +from brax.training.acme import specs +from brax.training.types import Params +from flax import struct +from jax import jit +from jax.lax import scan +from jaxtyping import PyTree + +from brax.envs.wrappers.training import wrap as wrap_for_training + +from mbpo.optimizers.policy_optimizers.sac import acting +from mbpo.optimizers.policy_optimizers.sac.losses import SACLosses +from mbpo.optimizers.policy_optimizers.sac.sac_networks import SACNetworksModel, make_inference_fn +from mbpo.optimizers.policy_optimizers.sac.utils import gradient_update_fn, metrics_to_float + +Metrics = types.Metrics +Transition = types.Transition +InferenceParams = Tuple[running_statistics.NestedMeanStd, Params] + +ReplayBufferState = Any + + +@struct.dataclass +class TrainingState: + """Contains training state for the learner.""" + policy_optimizer_state: optax.OptState + policy_params: Params + q_optimizer_state: optax.OptState + q_params: Params + target_q_params: Params + gradient_steps: jnp.ndarray + env_steps: jnp.ndarray + alpha_optimizer_state: optax.OptState + alpha_params: Params + normalizer_params: running_statistics.RunningStatisticsState + + def get_policy_params(self): + return self.normalizer_params, self.policy_params + + +class SAC: + def __init__(self, + environment: envs.Env, + num_timesteps: int, + episode_length: int, + action_repeat: int = 1, + num_env_steps_between_updates: int = 2, + num_envs: int = 1, + num_eval_envs: int = 128, + lr_alpha: float = 1e-4, + lr_policy: float = 1e-4, + lr_q: float = 1e-4, + wd_alpha: float = 0., + wd_policy: float = 0., + wd_q: float = 0., + max_grad_norm: float = 1e5, + discounting: float = 0.9, + batch_size: int = 256, + num_evals: int = 1, + normalize_observations: bool = False, + reward_scaling: float = 1., + tau: float = 0.005, + min_replay_size: int = 0, + max_replay_size: Optional[int] = None, + grad_updates_per_step: int = 1, + deterministic_eval: bool = True, + init_log_alpha: float = 0., + target_entropy: float | None = None, + policy_hidden_layer_sizes: Sequence[int] = (64, 64, 64), + policy_activation: networks.ActivationFn = nn.swish, + critic_hidden_layer_sizes: Sequence[int] = (64, 64, 64), + critic_activation: networks.ActivationFn = nn.swish, + wandb_logging: bool = False, + return_best_model: bool = False, + eval_environment: envs.Env | None = None, + episode_length_eval: int | None = None, + eval_key_fixed: bool = False, + ): + if min_replay_size >= num_timesteps: + raise ValueError( + 'No training will happen because min_replay_size >= num_timesteps') + + self.eval_key_fixed = eval_key_fixed + self.return_best_model = return_best_model + self.target_entropy = target_entropy + self.init_log_alpha = init_log_alpha + self.wandb_logging = wandb_logging + self.min_replay_size = min_replay_size + self.num_timesteps = num_timesteps + self.num_envs = num_envs + self.deterministic_eval = deterministic_eval + self.num_eval_envs = num_eval_envs + self.episode_length = episode_length + self.action_repeat = action_repeat + self.num_evals = num_evals + self.num_env_steps_between_updates = num_env_steps_between_updates + + if max_replay_size is None: + max_replay_size = num_timesteps + self.max_replay_size = max_replay_size + # The number of environment steps executed for every `actor_step()` call. + self.env_steps_per_actor_step = action_repeat * num_envs + # equals to ceil(min_replay_size / num_envs) + self.num_prefill_actor_steps = math.ceil(min_replay_size / num_envs) + num_prefill_env_steps = self.num_prefill_actor_steps * self.env_steps_per_actor_step + assert num_timesteps - num_prefill_env_steps >= 0 + self.num_evals_after_init = max(num_evals - 1, 1) + # The number of run_one_sac_epoch calls per run_sac_training. + # equals to + num_env_steps_in_one_train_step = self.num_evals_after_init * self.env_steps_per_actor_step + num_env_steps_in_one_train_step *= num_env_steps_between_updates + self.num_training_steps_per_epoch = math.ceil( + (num_timesteps - num_prefill_env_steps) / num_env_steps_in_one_train_step) + # num_training_steps_per_epoch is how many action we apply in every epoch + + self.grad_updates_per_step = grad_updates_per_step + + self.tau = tau + self.env = wrap_for_training(environment, + episode_length=episode_length, + action_repeat=action_repeat) + + # Prepare env for evaluation + if episode_length_eval is None: + episode_length_eval = episode_length + + self.episode_length_eval = episode_length_eval + if eval_environment is None: + eval_environment = environment + self.eval_env = wrap_for_training(eval_environment, + episode_length=episode_length_eval, + action_repeat=action_repeat) + + self.x_dim = self.env.observation_size + self.u_dim = self.env.action_size + + def normalize_fn(batch: PyTree, _: PyTree) -> PyTree: + return batch + + if normalize_observations: + normalize_fn = running_statistics.normalize + self.normalize_fn = normalize_fn + + self.sac_networks_model = SACNetworksModel( + x_dim=self.x_dim, u_dim=self.u_dim, + preprocess_observations_fn=normalize_fn, + policy_hidden_layer_sizes=policy_hidden_layer_sizes, + policy_activation=policy_activation, + critic_hidden_layer_sizes=critic_hidden_layer_sizes, + critic_activation=critic_activation) + + self.make_policy = make_inference_fn(self.sac_networks_model.get_sac_networks()) + + self.alpha_optimizer = optax.chain( + optax.clip_by_global_norm(max_norm=max_grad_norm), + optax.adamw(learning_rate=lr_alpha, weight_decay=wd_alpha) + ) + self.policy_optimizer = optax.chain( + optax.clip_by_global_norm(max_norm=max_grad_norm), + optax.adamw(learning_rate=lr_policy, weight_decay=wd_policy) + ) + self.q_optimizer = optax.chain( + optax.clip_by_global_norm(max_norm=max_grad_norm), + optax.adamw(learning_rate=lr_q, weight_decay=wd_q) + ) + + # Set this to None unless you want to use parallelism across multiple devices. + self._PMAP_AXIS_NAME = None + + # Setup replay buffer. + dummy_obs = jnp.zeros((self.x_dim,)) + dummy_action = jnp.zeros((self.u_dim,)) + dummy_transition = Transition( + observation=dummy_obs, + action=dummy_action, + reward=jnp.array(0.), + discount=jnp.array(0.), + next_observation=dummy_obs, + extras={'state_extras': {'truncation': 0.}, 'policy_extras': {}}) + + self.replay_buffer = replay_buffers.UniformSamplingQueue( + max_replay_size=max_replay_size, + dummy_data_sample=dummy_transition, + sample_batch_size=batch_size * grad_updates_per_step) + + # Setup optimization + self.losses = SACLosses(sac_network=self.sac_networks_model.get_sac_networks(), reward_scaling=reward_scaling, + discounting=discounting, u_dim=self.u_dim, target_entropy=self.target_entropy) + + self.alpha_update = gradient_update_fn( + self.losses.alpha_loss, self.alpha_optimizer, pmap_axis_name=self._PMAP_AXIS_NAME) + self.critic_update = gradient_update_fn( + self.losses.critic_loss, self.q_optimizer, pmap_axis_name=self._PMAP_AXIS_NAME) + self.actor_update = gradient_update_fn( + self.losses.actor_loss, self.policy_optimizer, pmap_axis_name=self._PMAP_AXIS_NAME) + + @partial(jit, static_argnums=(0,)) + def sgd_step(self, carry: Tuple[TrainingState, chex.PRNGKey], + transitions: Transition) -> Tuple[Tuple[TrainingState, chex.PRNGKey], Metrics]: + training_state, key = carry + + key, key_alpha, key_critic, key_actor = jr.split(key, 4) + + alpha_loss, alpha_params, alpha_optimizer_state = self.alpha_update( + training_state.alpha_params, + training_state.policy_params, + training_state.normalizer_params, + transitions, + key_alpha, + optimizer_state=training_state.alpha_optimizer_state) + alpha = jnp.exp(training_state.alpha_params) + critic_loss, q_params, q_optimizer_state = self.critic_update( + training_state.q_params, + training_state.policy_params, + training_state.normalizer_params, + training_state.target_q_params, + alpha, + transitions, + key_critic, + optimizer_state=training_state.q_optimizer_state) + actor_loss, policy_params, policy_optimizer_state = self.actor_update( + training_state.policy_params, + training_state.normalizer_params, + training_state.q_params, + alpha, + transitions, + key_actor, + optimizer_state=training_state.policy_optimizer_state) + + new_target_q_params = jtu.tree_map(lambda x, y: x * (1 - self.tau) + y * self.tau, + training_state.target_q_params, q_params) + + metrics = { + 'critic_loss': critic_loss, + 'actor_loss': actor_loss, + 'alpha_loss': alpha_loss, + 'alpha': jnp.exp(alpha_params), + } + + new_training_state = TrainingState( + policy_optimizer_state=policy_optimizer_state, + policy_params=policy_params, + q_optimizer_state=q_optimizer_state, + q_params=q_params, + target_q_params=new_target_q_params, + gradient_steps=training_state.gradient_steps + 1, + env_steps=training_state.env_steps, + alpha_optimizer_state=alpha_optimizer_state, + alpha_params=alpha_params, + normalizer_params=training_state.normalizer_params) + return (new_training_state, key), metrics + + def get_experience(self, normalizer_params: running_statistics.RunningStatisticsState, policy_params: Params, + env_state: envs.State, buffer_state: ReplayBufferState, key: chex.PRNGKey + ) -> Tuple[running_statistics.RunningStatisticsState, envs.State, ReplayBufferState]: + policy = self.make_policy((normalizer_params, policy_params)) + + def f(carry, _): + k, es = carry + k, k_t = jr.split(k) + new_es, new_trans = acting.actor_step(self.env, es, policy, k_t, extra_fields=('truncation',)) + return (k, new_es), new_trans + + (key, env_state), transitions = scan(f, (key, env_state), jnp.arange(self.num_env_steps_between_updates)) + + transitions = jtu.tree_map(jnp.concatenate, transitions) + + normalizer_params = running_statistics.update( + normalizer_params, + transitions.observation, + pmap_axis_name=self._PMAP_AXIS_NAME) + + buffer_state = self.replay_buffer.insert(buffer_state, transitions) + return normalizer_params, env_state, buffer_state + + def training_step(self, training_state: TrainingState, env_state: envs.State, buffer_state: ReplayBufferState, + key: chex.PRNGKey) -> Tuple[TrainingState, envs.State, ReplayBufferState, Metrics]: + experience_key, training_key = jr.split(key) + + normalizer_params, env_state, buffer_state = self.get_experience( + training_state.normalizer_params, training_state.policy_params, + env_state, buffer_state, experience_key) + + training_state = training_state.replace( + normalizer_params=normalizer_params, + env_steps=training_state.env_steps + self.env_steps_per_actor_step * self.num_env_steps_between_updates) + + buffer_state, transitions = self.replay_buffer.sample(buffer_state) + # Change the front dimension of transitions so 'update_step' is called + # grad_updates_per_step times by the scan. + transitions = jtu.tree_map( + lambda x: jnp.reshape(x, (self.grad_updates_per_step, -1) + x.shape[1:]), + transitions) + (training_state, _), metrics = scan(self.sgd_step, (training_state, training_key), transitions) + + metrics['buffer_current_size'] = jnp.array(self.replay_buffer.size(buffer_state)) + return training_state, env_state, buffer_state, metrics + + @partial(jit, static_argnums=(0,)) + def prefill_replay_buffer(self, training_state: TrainingState, env_state: envs.State, + buffer_state: ReplayBufferState, key: chex.PRNGKey + ) -> Tuple[TrainingState, envs.State, ReplayBufferState, chex.PRNGKey]: + + def f(carry, _): + tr_step, e_state, b_state, ky = carry + ky, new_key = jr.split(ky) + new_normalizer_params, e_state, b_state = self.get_experience( + tr_step.normalizer_params, tr_step.policy_params, + e_state, b_state, ky) + new_training_state = tr_step.replace( + normalizer_params=new_normalizer_params, + env_steps=tr_step.env_steps + self.env_steps_per_actor_step) + return (new_training_state, e_state, b_state, new_key), () + + return scan(f, (training_state, env_state, buffer_state, key), (), length=self.num_prefill_actor_steps)[0] + + @partial(jit, static_argnums=(0,)) + def training_epoch(self, training_state: TrainingState, env_state: envs.State, buffer_state: ReplayBufferState, + key: chex.PRNGKey) -> Tuple[TrainingState, envs.State, ReplayBufferState, Metrics]: + + def f(carry, _): + ts, es, bs, k = carry + k, new_key = jr.split(k) + ts, es, bs, metr = self.training_step(ts, es, bs, k) + return (ts, es, bs, new_key), metr + + (training_state, env_state, buffer_state, key), metrics = scan( + f, (training_state, env_state, buffer_state, key), (), + length=self.num_training_steps_per_epoch) + metrics = jtu.tree_map(jnp.mean, metrics) + return training_state, env_state, buffer_state, metrics + + def training_epoch_with_timing(self, training_state: TrainingState, env_state: envs.State, + buffer_state: ReplayBufferState, key: chex.PRNGKey + ) -> Tuple[TrainingState, envs.State, ReplayBufferState, Metrics]: + t = time.time() + training_state, env_state, buffer_state, metrics = self.training_epoch(training_state, env_state, buffer_state, + key) + + epoch_training_time = time.time() - t + sps = (self.env_steps_per_actor_step * self.num_training_steps_per_epoch) / epoch_training_time + metrics = {'training/sps': jnp.array(sps), + **{f'training/{name}': jnp.array(value) for name, value in metrics.items()}} + return training_state, env_state, buffer_state, metrics + + def init_training_state(self, key: chex.PRNGKey) -> TrainingState: + """Inits the training state and replicates it over devices.""" + key_policy, key_q = jr.split(key) + log_alpha = jnp.asarray(self.init_log_alpha, dtype=jnp.float32) + alpha_optimizer_state = self.alpha_optimizer.init(log_alpha) + + policy_params = self.sac_networks_model.get_policy_network().init(key_policy) + policy_optimizer_state = self.policy_optimizer.init(policy_params) + + q_params = self.sac_networks_model.get_q_network().init(key_q) + q_optimizer_state = self.q_optimizer.init(q_params) + + normalizer_params = running_statistics.init_state( + specs.Array((self.x_dim,), jnp.float32)) + + training_state = TrainingState( + policy_optimizer_state=policy_optimizer_state, + policy_params=policy_params, + q_optimizer_state=q_optimizer_state, + q_params=q_params, + target_q_params=q_params, + gradient_steps=jnp.zeros(()), + env_steps=jnp.zeros(()), + alpha_optimizer_state=alpha_optimizer_state, + alpha_params=log_alpha, + normalizer_params=normalizer_params) + return training_state + + def run_training(self, key: chex.PRNGKey, progress_fn: Callable[[int, Metrics], None] = lambda *args: None): + key, subkey = jr.split(key) + training_state = self.init_training_state(subkey) + + key, rb_key, env_key, eval_key = jr.split(key, 4) + + # Initialize initial env state + env_keys = jr.split(env_key, self.num_envs) + env_state = self.env.reset(env_keys) + + # Initialize replay buffer + buffer_state = self.replay_buffer.init(rb_key) + + evaluator = acting.Evaluator( + self.eval_env, functools.partial(self.make_policy, deterministic=self.deterministic_eval), + num_eval_envs=self.num_eval_envs, episode_length=self.episode_length_eval, action_repeat=self.action_repeat, + key=eval_key) + + # Run initial eval + all_metrics = [] + metrics = {} + highest_eval_episode_reward = jnp.array(-jnp.inf) + best_params = (training_state.normalizer_params, training_state.policy_params) + if self.num_evals > 1: + metrics = evaluator.run_evaluation((training_state.normalizer_params, training_state.policy_params), + training_metrics={}) + if metrics['eval/episode_reward'] > highest_eval_episode_reward: + highest_eval_episode_reward = metrics['eval/episode_reward'] + best_params = (training_state.normalizer_params, training_state.policy_params) + if self.wandb_logging: + metrics = metrics_to_float(metrics) + wandb.log(metrics) + all_metrics.append(metrics) + progress_fn(0, metrics) + + # Create and initialize the replay buffer. + key, prefill_key = jr.split(key) + training_state, env_state, buffer_state, _ = self.prefill_replay_buffer( + training_state, env_state, buffer_state, prefill_key) + + replay_size = self.replay_buffer.size(buffer_state) + if self.wandb_logging: + wandb.log(metrics_to_float({'replay size after prefill': replay_size})) + # assert replay_size >= self.min_replay_size + + current_step = 0 + + if self.eval_key_fixed: + key, eval_key = jr.split(key) + + for _ in range(self.num_evals_after_init): + if self.wandb_logging: + wandb.log(metrics_to_float({'step': current_step})) + + # Optimization + key, epoch_key = jr.split(key) + training_state, env_state, buffer_state, training_metrics = self.training_epoch_with_timing(training_state, + env_state, + buffer_state, + epoch_key) + # current_step = int(training_state.env_steps) + + # Eval and logging + # Run evals. + if not self.eval_key_fixed: + key, eval_key = jr.split(key) + metrics = evaluator.run_evaluation((training_state.normalizer_params, training_state.policy_params), + training_metrics, unroll_key=eval_key) + + if metrics['eval/episode_reward'] > highest_eval_episode_reward: + highest_eval_episode_reward = metrics['eval/episode_reward'] + best_params = (training_state.normalizer_params, training_state.policy_params) + + if self.wandb_logging: + metrics = metrics_to_float(metrics) + wandb.log(metrics) + all_metrics.append(metrics) + progress_fn(training_state.env_steps, metrics) + + # total_steps = current_step + # assert total_steps >= self.num_timesteps + last_params = (training_state.normalizer_params, training_state.policy_params) + + if self.return_best_model: + params_to_return = best_params + else: + params_to_return = last_params + + if self.wandb_logging: + wandb.log(metrics_to_float({'total steps': int(training_state.env_steps)})) + return params_to_return, all_metrics diff --git a/setup.py b/setup.py index 747a922..c6d1bbc 100644 --- a/setup.py +++ b/setup.py @@ -16,8 +16,8 @@ 'argparse-dataclass>=0.2.1', 'jaxutils', 'chex', + 'brax', 'distrax @ git+https://github.com/deepmind/distrax.git', - 'brax @ git+https://git@github.com/lenarttreven/brax.git', 'trajax @ git+https://git@github.com/lenarttreven/trajax.git', 'bsm @ git+https://github.com/lasgroup/bayesian_statistical_models.git', ] @@ -25,7 +25,7 @@ extras = {} setup( name='mbpo', - version='0.0.1', + version='0.0.2', license="MIT", packages=find_packages(), python_requires='>=3.9',