diff --git a/mava/configs/default/rec_iql_sebulba.yaml b/mava/configs/default/rec_iql_sebulba.yaml new file mode 100644 index 000000000..7d8cf3bef --- /dev/null +++ b/mava/configs/default/rec_iql_sebulba.yaml @@ -0,0 +1,11 @@ +defaults: + - logger: logger + - arch: sebulba + - system: q_learning/rec_iql + - network: rnn # [rnn, rcnn] + - env: lbf_gym # [rware_gym, lbf_gym, smac_gym] + - _self_ + +hydra: + searchpath: + - file://mava/configs diff --git a/mava/configs/system/q_learning/rec_iql.yaml b/mava/configs/system/q_learning/rec_iql.yaml index 059c55dec..caf733a17 100644 --- a/mava/configs/system/q_learning/rec_iql.yaml +++ b/mava/configs/system/q_learning/rec_iql.yaml @@ -15,9 +15,9 @@ rollout_length: 2 # Number of environment steps per vectorised environment. epochs: 2 # Number of learn epochs per training data batch. # sizes -buffer_size: 5000 # size of the replay buffer. Note: total size is this * num_devices +buffer_size: 1000 # size of the replay buffer. Note: total size is this * num_devices sample_batch_size: 32 # size of training data batch sampled from the buffer -sample_sequence_length: 20 # 20 transitions are sampled, giving 19 complete data points +sample_sequence_length: 32 # 20 transitions are sampled, giving 19 complete data points # learning rates q_lr: 3e-4 # the learning rate of the Q network network optimizer @@ -31,3 +31,7 @@ gamma: 0.99 # discount factor eps_min: 0.05 eps_decay: 1e5 + +# --- Sebulba parameters --- +replay_ratio : 4 # Average number of times the learner should sample each item from the replay buffer. +error_tolerance: ~ # Tolerance for how much the learner/actor can sample/insert before being blocked. Must be greater than 2 to avoid deadlocks. diff --git a/mava/evaluator.py b/mava/evaluator.py index f157b42d0..557a2da5b 100644 --- a/mava/evaluator.py +++ b/mava/evaluator.py @@ -290,8 +290,7 @@ def _episode(key: PRNGKey) -> Tuple[PRNGKey, Metrics]: if config.env.log_win_rate: metrics["won_episode"] = timesteps.extras["won_episode"] - # find the first instance of done to get the metrics at that timestep, we don't - # care about subsequent steps because we only the results from the first episode + # Find the first instance of done to get the metrics at that timestep. done_idx = np.argmax(timesteps.last(), axis=0) metrics = tree.map(lambda m: m[done_idx, np.arange(n_parallel_envs)], metrics) del metrics["is_terminal_step"] # uneeded for logging diff --git a/mava/systems/ppo/sebulba/ff_ippo.py b/mava/systems/ppo/sebulba/ff_ippo.py index c79b654d2..97bd99590 100644 --- a/mava/systems/ppo/sebulba/ff_ippo.py +++ b/mava/systems/ppo/sebulba/ff_ippo.py @@ -51,11 +51,13 @@ ) from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer -from mava.utils.config import check_sebulba_config, check_total_timesteps +from mava.utils.config import check_total_timesteps +from mava.utils.config import ppo_sebulba_checks as check_sebulba_config from mava.utils.jax_utils import merge_leading_dims, switch_leading_axes from mava.utils.logger import LogEvent, MavaLogger from mava.utils.network_utils import get_action_head -from mava.utils.sebulba import ParamsSource, Pipeline, RecordTimeTo, ThreadLifetime +from mava.utils.sebulba.pipelines import Pipeline +from mava.utils.sebulba.utils import ParamsSource, RecordTimeTo from mava.utils.training import make_learning_rate from mava.wrappers.episode_metrics import get_final_step_metrics from mava.wrappers.gym import GymToJumanji @@ -70,7 +72,7 @@ def rollout( apply_fns: Tuple[ActorApply, CriticApply], actor_device: int, seeds: List[int], - thread_lifetime: ThreadLifetime, + stop_event: threading.Event, ) -> None: """Runs rollouts to collect trajectories from the environment. @@ -110,7 +112,7 @@ def act_fn( dones = np.repeat(timestep.last(), num_agents).reshape(num_envs, -1) # Loop till the desired num_updates is reached. - while not thread_lifetime.should_stop(): + while not stop_event.is_set(): # Rollout traj: List[PPOTransition] = [] episode_metrics: List[Dict] = [] @@ -596,20 +598,18 @@ def run_experiment(_config: DictConfig) -> float: inital_params = jax.device_put(learner_state.params, actor_devices[0]) # unreplicate # the rollout queue/ the pipe between actor and learner - pipe_lifetime = ThreadLifetime() - pipe = Pipeline(config.arch.rollout_queue_size, learner_sharding, pipe_lifetime) + pipe = Pipeline(config.arch.rollout_queue_size, learner_sharding) pipe.start() params_sources: List[ParamsSource] = [] actor_threads: List[threading.Thread] = [] - actor_lifetime = ThreadLifetime() - params_sources_lifetime = ThreadLifetime() + actors_stop_event = threading.Event() # Create the actor threads print(f"{Fore.BLUE}{Style.BRIGHT}Starting up actor threads...{Style.RESET_ALL}") for actor_device in actor_devices: # Create 1 params source per device - params_source = ParamsSource(inital_params, actor_device, params_sources_lifetime) + params_source = ParamsSource(inital_params, actor_device) params_source.start() params_sources.append(params_source) # Create multiple rollout threads per actor device @@ -630,7 +630,7 @@ def run_experiment(_config: DictConfig) -> float: apply_fns, actor_device, seeds, - actor_lifetime, + actors_stop_event, ), name=f"Actor-{actor_device}-{thread_id}", ) @@ -708,7 +708,7 @@ def run_experiment(_config: DictConfig) -> float: # Stop all the threads. logger.stop() - actor_lifetime.stop() + actors_stop_event.set() pipe.clear() # We clear the pipeline before stopping the actor threads to avoid deadlock print(f"{Fore.RED}{Style.BRIGHT}Pipe cleared{Style.RESET_ALL}") print(f"{Fore.RED}{Style.BRIGHT}Stopping actor threads...{Style.RESET_ALL}") @@ -716,11 +716,11 @@ def run_experiment(_config: DictConfig) -> float: actor.join() print(f"{Fore.RED}{Style.BRIGHT}{actor.name} stopped{Style.RESET_ALL}") print(f"{Fore.RED}{Style.BRIGHT}Stopping pipeline...{Style.RESET_ALL}") - pipe_lifetime.stop() + pipe.stop() pipe.join() print(f"{Fore.RED}{Style.BRIGHT}Stopping params sources...{Style.RESET_ALL}") - params_sources_lifetime.stop() for params_source in params_sources: + params_source.stop() params_source.join() print(f"{Fore.RED}{Style.BRIGHT}All threads stopped...{Style.RESET_ALL}") diff --git a/mava/systems/q_learning/anakin/rec_iql.py b/mava/systems/q_learning/anakin/rec_iql.py index c8a7b4cf4..4016ac488 100644 --- a/mava/systems/q_learning/anakin/rec_iql.py +++ b/mava/systems/q_learning/anakin/rec_iql.py @@ -44,7 +44,7 @@ TrainState, Transition, ) -from mava.types import MarlEnv, Observation +from mava.types import MarlEnv, MavaObservation from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer from mava.utils.config import check_total_timesteps @@ -209,7 +209,7 @@ def make_update_fns( # ---- Acting functions ---- def select_eps_greedy_action( - action_selection_state: ActionSelectionState, obs: Observation, term_or_trunc: Array + action_selection_state: ActionSelectionState, obs: MavaObservation, term_or_trunc: Array ) -> Tuple[ActionSelectionState, Array]: """Select action to take in epsilon-greedy way. Batch and agent dims are included.""" params, hidden_state, t, key = action_selection_state @@ -279,7 +279,7 @@ def action_step(action_state: ActionState, _: Any) -> Tuple[ActionState, Dict]: # ---- Training functions ---- - def prep_inputs_to_scannedrnn(obs: Observation, term_or_trunc: chex.Array) -> chex.Array: + def prep_inputs_to_scannedrnn(obs: MavaObservation, term_or_trunc: chex.Array) -> chex.Array: """Prepares the inputs to the RNN network for either getting q values or the eps-greedy distribution. diff --git a/mava/systems/q_learning/anakin/rec_qmix.py b/mava/systems/q_learning/anakin/rec_qmix.py index af7bfc62b..b55c588fb 100644 --- a/mava/systems/q_learning/anakin/rec_qmix.py +++ b/mava/systems/q_learning/anakin/rec_qmix.py @@ -44,7 +44,7 @@ TrainState, Transition, ) -from mava.types import MarlEnv, Observation +from mava.types import MarlEnv, MavaObservation from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer from mava.utils.config import check_total_timesteps @@ -241,7 +241,7 @@ def make_update_fns( ) -> Callable[[LearnerState[QMIXParams]], Tuple[LearnerState[QMIXParams], Tuple[Metrics, Metrics]]]: def select_eps_greedy_action( action_selection_state: ActionSelectionState, - obs: Observation, + obs: MavaObservation, term_or_trunc: Array, ) -> Tuple[ActionSelectionState, Array]: """Select action to take in eps-greedy way. Batch and agent dims are included.""" @@ -310,7 +310,7 @@ def action_step(action_state: ActionState, _: Any) -> Tuple[ActionState, Dict]: return new_act_state, next_timestep.extras["episode_metrics"] - def prep_inputs_to_scannedrnn(obs: Observation, term_or_trunc: chex.Array) -> chex.Array: + def prep_inputs_to_scannedrnn(obs: MavaObservation, term_or_trunc: chex.Array) -> chex.Array: """Prepares the inputs to the RNN network for either getting q values or the eps-greedy distribution. diff --git a/mava/systems/q_learning/sebulba/rec_iql.py b/mava/systems/q_learning/sebulba/rec_iql.py new file mode 100644 index 000000000..c997ce944 --- /dev/null +++ b/mava/systems/q_learning/sebulba/rec_iql.py @@ -0,0 +1,803 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import queue +import threading +import warnings +from collections import defaultdict +from queue import Queue +from typing import Any, Dict, List, Sequence, Tuple + +import chex +import hydra +import jax +import jax.lax as lax +import jax.numpy as jnp +import numpy as np +import optax +from colorama import Fore, Style +from flax.core.scope import FrozenVariableDict +from flax.linen import FrozenDict +from jax import Array, tree +from jax.experimental import mesh_utils +from jax.experimental.shard_map import shard_map +from jax.sharding import Mesh, NamedSharding, PartitionSpec, Sharding +from omegaconf import DictConfig, OmegaConf +from rich.pretty import pprint + +from mava.evaluator import get_sebulba_eval_fn as get_eval_fn +from mava.evaluator import make_rec_eval_act_fn +from mava.networks import RecQNetwork, ScannedRNN +from mava.systems.q_learning.types import Metrics, QNetParams, Transition +from mava.systems.q_learning.types import SebulbaLearnerState as LearnerState +from mava.types import Observation, SebulbaLearnerFn +from mava.utils import make_env as environments +from mava.utils.checkpointing import Checkpointer +from mava.utils.config import base_sebulba_checks as check_sebulba_config +from mava.utils.config import check_total_timesteps +from mava.utils.jax_utils import switch_leading_axes +from mava.utils.logger import LogEvent, MavaLogger +from mava.utils.sebulba.pipelines import OffPolicyPipeline as Pipeline +from mava.utils.sebulba.rate_limiters import BlockingRatioLimiter, RateLimiter, SampleToInsertRatio +from mava.utils.sebulba.utils import ParamsSource, RecordTimeTo +from mava.wrappers.episode_metrics import get_final_step_metrics +from mava.wrappers.gym import GymToJumanji + + +def rollout( + key: chex.PRNGKey, + env: GymToJumanji, + config: DictConfig, + rollout_queue: Pipeline, + params_source: ParamsSource, + q_net: RecQNetwork, + actor_device: int, + seeds: List[int], + stop_event: threading.Event, + actor_id: int, +) -> None: + """Collects trajectories from the environment by running rollouts. + + Args: + key: The PRNG key for stochasticity. + env: The environment to interact with. + config: Configuration settings for rollout and environment. + rollout_queue: Queue for sending collected trajectories to the learner. + params_source: Provides the latest network parameters from the learner. + q_net: The Q-network. + actor_device: Index of the actor device to use for rollout. + seeds: Seeds for environment initialization. + thread_lifetime: Controls the thread's lifecycle. + actor_id: Unique identifier for the actor. + """ + name = threading.current_thread().name + print(f"{Fore.BLUE}{Style.BRIGHT}Thread {name} started{Style.RESET_ALL}") + num_agents = config.system.num_agents + move_to_device = lambda x: jax.device_put(x, device=actor_device) + + @jax.jit + def select_eps_greedy_action( + params: FrozenDict, + hidden_state: jax.Array, + obs: Observation, + term_or_trunc: Array, + key: chex.PRNGKey, + t: int, + ) -> Tuple[Array, Array, int]: + """Selects an action epsilon-greedily. + + Args: + params: Network parameters. + hidden_state: Current RNN hidden state. + obs: Observation from the environment. + term_or_trunc: Termination or truncation flag. + key: PRNG key for sampling. + t: Current timestep (used for epsilon decay). + + Returns: + Tuple containing the chosen action, next hidden state, and updated timestep. + """ + + eps = jnp.maximum( + config.system.eps_min, 1 - (t / config.system.eps_decay) * (1 - config.system.eps_min) + ) + + obs = tree.map(lambda x: x[jnp.newaxis, ...], obs) + term_or_trunc = tree.map(lambda x: x[jnp.newaxis, ...], term_or_trunc) + + next_hidden_state, eps_greedy_dist = q_net.apply( + params, hidden_state, (obs, term_or_trunc), eps + ) + + action = eps_greedy_dist.sample(seed=key).squeeze(0) # (B, A) + + return action, next_hidden_state, t + config.arch.num_envs + + next_timestep = env.reset(seed=seeds) + next_dones = next_timestep.last()[..., jnp.newaxis] + + # Initialise hidden states. + hstate = ScannedRNN.initialize_carry( + (config.arch.num_envs, num_agents), config.network.hidden_state_dim + ) + hstate_tpu = tree.map(move_to_device, hstate) + step_count = 0 + + # Loop till the desired num_updates is reached. + while not stop_event.is_set(): + # Rollout + episode_metrics: List[Dict] = [] + traj: List[Transition] = [] + actor_timings: Dict[str, List[float]] = defaultdict(list) + with RecordTimeTo(actor_timings["rollout_time"]): + for _ in range(config.system.rollout_length): + with RecordTimeTo(actor_timings["get_params_time"]): + params = params_source.get() # Get the latest parameters from the learner + + timestep = next_timestep + obs_tpu = move_to_device(timestep.observation) + + dones = move_to_device(next_dones) + + # Get action and value + with RecordTimeTo(actor_timings["compute_action_time"]): + key, act_key = jax.random.split(key) + action, hstate_tpu, step_count = select_eps_greedy_action( + params, hstate_tpu, obs_tpu, dones, act_key, step_count + ) + cpu_action = jax.device_get(action) + + # Step environment + with RecordTimeTo(actor_timings["env_step_time"]): + next_timestep = env.step(cpu_action) + + # Prepare the transation + terminal = (1 - timestep.discount[..., 0, jnp.newaxis]).astype(bool) + next_dones = next_timestep.last()[..., jnp.newaxis] + + # Append data to storage + traj.append( + Transition( + timestep.observation, + action, + next_timestep.reward, + terminal, + next_dones, + next_timestep.extras["real_next_obs"], + ) + ) + + episode_metrics.append(timestep.extras["episode_metrics"]) + + # send trajectories to learner + with RecordTimeTo(actor_timings["rollout_put_time"]): + try: + rollout_queue.put(traj, (actor_timings, episode_metrics), actor_id) + except queue.Full: + err = "Waited too long to add to the rollout queue, killing the actor thread" + warnings.warn(err, stacklevel=2) + break + + env.close() + + +def get_learner_step_fn( + q_net: RecQNetwork, + update_fn: optax.TransformUpdateFn, + config: DictConfig, +) -> SebulbaLearnerFn[LearnerState, Transition]: + """Get the learner function.""" + + def _update_step( + learner_state: LearnerState, + traj_batch: Transition, + ) -> Tuple[LearnerState, Metrics]: + """Performs a single network update. + + Calculates targets based on the input trajectories and updates the Q-network + parameters accordingly. + + Args: + learner_state: Current learner state. + traj_batch: Batch of transitions for training. + """ + + def prep_inputs_to_scannedrnn(obs: Observation, term_or_trunc: chex.Array) -> chex.Array: + """Prepares inputs for the ScannedRNN network. + + Switches leading axes of observations and termination/truncation flags to match the + (T, B, ...) format expected by the RNN. The replay buffer outputs data in (B, T, ...) + format. + + Args: + obs: Observation data. + term_or_trunc: Termination/truncation flags. + + Returns: + Tuple containing the initial hidden state and the formatted input data. + """ + hidden_state = ScannedRNN.initialize_carry( + (obs.agents_view.shape[0], obs.agents_view.shape[2]), + config.network.hidden_state_dim, + ) + # the rb outputs (B, T, ... ) the RNN takes in (T, B, ...) + obs = switch_leading_axes(obs) # (B, T) -> (T, B) + term_or_trunc = switch_leading_axes(term_or_trunc) # (B, T) -> (T, B) + obs_term_or_trunc = (obs, term_or_trunc) + + return hidden_state, obs_term_or_trunc + + def _update_epoch(update_state: Tuple, _: Any) -> Tuple: + """Update the network for a single epoch.""" + + def q_loss_fn( + q_online_params: FrozenVariableDict, + obs: Array, + term_or_trunc: Array, + action: Array, + target: Array, + ) -> Tuple[Array, Metrics]: + # axes switched here to scan over time + hidden_state, obs_term_or_trunc = prep_inputs_to_scannedrnn(obs, term_or_trunc) + + # get online q values of all actions + _, q_online = q_net.apply( + q_online_params, hidden_state, obs_term_or_trunc, method="get_q_values" + ) + q_online = switch_leading_axes(q_online) # (T, B, ...) -> (B, T, ...) + # get the q values of the taken actions and remove extra dim + q_online = jnp.squeeze( + jnp.take_along_axis(q_online, action[..., jnp.newaxis], axis=-1), axis=-1 + ) + q_error = jnp.square(q_online - target) + q_loss = jnp.mean(q_error) # mse + + # pack metrics for logging + loss_info = { + "q_loss": q_loss, + "mean_q": jnp.mean(q_online), + "mean_target": jnp.mean(target), + } + + return q_loss, loss_info + + params, opt_states, t_train, traj_batch = update_state + + # Get data aligned with current/next timestep + data_first = tree.map(lambda x: x[:, :-1, ...], traj_batch) + data_next = tree.map(lambda x: x[:, 1:, ...], traj_batch) + + obs = data_first.obs + term_or_trunc = data_first.term_or_trunc + reward = data_first.reward + action = data_first.action + + # The three following variables all come from the same time step. + # They are stored and accessed in this way because of the `AutoResetWrapper`. + # At the end of an episode `data_first.next_obs` and `data_next.obs` will be + # different, which is why we need to store both. Thus `data_first.next_obs` + # aligns with the `terminal` from `data_next`. + next_obs = data_first.next_obs + next_term_or_trunc = data_next.term_or_trunc + next_terminal = data_next.terminal + + # Scan over each sample + hidden_state, next_obs_term_or_trunc = prep_inputs_to_scannedrnn( + next_obs, next_term_or_trunc + ) + + # eps defaults to 0 + _, next_online_greedy_dist = q_net.apply( + params.online, hidden_state, next_obs_term_or_trunc + ) + + _, next_q_vals_target = q_net.apply( + params.target, hidden_state, next_obs_term_or_trunc, method="get_q_values" + ) + + # Get the greedy action + next_action = next_online_greedy_dist.mode() # (T, B, ...) + + # Double q-value selection + next_q_val = jnp.squeeze( + jnp.take_along_axis(next_q_vals_target, next_action[..., jnp.newaxis], axis=-1), + axis=-1, + ) + + next_q_val = switch_leading_axes(next_q_val) # (T, B, ...) -> (B, T, ...) + + # TD Target + target_q_val = reward + (1.0 - next_terminal) * config.system.gamma * next_q_val + + # Update Q function. + q_grad_fn = jax.grad(q_loss_fn, has_aux=True) + q_grads, q_loss_info = q_grad_fn( + params.online, obs, term_or_trunc, action, target_q_val + ) + + # Mean over the device and batch dimension. + q_grads, q_loss_info = lax.pmean((q_grads, q_loss_info), axis_name="learner_devices") + q_updates, next_opt_state = update_fn(q_grads, opt_states) + next_online_params = optax.apply_updates(params.online, q_updates) + + if config.system.hard_update: + next_target_params = optax.periodic_update( + next_online_params, params.target, t_train, config.system.update_period + ) + else: + next_target_params = optax.incremental_update( + next_online_params, params.target, config.system.tau + ) + + # Repack params and opt_states. + next_params = QNetParams(next_online_params, next_target_params) + + # Repack. + next_state = (next_params, next_opt_state, t_train + 1, traj_batch) + + return next_state, q_loss_info + + update_state = (*learner_state, traj_batch) + update_state, loss_info = jax.lax.scan( + _update_epoch, update_state, None, config.system.epochs + ) + + params, opt_states, train_step, _ = update_state + learner_state = LearnerState(params, opt_states, train_step) + return learner_state, loss_info + + def learner_fn( + learner_state: LearnerState, traj_batch: Transition + ) -> Tuple[LearnerState, Metrics]: + """Learner function. + + This function represents the learner, it updates the network parameters + by iteratively applying the `_update_step` function for a fixed number of + updates. The `_update_step` function is vectorized across learner devices. + + Args: + learner_state (NamedTuple): + - params (Params): The initial model parameters. + - opt_states (OptStates): The initial optimizer state. + - step_counter int): Number of learning steps. + traj_batch (Transition): The collected trainig data. + """ + learner_state, loss_info = _update_step(learner_state, traj_batch) + + return learner_state, loss_info + + return learner_fn + + +def learner_thread( + learn_fn: SebulbaLearnerFn[LearnerState, Transition], + learner_state: LearnerState, + config: DictConfig, + eval_queue: Queue, + pipeline: Pipeline, + params_sources: Sequence[ParamsSource], +) -> None: + for _ in range(config.arch.num_evaluation): + # Create the lists to store metrics and timings for this learning iteration. + ep_metrics_list: List[Dict] = [] + train_metrics: List[Dict] = [] + rollout_times_list: List[Dict] = [] + learn_times: Dict[str, List[float]] = defaultdict(list) + + with RecordTimeTo(learn_times["learner_time_per_eval"]): + for _ in range(config.system.num_updates_per_eval): + # Get the trajectory batch from the pipeline + # This is blocking so it will wait until the pipeline has data. + with RecordTimeTo(learn_times["rollout_get_time"]): + traj_batch, (rollout_time, ep_metric) = pipeline.get() + # Update the networks + with RecordTimeTo(learn_times["learning_time"]): + learner_state, train_metric = learn_fn(learner_state, traj_batch) + + train_metrics.append(train_metric) + if ep_metric is not None: + ep_metrics_list.append(ep_metric) + rollout_times_list.append(rollout_time) + + # Update all the params sources so all actors can get the latest params + params = jax.block_until_ready(learner_state.params) + for source in params_sources: + source.update(params.online) + + # Pass all the metrics and params to the main thread (evaluator) for logging and evaluation + + if ep_metrics_list: + # [{metric1 : (num_envs, ...), ...} * n_rollouts] --> + # {metric1 : (n_rollouts, num_envs, ...), ...] + ep_metrics = tree.map(lambda *x: np.asarray(x), *ep_metrics_list) + + # [{metric1: value1, ...} * n_rollouts] --> + # {metric1: mean(value1_rollout1, value1_rollout2, ...), ...} + rollout_times = tree.map(lambda *x: np.mean(x), *rollout_times_list) + else: + rollout_times = {} + ep_metrics = {} + + train_metrics = tree.map(lambda *x: np.asarray(x), *train_metrics) + + # learn_times : {metric1: (1,) or (num_updates_per_eval,), ...} + time_metrics = rollout_times | learn_times + # time_metrics : {metric1: Array, ...} - > {metric1: mean(Array), ...} + time_metrics = tree.map(np.mean, time_metrics, is_leaf=lambda x: isinstance(x, list)) + + eval_queue.put((ep_metrics, train_metrics, learner_state, time_metrics)) + + +def learner_setup( + key: chex.PRNGKey, config: DictConfig, learner_devices: List +) -> Tuple[ + SebulbaLearnerFn[LearnerState, Transition], + RecQNetwork, + LearnerState, + Sharding, + Transition, +]: + """Initialise learner_fn, network and learner state.""" + + # create temporary environments. + env = environments.make_gym_env(config, 1) + # Get number of agents and actions. + action_space = env.single_action_space + config.system.num_agents = len(action_space) + config.system.num_actions = int(action_space[0].n) + + devices = mesh_utils.create_device_mesh((len(learner_devices),), devices=learner_devices) + mesh = Mesh(devices, axis_names=("learner_devices")) + model_spec = PartitionSpec() + data_spec = PartitionSpec("learner_devices") + learner_sharding = NamedSharding(mesh, model_spec) + + key, q_key = jax.random.split(key, 2) + # Shape legend: + # T: Time (dummy dimension size = 1) + # B: Batch (dummy dimension size = 1) + # A: Agent + # Make dummy inputs to init recurrent Q network -> need shape (T, B, A, ...) + init_agents_view = jnp.array(env.single_observation_space.sample()) + init_action_mask = jnp.ones((config.system.num_agents, config.system.num_actions)) + init_obs = Observation(init_agents_view, init_action_mask) # (A, ...) + # (B, T, A, ...) + init_obs_batched = tree.map(lambda x: x[jnp.newaxis, jnp.newaxis, ...], init_obs) + dones = jnp.zeros((1, 1, 1), dtype=bool) # (T, B, 1) + init_x = (init_obs_batched, dones) # pack the RNN dummy inputs + # (B, A, ...) + init_hidden_state = ScannedRNN.initialize_carry( + (config.system.sample_batch_size, config.system.num_agents), config.network.hidden_state_dim + ) + + # Making recurrent Q network. + pre_torso = hydra.utils.instantiate(config.network.q_network.pre_torso) + post_torso = hydra.utils.instantiate(config.network.q_network.post_torso) + q_net = RecQNetwork( + pre_torso, + post_torso, + config.system.num_actions, + config.network.hidden_state_dim, + ) + q_params = q_net.init(q_key, init_hidden_state, init_x) # epsilon defaults to 0 + q_target_params = q_net.init(q_key, init_hidden_state, init_x) # ensure parameters are separate + + # Pack Q network params + params = QNetParams(q_params, q_target_params) + + # Making optimiser and state + opt = optax.chain( + optax.clip_by_global_norm(config.system.max_grad_norm), + optax.adam(learning_rate=config.system.q_lr, eps=1e-5), + ) + opt_state = opt.init(params.online) + + # Create dummy transition Used to initialiwe the pipeline's Buffer + init_acts = env.single_action_space.sample() # (A,) + init_transition = Transition( + obs=init_obs, # (A, ...) + action=init_acts, + reward=jnp.zeros((config.system.num_agents,), dtype=float), + terminal=jnp.zeros((1,), dtype=bool), # one flag for all agents + term_or_trunc=jnp.zeros((1,), dtype=bool), + next_obs=init_obs, + ) + + learn_state_spec = LearnerState(model_spec, model_spec, model_spec) + learn = get_learner_step_fn(q_net, opt.update, config) + learn = jax.jit( + shard_map( + learn, + mesh=mesh, + in_specs=(learn_state_spec, data_spec), + out_specs=(learn_state_spec, data_spec), + ) + ) + + # Load model from checkpoint if specified. + if config.logger.checkpointing.load_model: + loaded_checkpoint = Checkpointer( + model_name=config.logger.system_name, + **config.logger.checkpointing.load_args, # Other checkpoint args + ) + # Restore the learner state from the checkpoint + restored_params, _ = loaded_checkpoint.restore_params(input_params=params) + # Update the params + params = restored_params + + # Duplicate learner across Learner devices. + params, opt_state = jax.device_put((params, opt_state), learner_sharding) + + # Initial learner state. + init_learner_state = LearnerState(params, opt_state, 0) + + env.close() + return learn, q_net, init_learner_state, learner_sharding, init_transition + + +def run_experiment(_config: DictConfig) -> float: + """Runs experiment.""" + config = copy.deepcopy(_config) + + local_devices = jax.local_devices() + devices = jax.devices() + err = "Local and global devices must be the same, we dont support multihost yet" + assert len(local_devices) == len(devices), err + learner_devices = [devices[d_id] for d_id in config.arch.learner_device_ids] + actor_devices = [local_devices[device_id] for device_id in config.arch.actor_device_ids] + + # JAX and numpy RNGs + key = jax.random.PRNGKey(config.system.seed) + np_rng = np.random.default_rng(config.system.seed) + + # Setup learner. + learn, q_net, learner_state, learner_sharding, init_transition = learner_setup( + key, config, learner_devices + ) + + # Setup evaluator. + # One key per device for evaluation. + eval_act_fn = make_rec_eval_act_fn(q_net.apply, config) + + evaluator, evaluator_envs = get_eval_fn( + environments.make_gym_env, eval_act_fn, config, np_rng, absolute_metric=False + ) + + # Calculate total timesteps. + config = check_total_timesteps(config) + check_sebulba_config(config) + + steps_per_rollout = ( + config.system.rollout_length + * config.arch.num_envs + * config.system.num_updates_per_eval + * len(config.arch.actor_device_ids) + * config.arch.n_threads_per_executor + ) + + # Setup RateLimiter + # Replay_ratio = num_gradient_updates / num_env_steps + # num_gradient_updates = sample_batch_size * epochs * rollout_length * samples_per_insert + num_updates_per_insert = ( + config.system.epochs * config.system.sample_batch_size * config.system.rollout_length + ) + num_setps_per_insert = ( + config.system.sample_sequence_length + * config.arch.num_envs + * len(config.arch.actor_device_ids) + * config.arch.n_threads_per_executor + ) + config.system.sample_per_insert = ( + num_setps_per_insert * config.system.replay_ratio + ) / num_updates_per_insert + + min_num_inserts = max( + config.system.sample_sequence_length // config.system.rollout_length, + config.system.min_buffer_size // config.system.rollout_length, + 1, + ) + + rate_limiter: RateLimiter + if config.system.error_tolerance: + rate_limiter = SampleToInsertRatio( + config.system.sample_per_insert, min_num_inserts, config.system.error_tolerance + ) + else: + rate_limiter = BlockingRatioLimiter(config.system.sample_per_insert, min_num_inserts) + + # Setup logger + logger = MavaLogger(config) + print_cfg: Dict = OmegaConf.to_container(config, resolve=True) + print_cfg["arch"]["devices"] = jax.devices() + pprint(print_cfg) + + # Set up checkpointer + save_checkpoint = config.logger.checkpointing.save_model + if save_checkpoint: + checkpointer = Checkpointer( + metadata=config, # Save all config as metadata in the checkpoint + model_name=config.logger.system_name, + **config.logger.checkpointing.save_args, # Checkpoint args + ) + + # Executor setup and launch. + inital_params = jax.device_put(learner_state.params, actor_devices[0]) # unreplicate + + # Setup Pipeline + pipe = Pipeline(config, learner_sharding, key, rate_limiter, init_transition) + pipe.start() + + params_sources: List[ParamsSource] = [] + actor_threads: List[threading.Thread] = [] + actors_stop_event = threading.Event() + # Create the actor threads + print(f"{Fore.BLUE}{Style.BRIGHT}Starting up actor threads...{Style.RESET_ALL}") + for device_idx, actor_device in enumerate(actor_devices): + # Create 1 params source per device + params_source = ParamsSource(inital_params.online, actor_device) + params_source.start() + params_sources.append(params_source) + # Create multiple rollout threads per actor device + for thread_id in range(config.arch.n_threads_per_executor): + key, act_key = jax.random.split(key) + seeds = np_rng.integers(np.iinfo(np.int32).max, size=config.arch.num_envs).tolist() + act_key = jax.device_put(key, actor_device) + actor_id = device_idx * config.arch.n_threads_per_executor + thread_id + actor = threading.Thread( + target=rollout, + args=( + act_key, + # We have to do this here, creating envs inside actor threads causes deadlocks + environments.make_gym_env(config, config.arch.num_envs), + config, + pipe, + params_source, + q_net, + actor_device, + seeds, + actors_stop_event, + actor_id, + ), + name=f"Actor-{actor_device}-{thread_id}", + ) + actor_threads.append(actor) + + # Start the actors simultaneously + for actor in actor_threads: + actor.start() + + eval_queue: Queue = Queue() + threading.Thread( + target=learner_thread, + name="Learner", + args=(learn, learner_state, config, eval_queue, pipe, params_sources), + ).start() + + max_episode_return = -np.inf + best_params_cpu = jax.device_get(inital_params.online) + + eval_hs = ScannedRNN.initialize_carry( + (min(config.arch.num_eval_episodes, config.arch.num_envs), config.system.num_agents), + config.network.hidden_state_dim, + ) + + # This is the main loop, all it does is evaluation and logging. + # Acting and learning is happening in their own threads. + # This loop waits for the learner to finish an update before evaluation and logging. + for eval_step in range(config.arch.num_evaluation): + # Sync with the learner - the get() is blocking so it keeps eval and learning in step. + episode_metrics, train_metrics, learner_state, time_metrics = eval_queue.get() + + t = int(steps_per_rollout * (eval_step + 1)) + time_metrics |= {"timestep": t, "pipline_size": pipe.qsize()} + logger.log(time_metrics, t, eval_step, LogEvent.MISC) + + if episode_metrics: + episode_metrics, ep_completed = get_final_step_metrics(episode_metrics) + episode_metrics["steps_per_second"] = steps_per_rollout / time_metrics["rollout_time"] + if ep_completed: + logger.log(episode_metrics, t, eval_step, LogEvent.ACT) + + train_metrics["learner_step"] = (eval_step + 1) * config.system.num_updates_per_eval + train_metrics["learner_steps_per_second"] = ( + config.system.num_updates_per_eval + ) / time_metrics["learner_time_per_eval"] + logger.log(train_metrics, t, eval_step, LogEvent.TRAIN) + + learner_state_cpu = jax.device_get(learner_state) + key, eval_key = jax.random.split(key, 2) + eval_metrics = evaluator( + learner_state_cpu.params.online, eval_key, {"hidden_state": eval_hs} + ) + logger.log(eval_metrics, t, eval_step, LogEvent.EVAL) + + episode_return = np.mean(eval_metrics["episode_return"]) + + if save_checkpoint: # Save a checkpoint of the learner state + checkpointer.save( + timestep=steps_per_rollout * (eval_step + 1), + unreplicated_learner_state=learner_state_cpu, + episode_return=episode_return, + ) + + if config.arch.absolute_metric and (max_episode_return <= episode_return): + best_params_cpu = copy.deepcopy(learner_state_cpu.params.online) + max_episode_return = float(episode_return) + + evaluator_envs.close() + eval_performance = float(np.mean(eval_metrics[config.env.eval_metric])) + + # Measure absolute metric. + if config.arch.absolute_metric: + print(f"{Fore.BLUE}{Style.BRIGHT}Measuring absolute metric...{Style.RESET_ALL}") + abs_metric_evaluator, abs_metric_evaluator_envs = get_eval_fn( + environments.make_gym_env, eval_act_fn, config, np_rng, absolute_metric=True + ) + key, eval_key = jax.random.split(key, 2) + eval_hs = ScannedRNN.initialize_carry( + ( + min(config.arch.num_absolute_metric_eval_episodes, config.arch.num_envs), + config.system.num_agents, + ), + config.network.hidden_state_dim, + ) + eval_metrics = abs_metric_evaluator(best_params_cpu, eval_key, {"hidden_state": eval_hs}) + + t = int(steps_per_rollout * (eval_step + 1)) + logger.log(eval_metrics, t, eval_step, LogEvent.ABSOLUTE) + abs_metric_evaluator_envs.close() + + # Stop all the threads. + logger.stop() + actors_stop_event.set() + pipe.clear() # We clear the pipeline before stopping the actor threads to avoid deadlock + print(f"{Fore.RED}{Style.BRIGHT}Pipe cleared{Style.RESET_ALL}") + print(f"{Fore.RED}{Style.BRIGHT}Stopping actor threads...{Style.RESET_ALL}") + for actor in actor_threads: + actor.join() + print(f"{Fore.RED}{Style.BRIGHT}{actor.name} stopped{Style.RESET_ALL}") + print(f"{Fore.RED}{Style.BRIGHT}Stopping pipeline...{Style.RESET_ALL}") + pipe.stop() + pipe.join() + print(f"{Fore.RED}{Style.BRIGHT}Stopping params sources...{Style.RESET_ALL}") + for params_source in params_sources: + params_source.stop() + params_source.join() + print(f"{Fore.RED}{Style.BRIGHT}All threads stopped...{Style.RESET_ALL}") + + return eval_performance + + +@hydra.main( + config_path="../../../configs/default", + config_name="rec_iql_sebulba.yaml", + version_base="1.2", +) +def hydra_entry_point(cfg: DictConfig) -> float: + """Experiment entry point.""" + # Allow dynamic attributes. + OmegaConf.set_struct(cfg, False) + cfg.logger.system_name = "rec_iql" + + # Run experiment. + final_return = run_experiment(cfg) + + print(f"{Fore.CYAN}{Style.BRIGHT}IDQN experiment completed{Style.RESET_ALL}") + + return float(final_return) + + +if __name__ == "__main__": + hydra_entry_point() diff --git a/mava/systems/q_learning/types.py b/mava/systems/q_learning/types.py index 8e0cd8125..6c5fc3f5b 100644 --- a/mava/systems/q_learning/types.py +++ b/mava/systems/q_learning/types.py @@ -21,7 +21,7 @@ from jumanji.env import State from typing_extensions import NamedTuple, TypeAlias -from mava.types import Observation +from mava.types import MavaObservation Metrics = Dict[str, Array] @@ -29,14 +29,14 @@ class Transition(NamedTuple): """Transition for recurrent Q-learning.""" - obs: Observation + obs: MavaObservation action: Array reward: Array terminal: Array term_or_trunc: Array # Even though we use a trajectory buffer we need to store both obs and next_obs. # This is because of how the `AutoResetWrapper` returns obs at the end of an episode. - next_obs: Observation + next_obs: MavaObservation BufferState: TypeAlias = TrajectoryBufferState[Transition] @@ -64,7 +64,7 @@ class ActionState(NamedTuple): action_selection_state: ActionSelectionState env_state: State buffer_state: BufferState - obs: Observation + obs: MavaObservation terminal: Array term_or_trunc: Array @@ -83,7 +83,7 @@ class LearnerState(NamedTuple, Generic[QLearningParams]): """State of the learner in an interaction-training loop.""" # Interaction vars - obs: Observation + obs: MavaObservation terminal: Array term_or_trunc: Array hidden_state: Array @@ -108,3 +108,11 @@ class TrainState(NamedTuple, Generic[QLearningParams]): opt_state: optax.OptState train_steps: Array key: PRNGKey + + +class SebulbaLearnerState(NamedTuple): + """State of the learner for the Sebulba architecture.""" + + params: QNetParams + opt_states: optax.OptState + step_counter: int diff --git a/mava/utils/config.py b/mava/utils/config.py index 23484311b..31d01d48b 100644 --- a/mava/utils/config.py +++ b/mava/utils/config.py @@ -18,7 +18,7 @@ from omegaconf import DictConfig -def check_sebulba_config(config: DictConfig) -> None: +def base_sebulba_checks(config: DictConfig) -> None: """Checks that the given config does not have conflicting values.""" assert ( config.system.num_updates > config.arch.num_evaluation @@ -30,6 +30,10 @@ def check_sebulba_config(config: DictConfig) -> None: + "The output of each actor is equally split across the learners." ) + +def ppo_sebulba_checks(config: DictConfig) -> None: + base_sebulba_checks(config) + num_eval_samples = ( int(config.arch.num_envs / len(config.arch.learner_device_ids)) * config.system.rollout_length diff --git a/mava/utils/sebulba.py b/mava/utils/sebulba.py deleted file mode 100644 index 1fe441e2a..000000000 --- a/mava/utils/sebulba.py +++ /dev/null @@ -1,199 +0,0 @@ -# Copyright 2022 InstaDeep Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import queue -import threading -import time -from typing import Any, Dict, List, Sequence, Tuple, Union - -import jax -import jax.numpy as jnp -from colorama import Fore, Style -from jax import tree -from jax.sharding import Sharding -from jumanji.types import TimeStep - -# todo: remove the ppo dependencies when we make sebulba for other systems -from mava.systems.ppo.types import Params, PPOTransition -from mava.types import Metrics - -QUEUE_PUT_TIMEOUT = 100 - - -class ThreadLifetime: - """Simple class for a mutable boolean that can be used to signal a thread to stop.""" - - def __init__(self) -> None: - self._stop = False - - def should_stop(self) -> bool: - return self._stop - - def stop(self) -> None: - self._stop = True - - -@jax.jit -def _stack_trajectory(trajectory: List[PPOTransition]) -> PPOTransition: - """Stack a list of parallel_env transitions into a single - transition of shape [rollout_len, num_envs, ...].""" - return tree.map(lambda *x: jnp.stack(x, axis=0).swapaxes(0, 1), *trajectory) # type: ignore - - -# Modified from https://github.com/instadeepai/sebulba/blob/main/sebulba/core.py -class Pipeline(threading.Thread): - """ - The `Pipeline` shards trajectories into learner devices, - ensuring trajectories are consumed in the right order to avoid being off-policy - and limit the max number of samples in device memory at one time to avoid OOM issues. - """ - - def __init__(self, max_size: int, learner_sharding: Sharding, lifetime: ThreadLifetime): - """ - Initializes the pipeline with a maximum size and the devices to shard trajectories across. - - Args: - max_size: The maximum number of trajectories to keep in the pipeline. - learner_sharding: The sharding used for the learner's update function. - lifetime: A `ThreadLifetime` which is used to stop this thread. - """ - super().__init__(name="Pipeline") - - self.sharding = learner_sharding - self.tickets_queue: queue.Queue = queue.Queue() - self._queue: queue.Queue = queue.Queue(maxsize=max_size) - self.lifetime = lifetime - - def run(self) -> None: - """This function ensures that trajectories on the queue are consumed in the right order. The - start_condition and end_condition are used to ensure that only 1 thread is processing an - item from the queue at one time, ensuring predictable memory usage. - """ - while not self.lifetime.should_stop(): - try: - start_condition, end_condition = self.tickets_queue.get(timeout=1) - with end_condition: - with start_condition: - start_condition.notify() - end_condition.wait() - except queue.Empty: - continue - - def put( - self, traj: Sequence[PPOTransition], timestep: TimeStep, metrics: Tuple[Dict, List[Dict]] - ) -> None: - """Put a trajectory on the queue to be consumed by the learner.""" - start_condition, end_condition = (threading.Condition(), threading.Condition()) - with start_condition: - self.tickets_queue.put((start_condition, end_condition)) - start_condition.wait() # wait to be allowed to start - - # [Transition(num_envs)] * rollout_len -> Transition[done=(num_envs, rollout_len, ...)] - traj = _stack_trajectory(traj) - traj, timestep = jax.device_put((traj, timestep), device=self.sharding) - - time_dict, episode_metrics = metrics - # [{'metric1' : value1, ...} * rollout_len -> {'metric1' : [value1, value2, ...], ...} - episode_metrics = _stack_trajectory(episode_metrics) - - # We block on the `put` to ensure that actors wait for the learners to catch up. - # This ensures two things: - # The actors don't get too far ahead of the learners, which could lead to off-policy data. - # The actors don't "waste" samples by generating samples that the learners can't consume. - # However, we put a timeout of 100 seconds to avoid deadlocks in case the learner - # is not consuming the data. This is a safety measure and should not normally occur. - # We use a try-finally so the lock is released even if an exception is raised. - try: - self._queue.put( - (traj, timestep, time_dict, episode_metrics), - block=True, - timeout=QUEUE_PUT_TIMEOUT, - ) - except queue.Full: - print( - f"{Fore.RED}{Style.BRIGHT}Pipeline is full and actor has timed out, " - f"this should not happen. A deadlock might be occurring{Style.RESET_ALL}" - ) - finally: - with end_condition: - end_condition.notify() # notify that we have finished - - def qsize(self) -> int: - """Returns the number of trajectories in the pipeline.""" - return self._queue.qsize() - - def get( - self, block: bool = True, timeout: Union[float, None] = None - ) -> Tuple[PPOTransition, TimeStep, Dict, Metrics]: - """Get a trajectory from the pipeline.""" - return self._queue.get(block, timeout) # type: ignore - - def clear(self) -> None: - """Clear the pipeline.""" - while not self._queue.empty(): - try: - self._queue.get(block=False) - except queue.Empty: - break - - -class ParamsSource(threading.Thread): - """A `ParamSource` is a component that allows networks params to be passed from a - `Learner` component to `Actor` components. - """ - - def __init__(self, init_value: Params, device: jax.Device, lifetime: ThreadLifetime): - super().__init__(name=f"ParamsSource-{device.id}") - self.value: Params = jax.device_put(init_value, device) - self.device = device - self.new_value: queue.Queue = queue.Queue() - self.lifetime = lifetime - - def run(self) -> None: - """This function is responsible for updating the value of the `ParamSource` when a new value - is available. - """ - while not self.lifetime.should_stop(): - try: - waiting = self.new_value.get(block=True, timeout=1) - self.value = jax.device_put(waiting, self.device) - except queue.Empty: - continue - - def update(self, new_params: Params) -> None: - """Update the value of the `ParamSource` with a new value. - - Args: - new_params: The new value to update the `ParamSource` with. - """ - self.new_value.put(new_params) - - def get(self) -> Params: - """Get the current value of the `ParamSource`.""" - return self.value - - -class RecordTimeTo: - """Context manager to record the runtime in a `with` block""" - - def __init__(self, to: Any): - self.to = to - - def __enter__(self) -> None: - self.start = time.monotonic() - - def __exit__(self, *args: Any) -> None: - end = time.monotonic() - self.to.append(end - self.start) diff --git a/mava/utils/sebulba/pipelines.py b/mava/utils/sebulba/pipelines.py new file mode 100644 index 000000000..38453af2d --- /dev/null +++ b/mava/utils/sebulba/pipelines.py @@ -0,0 +1,300 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import queue +import threading +from typing import Any, Dict, List, Sequence, Tuple, Union + +import jax +import jax.numpy as jnp +import numpy as np +from colorama import Fore, Style +from flashbax import make_trajectory_buffer +from jax import tree +from jax.sharding import Sharding +from jumanji.types import TimeStep +from omegaconf import DictConfig + +from mava.systems.ppo.types import PPOTransition +from mava.systems.q_learning.types import Transition +from mava.types import Metrics +from mava.utils.sebulba.rate_limiters import RateLimiter + +QUEUE_PUT_TIMEOUT = 100 + + +@jax.jit +def _stack_trajectory( + trajectory: Union[List[PPOTransition], List[Transition]], +) -> Union[PPOTransition, Transition]: + """Stack a list of parallel_env transitions into a single + transition of shape [rollout_len, num_envs, ...].""" + return tree.map(lambda *x: jnp.stack(x, axis=0).swapaxes(0, 1), *trajectory) # type: ignore + + +# Modified from https://github.com/instadeepai/sebulba/blob/main/sebulba/core.py +class Pipeline(threading.Thread): + """ + The `Pipeline` shards trajectories into learner devices, + ensuring trajectories are consumed in the right order to avoid being off-policy + and limit the max number of samples in device memory at one time to avoid OOM issues. + """ + + def __init__(self, max_size: int, learner_sharding: Sharding): + """ + Initializes the pipeline with a maximum size and the devices to shard trajectories across. + + Args: + max_size: The maximum number of trajectories to keep in the pipeline. + learner_sharding: The sharding used for the learner's update function. + lifetime: A `ThreadLifetime` which is used to stop this thread. + """ + super().__init__(name="Pipeline") + + self.sharding = learner_sharding + self.tickets_queue: queue.Queue = queue.Queue() + self._queue: queue.Queue = queue.Queue(maxsize=max_size) + self._stop_event = threading.Event() + + def run(self) -> None: + """This function ensures that trajectories on the queue are consumed in the right order. The + start_condition and end_condition are used to ensure that only 1 thread is processing an + item from the queue at one time, ensuring predictable memory usage. + """ + while not self._stop_event.is_set(): + try: + start_condition, end_condition = self.tickets_queue.get(timeout=1) + with end_condition: + with start_condition: + start_condition.notify() + end_condition.wait() + except queue.Empty: + continue + + def put( + self, traj: Sequence[PPOTransition], timestep: TimeStep, metrics: Tuple[Dict, List[Dict]] + ) -> None: + """Put a trajectory on the queue to be consumed by the learner.""" + start_condition, end_condition = (threading.Condition(), threading.Condition()) + with start_condition: + self.tickets_queue.put((start_condition, end_condition)) + start_condition.wait() # wait to be allowed to start + + # [Transition(num_envs)] * rollout_len -> Transition[done=(num_envs, rollout_len, ...)] + traj = _stack_trajectory(traj) + traj, timestep = jax.device_put((traj, timestep), device=self.sharding) + + time_dict, episode_metrics = metrics + # [{'metric1' : value1, ...} * rollout_len -> {'metric1' : [value1, value2, ...], ...} + episode_metrics = _stack_trajectory(episode_metrics) + + # We block on the `put` to ensure that actors wait for the learners to catch up. + # This ensures two things: + # The actors don't get too far ahead of the learners, which could lead to off-policy data. + # The actors don't "waste" samples by generating samples that the learners can't consume. + # However, we put a timeout of 100 seconds to avoid deadlocks in case the learner + # is not consuming the data. This is a safety measure and should not normally occur. + # We use a try-finally so the lock is released even if an exception is raised. + try: + self._queue.put( + (traj, timestep, time_dict, episode_metrics), + block=True, + timeout=QUEUE_PUT_TIMEOUT, + ) + except queue.Full: + print( + f"{Fore.RED}{Style.BRIGHT}Pipeline is full and actor has timed out, " + f"this should not happen. A deadlock might be occurring{Style.RESET_ALL}" + ) + finally: + with end_condition: + end_condition.notify() # notify that we have finished + + def qsize(self) -> int: + """Returns the number of trajectories in the pipeline.""" + return self._queue.qsize() + + def get( + self, block: bool = True, timeout: Union[float, None] = None + ) -> Tuple[PPOTransition, TimeStep, Dict, Metrics]: + """Get a trajectory from the pipeline.""" + return self._queue.get(block, timeout) # type: ignore + + def clear(self) -> None: + """Clear the pipeline.""" + while not self._queue.empty(): + try: + self._queue.get(block=False) + except queue.Empty: + break + + def stop(self) -> None: + """Signal the thread to stop.""" + self._stop_event.set() + + +class OffPolicyPipeline(threading.Thread): + """ + The `Pipeline` shards trajectories into learner devices, + ensuring trajectories are consumed in the right order to avoid being off-policy + and limit the max number of samples in device memory at one time to avoid OOM issues. + """ + + def __init__( + self, + config: DictConfig, + learner_sharding: Sharding, + key: jax.random.PRNGKey, + rate_limiter: RateLimiter, + init_transition: Transition, + ): + """ + Initializes the pipeline with a maximum size and the devices to shard trajectories across. + + Args: + config: Configuration settings for buffers. + learner_sharding: The sharding used for the learner's update function. + key: The PRNG key for stochasticity. + rate_limiter: A `RateLimiter` Used to manage how often we are allowed to + sample from the buffers. + init_transition : A sample trasition used to initialize the buffers. + lifetime: A `ThreadLifetime` which is used to stop this thread. + """ + super().__init__(name="Pipeline") + self.cpu = jax.devices("cpu")[0] + + self.tickets_queue: queue.Queue = queue.Queue() + self.metrics_queue: queue.Queue = queue.Queue(maxsize=100) + self._stop_event = threading.Event() + + self.num_buffers = len(config.arch.actor_device_ids) * config.arch.n_threads_per_executor + self.rate_limiter = rate_limiter + self.sharding = learner_sharding + self.key = key + + assert config.system.sample_batch_size % self.num_buffers == 0, ( + f"The sample batch size ({config.system.sample_batch_size}) must be divisible " + f"by the total number of actors ({self.num_buffers})." + ) + + # Setup Buffers + rb = make_trajectory_buffer( + sample_sequence_length=config.system.sample_sequence_length + 1, + period=1, + add_batch_size=config.arch.num_envs, + sample_batch_size=config.system.sample_batch_size // self.num_buffers, + max_length_time_axis=config.system.buffer_size, + min_length_time_axis=config.system.min_buffer_size, + ) + self.buffer_states = [rb.init(init_transition) for _ in range(self.num_buffers)] + self.buffer_adds_count = [0] * self.num_buffers + + # Setup functions + self.buffer_add = jax.jit(rb.add, device=self.cpu) + self.buffer_sample = jax.jit(rb.sample, device=self.cpu) + + def run(self) -> None: + """This function ensures that trajectories on the queue are consumed in the right order. The + start_condition and end_condition are used to ensure that only 1 thread is processing an + item from the queue at one time, ensuring predictable memory usage. + """ + while not self._stop_event.is_set(): + try: + start_condition, end_condition = self.tickets_queue.get(timeout=1) + with end_condition: + with start_condition: + start_condition.notify() + end_condition.wait() + except queue.Empty: + continue + + def put(self, traj: Sequence[Transition], metrics: Tuple, actor_id: int) -> None: + start_condition, end_condition = (threading.Condition(), threading.Condition()) + with start_condition: + self.tickets_queue.put((start_condition, end_condition)) + start_condition.wait() + + try: + self.rate_limiter.await_can_insert(timeout=QUEUE_PUT_TIMEOUT) + except TimeoutError: + print( + f"{Fore.RED}{Style.BRIGHT}Actor has timed out on insertion, " + f"this should not happen. A deadlock might be occurring{Style.RESET_ALL}" + ) + + traj = jax.device_get(traj) + # [Transition(num_envs)] * rollout_len -> Transition[done=(num_envs, rollout_len, ...)] + traj = _stack_trajectory(traj) + + time_dict, episode_metrics = metrics + # [{'metric1' : value1, ...} * rollout_len -> {'metric1' : [value1, value2, ...], ...} + episode_metrics = _stack_trajectory(episode_metrics) + + self.buffer_states[actor_id] = self.buffer_add(self.buffer_states[actor_id], traj) + self.buffer_adds_count[actor_id] += 1 + + if self.metrics_queue.full(): + self.metrics_queue.get() # remove the oldest entry + + self.metrics_queue.put((time_dict, episode_metrics)) + + self.rate_limiter.insert(1 / self.num_buffers) + + with end_condition: + end_condition.notify() # notify that we have finished + + def get(self, timeout: Union[float, None] = None) -> Tuple[Transition, Any]: + """Get a trajectory from the pipeline.""" + self.key, sample_key = jax.random.split(self.key) + + # wait until we can sample the data + try: + self.rate_limiter.await_can_sample(timeout=timeout) + except TimeoutError: + print( + f"{Fore.RED}{Style.BRIGHT}Learner has timed out on sampling, " + f"this should not happen. A deadlock might be occurring{Style.RESET_ALL}" + ) + + # Sample the data + # Potential deadlock risk here. Although it hasn't occurred during testing. + # if an unexplained deadlock happens, it is likely due to this section. + sampled_batch: List[Transition] = [ + self.buffer_sample(state, sample_key).experience for state in self.buffer_states + ] + transitions: Transition = tree.map(lambda *x: np.concatenate(x), *sampled_batch) + transitions = jax.device_put(transitions, device=self.sharding) + + self.rate_limiter.sample() + + if not self.metrics_queue.empty(): + return transitions, self.metrics_queue.get() + + return transitions, (None, None) + + def clear(self) -> None: + """Clear the pipeline.""" + while not self.metrics_queue.empty(): + try: + self.metrics_queue.get(block=False) + except queue.Empty: + break + + def qsize(self) -> int: + """Returns the number of trajectories in the pipeline.""" + return self.metrics_queue.qsize() + + def stop(self) -> None: + """Signal the thread to stop.""" + self._stop_event.set() diff --git a/mava/utils/sebulba/rate_limiters.py b/mava/utils/sebulba/rate_limiters.py new file mode 100644 index 000000000..3eb443178 --- /dev/null +++ b/mava/utils/sebulba/rate_limiters.py @@ -0,0 +1,290 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import threading +from math import ceil +from typing import Optional, Tuple, Union + +from colorama import Fore, Style + + +# from https://github.com/EdanToledo/Stoix/blob/feat/sebulba-dqn/stoix/utils/rate_limiters.py +class RateLimiter: + """ + Rate limiter to control the ratio of samples to inserts. + + This class is designed to regulate the rate at which samples are drawn + compared to the rate at which new data is inserted. + + Args: + samples_per_insert (float): The target ratio of samples to inserts. + For example, a value of 4.0 means the system aims to sample 4 times + for every insertion. + min_size_to_sample (int): The minimum number of inserts required before + sampling is allowed. + min_diff (float): The minimum acceptable difference between the expected + number of samples (based on inserts and `samples_per_insert`) and + the actual number of samples. Sampling is allowed if the difference + is greater than or equal to this value. + max_diff (float): The maximum acceptable difference between the expected + number of samples and the actual number of samples. Inserting is allowed + if the difference is less than or equal to this value. + """ + + def __init__( + self, samples_per_insert: float, min_size_to_sample: float, min_diff: float, max_diff: float + ): + assert min_size_to_sample > 0, "min_size_to_sample must be greater than 0" + assert samples_per_insert > 0, "samples_per_insert must be greater than 0" + + self.samples_per_insert = samples_per_insert + self.min_diff = min_diff + self.max_diff = max_diff + self.min_size_to_sample = min_size_to_sample + + self.inserts = 0.0 + self.samples = 0 + + self.mutex = threading.Lock() + self.condition = threading.Condition(self.mutex) + + def num_inserts(self) -> float: + """Returns the number of inserts.""" + with self.mutex: + return self.inserts + + def num_samples(self) -> int: + """Returns the number of samples.""" + with self.mutex: + return self.samples + + def insert(self, insert_fraction: float = 1) -> None: + """Increment the number of inserts and notify all waiting threads.""" + with self.mutex: + self.inserts += insert_fraction + self.condition.notify_all() # Notify all waiting threads + + def sample(self) -> None: + """Increment the number of samples and notify all waiting threads.""" + with self.mutex: + self.samples += 1 + self.condition.notify_all() # Notify all waiting threads + + def can_insert(self, num_inserts: int) -> bool: + """Check if the caller can insert `num_inserts` items.""" + # Assume lock is already held by the caller + if num_inserts <= 0: + return False + if ceil(self.inserts) + num_inserts <= self.min_size_to_sample: + return True + diff = (num_inserts + ceil(self.inserts)) * self.samples_per_insert - self.samples + return diff <= self.max_diff + + def can_sample(self, num_samples: int) -> bool: + """Check if the caller can sample `num_samples` items.""" + # Assume lock is already held by the caller + if num_samples <= 0: + return False + if ceil(self.inserts) < self.min_size_to_sample: + return False + diff = ceil(self.inserts) * self.samples_per_insert - self.samples - num_samples + return diff >= self.min_diff + + def await_can_insert(self, num_inserts: int = 1, timeout: Optional[float] = None) -> bool: + """Wait until the caller can insert `num_inserts` items.""" + with self.condition: + result = self.condition.wait_for(lambda: self.can_insert(num_inserts), timeout) + if not result: + raise TimeoutError(f"Timeout occurred while waiting to insert {num_inserts} items.") + return result + + def await_can_sample(self, num_samples: int = 1, timeout: Optional[float] = None) -> bool: + """Wait until the caller can sample `num_samples` items.""" + with self.condition: + result = self.condition.wait_for(lambda: self.can_sample(num_samples), timeout) + if not result: + raise TimeoutError(f"Timeout occurred while waiting to sample {num_samples} items.") + return result + + def __repr__(self) -> str: + return ( + f"RateLimiter(samples_per_insert={self.samples_per_insert}, " + f"min_size_to_sample={self.min_size_to_sample}, " + f"min_diff={self.min_diff}, max_diff={self.max_diff})" + ) + + +class SampleToInsertRatio(RateLimiter): + """Maintains a specified ratio between samples and inserts. + + The limiter works in two stages: + + Stage 1. Size of table is lt `min_size_to_sample`. + Stage 2. Size of table is ge `min_size_to_sample`. + + During stage 1 the limiter works exactly like MinSize, i.e. it allows + all insert calls and blocks all sample calls. Note that it is possible to + transition into stage 1 from stage 2 when items are removed from the table. + + During stage 2 the limiter attempts to maintain the `samples_per_insert` + ratio between the samples and inserts. This is done by + measuring the `error`, calculated as: + + error = number_of_inserts * samples_per_insert - number_of_samples + + and making sure that `error` stays within `allowed_range`. Any operation + which would move `error` outside of the `allowed_range` is blocked. + Such approach allows for small deviation from a target `samples_per_insert`, + which eliminates excessive blocking of insert/sample operations and improves + performance. + + If `error_buffer` is a tuple of two numbers then `allowed_range` is defined as + + (error_buffer[0], error_buffer[1]) + + When `error_buffer` is a single number then the range is defined as + + ( + min_size_to_sample * samples_per_insert - error_buffer, + min_size_to_sample * samples_per_insert + error_buffer + ) + """ + + def __init__( + self, + samples_per_insert: float, + min_size_to_sample: int, + error_buffer: Union[float, Tuple[float, float]], + ): + """Constructor of SampleToInsertRatio. + + Args: + samples_per_insert: The average number of times the learner should sample + each item in the replay buffer during the item's entire lifetime. + min_size_to_sample: The minimum number of items that the table must + contain before transitioning into stage 2. + error_buffer: Maximum size of the "error" before calls should be blocked. + When a single value is provided then inferred range is + ( + min_size_to_sample * samples_per_insert - error_buffer, + min_size_to_sample * samples_per_insert + error_buffer + ) + The offset is added so that the error tracked is for the insert/sample + ratio only takes into account operations occurring AFTER stage 1. If a + range (two float tuple) then the values are used without any offset. + + Raises: + ValueError: If error_buffer is smaller than max(1.0, samples_per_inserts). + """ + if isinstance(error_buffer, (int, float)): + offset = samples_per_insert * min_size_to_sample + min_diff = offset - error_buffer + max_diff = offset + error_buffer + else: + min_diff, max_diff = error_buffer + + if samples_per_insert <= 0: + raise ValueError(f"samples_per_insert ({samples_per_insert}) must be > 0") + + if max_diff - min_diff < 2 * max(1.0, samples_per_insert): + raise ValueError( + "The size of error_buffer must be >= max(1.0, samples_per_insert) as " + "smaller values could completely block samples and/or insert calls." + ) + + if max_diff < samples_per_insert * min_size_to_sample: + print( + f"{Fore.YELLOW}{Style.BRIGHT}The range covered by error_buffer is below " + "samples_per_insert * min_size_to_sample. If the sampler cannot " + "sample concurrently, this will result in a deadlock as soon as " + f"min_size_to_sample items have been inserted.{Style.RESET_ALL}" + ) + if min_diff > samples_per_insert * min_size_to_sample: + raise ValueError( + "The range covered by error_buffer is above " + "samples_per_insert * min_size_to_sample. This will result in a " + "deadlock as soon as min_size_to_sample items have been inserted." + ) + + if min_size_to_sample < 1: + raise ValueError( + f"min_size_to_sample ({min_size_to_sample}) must be a positive integer" + ) + + super().__init__( + samples_per_insert=samples_per_insert, + min_size_to_sample=min_size_to_sample, + min_diff=min_diff, + max_diff=max_diff, + ) + + +class BlockingRatioLimiter(RateLimiter): + """ + Blocking rate limiter that enforces a ratio of X samples per insert and 1/X inserts per sample. + + Args: + sample_insert_ratio (float): The ratio of samples to inserts (X). + For example, a value of 2.0 means for every insert, up to 2 samples are allowed, + and for every 2 samples, up to 1 insert is allowed. Must be greater than 0. + """ + + def __init__(self, sample_insert_ratio: float, min_num_inserts: float): + if sample_insert_ratio <= 0: + raise ValueError("sample_insert_ratio must be greater than 0") + super().__init__( + samples_per_insert=sample_insert_ratio, + min_size_to_sample=min_num_inserts, + min_diff=float("-inf"), + max_diff=float("inf"), + ) + self.available_inserts = 1.0 + self.available_samples = 0.0 + self.sample_insert_ratio = sample_insert_ratio + + def insert(self, insert_fraction: float = 1.0) -> None: + """ + Increments the available samples by insert_fraction * sample_insert_ratio. + """ + with self.mutex: + if self.min_size_to_sample > 0: + self.min_size_to_sample -= insert_fraction + else: + self.available_samples += insert_fraction * self.sample_insert_ratio + self.available_inserts -= insert_fraction + + self.inserts += insert_fraction + self.condition.notify_all() + + def sample(self, num_samples: int = 1) -> None: + """ + Increments the available inserts by num_samples / sample_insert_ratio. + """ + with self.mutex: + self.available_inserts += num_samples / self.sample_insert_ratio + self.available_samples -= num_samples + self.samples += 1 + self.condition.notify_all() + + def can_insert(self, num_inserts: float = 1.0) -> bool: + """ + Checks if it is possible to insert num_inserts, based on available insert credits. + """ + return self.available_inserts >= num_inserts + + def can_sample(self, num_samples: int = 1) -> bool: + """ + Checks if it is possible to sample num_samples, based on available sample credits. + """ + return self.available_samples >= num_samples diff --git a/mava/utils/sebulba/utils.py b/mava/utils/sebulba/utils.py new file mode 100644 index 000000000..3a52cc0c2 --- /dev/null +++ b/mava/utils/sebulba/utils.py @@ -0,0 +1,76 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import queue +import threading +import time +from typing import Any + +import jax + +from mava.systems.ppo.types import Params + + +class RecordTimeTo: + """Context manager to record the runtime in a `with` block""" + + def __init__(self, to: Any): + self.to = to + + def __enter__(self) -> None: + self.start = time.monotonic() + + def __exit__(self, *args: Any) -> None: + end = time.monotonic() + self.to.append(end - self.start) + + +class ParamsSource(threading.Thread): + """A `ParamSource` is a component that allows networks params to be passed from a + `Learner` component to `Actor` components. + """ + + def __init__(self, init_value: Params, device: jax.Device): + super().__init__(name=f"ParamsSource-{device.id}") + self.value: Params = jax.device_put(init_value, device) + self.device = device + self.new_value: queue.Queue = queue.Queue() + self._stop_event = threading.Event() + + def run(self) -> None: + """This function is responsible for updating the value of the `ParamSource` when a new value + is available. + """ + while not self._stop_event.is_set(): + try: + waiting = self.new_value.get(block=True, timeout=1) + self.value = jax.device_put(waiting, self.device) + except queue.Empty: + continue + + def update(self, new_params: Params) -> None: + """Update the value of the `ParamSource` with a new value. + + Args: + new_params: The new value to update the `ParamSource` with. + """ + self.new_value.put(new_params) + + def get(self) -> Params: + """Get the current value of the `ParamSource`.""" + return self.value + + def stop(self) -> None: + """Signal the thread to stop.""" + self._stop_event.set() diff --git a/mava/wrappers/gym.py b/mava/wrappers/gym.py index 9258bde6a..3bd48dbfa 100644 --- a/mava/wrappers/gym.py +++ b/mava/wrappers/gym.py @@ -29,7 +29,7 @@ from gymnasium.vector.utils import write_to_shared_memory from numpy.typing import NDArray -from mava.types import Observation, ObservationGlobalState +from mava.types import MavaObservation, Observation, ObservationGlobalState if TYPE_CHECKING: # https://github.com/python/mypy/issues/6239 from dataclasses import dataclass @@ -54,7 +54,7 @@ class TimeStep: step_type: StepType reward: NDArray discount: NDArray - observation: Union[Observation, ObservationGlobalState] + observation: MavaObservation extras: Dict = field(default_factory=dict) def first(self) -> NDArray: @@ -78,12 +78,13 @@ def __init__( use_shared_rewards: bool = True, add_global_state: bool = False, ): - """Initialise the gym wrapper + """Initialize the gym wrapper Args: env (gymnasium.env): gymnasium env instance. use_shared_rewards (bool, optional): Use individual or shared rewards. Defaults to False. - add_global_state (bool, optional) : Create global observations. Defaults to False. + add_global_state (bool, optional) : Add global state information + to observations. """ super().__init__(env) self._env = env @@ -141,7 +142,7 @@ def get_global_obs(self, obs: NDArray) -> NDArray: class SmacWrapper(UoeWrapper): - """A wrapper that converts actions step to integers.""" + """A wrapper that converts actions to integers.""" def reset( self, seed: Optional[int] = None, options: Optional[dict] = None @@ -251,18 +252,18 @@ def __init__(self, env: gymnasium.vector.VectorEnv): self.env = env self.single_action_space = env.unwrapped.single_action_space self.single_observation_space = env.unwrapped.single_observation_space + self.num_agents = len(self.env.single_action_space) def reset(self, seed: Optional[list[int]] = None, options: Optional[dict] = None) -> TimeStep: obs, info = self.env.reset(seed=seed, options=options) # type: ignore - num_agents = len(self.env.single_action_space) # type: ignore num_envs = self.env.num_envs step_type = np.full(num_envs, StepType.FIRST) - rewards = np.zeros((num_envs, num_agents), dtype=float) - teminated = np.zeros(num_envs, dtype=float) + rewards = np.zeros((num_envs, self.num_agents), dtype=float) + terminated = np.zeros((num_envs, self.num_agents), dtype=float) - timestep = self._create_timestep(obs, step_type, teminated, rewards, info) + timestep = self._create_timestep(obs, step_type, terminated, rewards, info) return timestep @@ -271,21 +272,27 @@ def step(self, action: list) -> TimeStep: ep_done = np.logical_or(terminated, truncated) step_type = np.where(ep_done, StepType.LAST, StepType.MID) + terminated = np.repeat( + terminated[..., np.newaxis], repeats=self.num_agents, axis=-1 + ) # (B,) --> (B, N) timestep = self._create_timestep(obs, step_type, terminated, rewards, info) return timestep def _format_observation( - self, obs: NDArray, info: Dict + self, + obs: NDArray, + action_mask: Tuple[NDArray], + global_obs: Tuple[Union[NDArray, None]] = (None,), ) -> Union[Observation, ObservationGlobalState]: """Create an observation from the raw observation and environment state.""" - action_mask = np.stack(info["action_mask"]) + action_mask = np.stack(action_mask) obs_data = {"agents_view": obs, "action_mask": action_mask} - if "global_obs" in info: - global_obs = np.array(info["global_obs"]) + if global_obs[0] is not None: + global_obs = np.array(global_obs) obs_data["global_state"] = global_obs return ObservationGlobalState(**obs_data) else: @@ -294,17 +301,23 @@ def _format_observation( def _create_timestep( self, obs: NDArray, step_type: NDArray, terminated: NDArray, rewards: NDArray, info: Dict ) -> TimeStep: - observation = self._format_observation(obs, info) + observation = self._format_observation( + obs, info["action_mask"], info.get("global_obs", (None,)) + ) # Filter out the masks and auxiliary data extras = {} extras["episode_metrics"] = { key: value for key, value in info["metrics"].items() if key[0] != "_" } + extras["real_next_obs"] = self._format_observation( # type: ignore + info["real_next_obs"], info["real_next_action_mask"], info["real_next_global_obs"] + ) + if "won_episode" in info: extras["won_episode"] = info["won_episode"] return TimeStep( - step_type=step_type, # type: ignore + step_type=step_type, reward=rewards, discount=1.0 - terminated, observation=observation, @@ -315,7 +328,7 @@ def close(self) -> None: self.env.close() -# Copied form Gymnasium/blob/main/gymnasium/vector/async_vector_env.py +# Copied from Gymnasium/blob/main/gymnasium/vector/async_vector_env.py # Modified to work with multiple agents # Note: The worker handles auto-resetting the environments. # Each environment resets when all of its agents have either terminated or been truncated. @@ -338,13 +351,16 @@ def async_multiagent_worker( # CCR001 if command == "reset": observation, info = env.reset(**data) + info["real_next_obs"] = observation + info["real_next_action_mask"] = info["action_mask"] + info["real_next_global_obs"] = info.get("global_obs", None) if shared_memory: write_to_shared_memory(observation_space, index, observation, shared_memory) observation = None pipe.send(((observation, info), True)) elif command == "step": # Modified the step function to align with 'AutoResetWrapper'. - # The environment resets immediately upon termination or truncation. + # The environment resets when all agents have either terminated or truncated. ( observation, reward, @@ -352,9 +368,13 @@ def async_multiagent_worker( # CCR001 truncated, info, ) = env.step(data) + info["real_next_obs"] = observation + info["real_next_action_mask"] = info["action_mask"] + info["real_next_global_obs"] = info.get("global_obs", None) if np.logical_or(terminated, truncated).all(): observation, new_info = env.reset() info["action_mask"] = new_info["action_mask"] + info["global_obs"] = new_info.get("global_obs", None) if shared_memory: write_to_shared_memory(observation_space, index, observation, shared_memory)