From a535e3d67be138a30f39595191a6e8febbd1bf60 Mon Sep 17 00:00:00 2001 From: Dries Date: Mon, 27 Sep 2021 10:41:43 +0200 Subject: [PATCH 001/104] Add rough recurrent code for MAPPO. --- mava/systems/tf/mappo/builder.py | 58 +++++++--- mava/systems/tf/mappo/execution.py | 169 ++++++++++++++++++++++++++++- mava/systems/tf/mappo/system.py | 28 +++++ mava/systems/tf/mappo/training.py | 135 ++++++++++++++++++----- 4 files changed, 344 insertions(+), 46 deletions(-) diff --git a/mava/systems/tf/mappo/builder.py b/mava/systems/tf/mappo/builder.py index 428257fd5..d06fb7c83 100644 --- a/mava/systems/tf/mappo/builder.py +++ b/mava/systems/tf/mappo/builder.py @@ -16,14 +16,15 @@ """MAPPO system builder implementation.""" import dataclasses -from typing import Dict, Iterator, List, Optional, Type, Union +import copy +from typing import Dict, Iterator, List, Optional, Type, Union, Any +import tensorflow as tf import reverb import sonnet as snt -import tensorflow as tf -from acme import datasets -from acme import types as acme_types from acme.tf import utils as tf2_utils +from acme import datasets +from acme.specs import EnvironmentSpec from acme.tf import variable_utils from acme.utils import counting @@ -98,6 +99,7 @@ def __init__( config: MAPPOConfig, trainer_fn: Type[training.MAPPOTrainer] = training.MAPPOTrainer, executor_fn: Type[core.Executor] = execution.MAPPOFeedForwardExecutor, + extra_specs: Dict[str, Any] = {} ): """Initialise the system. @@ -117,6 +119,35 @@ def __init__( self._agent_types = self._config.environment_spec.get_agent_types() self._trainer_fn = trainer_fn self._executor_fn = executor_fn + self._extras_specs = extra_specs + + def add_log_prob_to_spec( + self, environment_spec: specs.MAEnvironmentSpec + ) -> specs.MAEnvironmentSpec: + """convert discrete action space to bounded continuous action space + Args: + environment_spec (specs.MAEnvironmentSpec): description of + the action, observation spaces etc. for each agent in the system. + Returns: + specs.MAEnvironmentSpec: updated environment spec. + """ + + env_adder_spec: specs.MAEnvironmentSpec = copy.deepcopy(environment_spec) + keys = env_adder_spec._keys + for key in keys: + agent_spec = env_adder_spec._specs[key] + new_act_spec = {"actions": agent_spec.actions} + + # Make dummy log_probs + new_act_spec["log_probs"] = tf2_utils.squeeze_batch_dim(tf.ones(shape=(1,), dtype=tf.float32)) + + env_adder_spec._specs[key] = EnvironmentSpec( + observations=agent_spec.observations, + actions=new_act_spec, + rewards=agent_spec.rewards, + discounts=agent_spec.discounts, + ) + return env_adder_spec def make_replay_tables( self, @@ -132,23 +163,18 @@ def make_replay_tables( List[reverb.Table]: a list of data tables for inserting data. """ - agent_specs = environment_spec.get_agent_specs() - extras_spec: Dict[str, Dict[str, acme_types.NestedArray]] = {"log_probs": {}} - - for agent, spec in agent_specs.items(): - # Make dummy log_probs - extras_spec["log_probs"][agent] = tf.ones(shape=(1,), dtype=tf.float32) - - # Squeeze the batch dim. - extras_spec = tf2_utils.squeeze_batch_dim(extras_spec) + # Create system architecture with target networks. + adder_env_spec = self._builder.add_log_prob_to_spec( + environment_spec + ) replay_table = reverb.Table.queue( name=self._config.replay_table_name, max_size=self._config.max_queue_size, signature=reverb_adders.ParallelSequenceAdder.signature( - environment_spec, + adder_env_spec, sequence_length=self._config.sequence_length, - extras_spec=extras_spec, + extras_spec=self._extras_specs, ), ) @@ -200,7 +226,7 @@ def make_adder( client=replay_client, period=self._config.sequence_period, sequence_length=self._config.sequence_length, - use_next_extras=False, + use_next_extras=True, ) def make_executor( diff --git a/mava/systems/tf/mappo/execution.py b/mava/systems/tf/mappo/execution.py index afa94c65d..74a704eaa 100644 --- a/mava/systems/tf/mappo/execution.py +++ b/mava/systems/tf/mappo/execution.py @@ -15,7 +15,7 @@ """MAPPO system executor implementation.""" -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional, Tuple, Union import dm_env import numpy as np @@ -23,10 +23,12 @@ import tensorflow as tf import tensorflow_probability as tfp from acme import types +from acme.specs import EnvironmentSpec from acme.tf import utils as tf2_utils from acme.tf import variable_utils as tf2_variable_utils from mava import adders, core +from mava.systems.tf import executors tfd = tfp.distributions @@ -202,3 +204,168 @@ def update(self, wait: bool = False) -> None: if self._variable_client: self._variable_client.update(wait) + +class MADDPGRecurrentExecutor(executors.RecurrentExecutor): + """A recurrent executor for MADDPG. + An executor based on a recurrent policy for each agent in the system. + """ + + def __init__( + self, + policy_networks: Dict[str, snt.Module], + agent_net_keys: Dict[str, str], + adder: Optional[adders.ParallelAdder] = None, + variable_client: Optional[tf2_variable_utils.VariableClient] = None, + ): + """Initialise the system executor + Args: + policy_networks (Dict[str, snt.Module]): policy networks for each agent in + the system. + adder (Optional[adders.ParallelAdder], optional): adder which sends data + to a replay buffer. Defaults to None. + variable_client (Optional[tf2_variable_utils.VariableClient], optional): + client to copy weights from the trainer. Defaults to None. + agent_net_keys: (dict, optional): specifies what network each agent uses. + Defaults to {}. + """ + + super().__init__( + policy_networks=policy_networks, + agent_net_keys=agent_net_keys, + adder=adder, + variable_client=variable_client, + store_recurrent_state=True, + ) + + @tf.function + def _policy( + self, + agent: str, + observation: types.NestedTensor, + state: types.NestedTensor, + ) -> Tuple[types.NestedTensor, types.NestedTensor, types.NestedTensor]: + """Agent specific policy function + + Args: + agent (str): agent id + observation (types.NestedTensor): observation tensor received from the + environment. + state (types.NestedTensor): recurrent network state. + + Raises: + NotImplementedError: unknown action space + + Returns: + Tuple[types.NestedTensor, types.NestedTensor, types.NestedTensor]: + action, policy and new recurrent hidden state + """ + + # Add a dummy batch dimension and as a side effect convert numpy to TF. + batched_observation = tf2_utils.add_batch_dim(observation) + + # index network either on agent type or on agent id + agent_key = self._agent_net_keys[agent] + + # Compute the policy, conditioned on the observation. + policy, new_state = self._policy_networks[agent_key](batched_observation, state) + + # Sample from the policy and compute the log likelihood. + action = policy.sample() + log_prob = policy.log_prob(action) + + # Cast for compatibility with reverb. + # sample() returns a 'int32', which is a problem. + if isinstance(policy, tfp.distributions.Categorical): + action = tf.cast(action, "int64") + + return action, log_prob, new_state + + def select_action( + self, agent: str, observation: types.NestedArray + ) -> types.NestedArray: + """select an action for a single agent in the system + + Args: + agent (str): agent id + observation (types.NestedArray): observation tensor received from the + environment. + + Returns: + types.NestedArray: action and policy. + """ + + # TODO Mask actions here using observation.legal_actions + # Initialize the RNN state if necessary. + if self._states[agent] is None: + # index network either on agent type or on agent id + agent_key = self._agent_net_keys[agent] + self._states[agent] = self._policy_networks[agent_key].initia_state(1) + + # Step the recurrent policy forward given the current observation and state. + action, log_prob, new_state = self._policy( + agent, observation.observation, self._states[agent] + ) + + # Bookkeeping of recurrent states for the observe method. + self._update_state(agent, new_state) + + # Return a numpy array with squeezed out batch dimension. + action = tf2_utils.to_numpy_squeeze(action) + log_prob = tf2_utils.to_numpy_squeeze(log_prob) + return action, log_prob + + def select_actions( + self, observations: Dict[str, types.NestedArray] + ) -> Tuple[Dict[str, types.NestedArray], Dict[str, types.NestedArray]]: + """select the actions for all agents in the system + + Args: + observations (Dict[str, types.NestedArray]): agent observations from the + environment. + + Returns: + Tuple[Dict[str, types.NestedArray], Dict[str, types.NestedArray]]: + actions and policies for all agents in the system. + """ + + actions = {} + log_probs = {} + for agent, observation in observations.items(): + actions[agent], log_probs[agent] = self.select_action(agent, observation) + return actions, log_probs + + def observe( + self, + actions: Dict[str, types.NestedArray], + next_timestep: dm_env.TimeStep, + next_extras: Optional[Dict[str, types.NestedArray]] = {}, + ) -> None: + """record observed timestep from the environment + + Args: + actions (Dict[str, types.NestedArray]): system agents' actions. + next_timestep (dm_env.TimeStep): data emitted by an environment during + interaction. + next_extras (Dict[str, types.NestedArray], optional): possible extra + information to record during the transition. Defaults to {}. + """ + + if not self._adder: + return + actions, log_probs = actions + + adder_actions = {} + for agent in actions.keys(): + adder_actions[agent] = {} + adder_actions[agent]["actions"] = actions[agent] + adder_actions[agent]["log_probs"] = log_probs[agent] + + numpy_states = { + agent: tf2_utils.to_numpy_squeeze(_state) + for agent, _state in self._states.items() + } + if next_extras: + next_extras.update({"core_states": numpy_states}) + self._adder.add(adder_actions, next_timestep, next_extras) # type: ignore + else: + self._adder.add(adder_actions, next_timestep, numpy_states) # type: ignore diff --git a/mava/systems/tf/mappo/system.py b/mava/systems/tf/mappo/system.py index 14f56ae05..2af1bd732 100644 --- a/mava/systems/tf/mappo/system.py +++ b/mava/systems/tf/mappo/system.py @@ -25,6 +25,7 @@ import sonnet as snt from acme import specs as acme_specs from acme.utils import counting, loggers +from acme.tf import utils as tf2_utils import mava from mava import core @@ -195,6 +196,11 @@ def __init__( self._eval_loop_fn_kwargs = eval_loop_fn_kwargs self._checkpoint_minute_interval = checkpoint_minute_interval + if issubclass(executor_fn, execution.MAPPOFeedForwardExecutor): + extra_specs = self._get_extra_specs() + else: + extra_specs = {} + self._builder = builder.MAPPOBuilder( config=builder.MAPPOConfig( environment_spec=environment_spec, @@ -218,7 +224,29 @@ def __init__( ), trainer_fn=trainer_fn, executor_fn=executor_fn, + extra_specs=extra_specs, + ) + + def _get_extra_specs(self) -> Any: + """helper to establish specs for extra information + Returns: + Dict[str, Any]: dictionary containing extra specs + """ + + agents = self._environment_spec.get_agent_ids() + core_state_specs = {} + networks = self._network_factory( # type: ignore + environment_spec=self._environment_spec, + agent_net_keys=self._agent_net_keys, ) + for agent in agents: + agent_type = agent.split("_")[0] + core_state_specs[agent] = ( + tf2_utils.squeeze_batch_dim( + networks["policies"][agent_type].initial_state(1) + ), + ) + return {"core_states": core_state_specs} def replay(self) -> Any: """Replay data storage. diff --git a/mava/systems/tf/mappo/training.py b/mava/systems/tf/mappo/training.py index 27c6cc971..3dc168298 100644 --- a/mava/systems/tf/mappo/training.py +++ b/mava/systems/tf/mappo/training.py @@ -18,6 +18,7 @@ import copy import os import time +import tree from typing import Any, Dict, List, Optional, Sequence, Union import numpy as np @@ -55,10 +56,10 @@ def __init__( critic_optimizer: Union[snt.Optimizer, Dict[str, snt.Optimizer]], agent_net_keys: Dict[str, str], checkpoint_minute_interval: int, - discount: float = 0.99, - lambda_gae: float = 1.0, - entropy_cost: float = 0.0, - baseline_cost: float = 1.0, + discount: float = 0.999, + lambda_gae: float = 1.0, # Question (dries): What is this used for? + entropy_cost: float = 0.01, + baseline_cost: float = 0.5, clipping_epsilon: float = 0.2, max_gradient_norm: Optional[float] = None, counter: counting.Counter = None, @@ -286,11 +287,15 @@ def _forward(self, inputs: Any) -> None: data.extras, ) + if 'core_state' in extras: + core_state = tree.map_structure(lambda s: s[:, 0, :], extras["core_states"]) + # transform observation using observation networks observations_trans = self._transform_observations(observations) # Get log_probs. - log_probs = extras["log_probs"] + actions = actions["actions"] + behaviour_logits = actions["log_probs"] # Store losses. policy_losses: Dict[str, Any] = {} @@ -299,15 +304,15 @@ def _forward(self, inputs: Any) -> None: with tf.GradientTape(persistent=True) as tape: for agent in self._agents: - action, reward, discount, behaviour_log_prob = ( + action, reward, discount, behaviour_logits = ( actions[agent], rewards[agent], discounts[agent], - log_probs[agent], + behaviour_logits[agent], ) actor_observation = observations_trans[agent] - critic_observation = self._get_critic_feed(observations_trans, agent) + critic_observation = self._get_critic_feed(observations_trans, extras, agent) # Chop off final timestep for bootstrapping value reward = reward[:-1] @@ -326,13 +331,28 @@ def _forward(self, inputs: Any) -> None: critic_observation = snt.merge_leading_dims( critic_observation, num_dims=2 ) - policy = policy_network(actor_observation) - values = critic_network(critic_observation) - - # Reshape the outputs. - policy = tfd.BatchReshape(policy, batch_shape=dims, name="policy") - values = tf.reshape(values, dims, name="value") + if 'core_state' in extras: + # Unroll current policy over actor_observation. + (logits, values), _ = snt.static_unroll(self._actor_critic_network, actor_observation, + core_state) + policy = tfd.Categorical(logits=logits[:-1]) + policy_log_prob = policy.log_prob(actions) + # Entropy regulariser. + entropy_loss = trfl.policy_entropy_loss(policy).loss + else: + policy = policy_network(actor_observation) + values = critic_network(critic_observation) + # Reshape the outputs. + policy = tfd.BatchReshape(policy, batch_shape=dims, name="policy") + values = tf.reshape(values, dims, name="value") + policy_log_prob = policy.log_prob(action) + + # Entropy regularization. Only implemented for categorical dist. + policy_entropy = tf.reduce_mean(policy.entropy()) + entropy_loss = -self._entropy_cost * policy_entropy + # Compute importance sampling weights: current policy / behavior policy. + # Values along the sequence T. bootstrap_value = values[-1] state_values = values[:-1] @@ -354,7 +374,9 @@ def _forward(self, inputs: Any) -> None: ) # Compute importance sampling weights: current policy / behavior policy. - log_rhos = policy.log_prob(action) - behaviour_log_prob + behaviour_log_prob = tfd.Categorical(logits=behaviour_logits[:-1]) + + log_rhos = policy_log_prob - behaviour_log_prob importance_ratio = tf.exp(log_rhos)[:-1] clipped_importance_ratio = tf.clip_by_value( importance_ratio, @@ -375,19 +397,12 @@ def _forward(self, inputs: Any) -> None: name="PolicyGradientLoss", ) - # Entropy regularization. Only implemented for categorical dist. - try: - policy_entropy = tf.reduce_mean(policy.entropy()) - except NotImplementedError: - policy_entropy = tf.convert_to_tensor(0.0) - - entropy_loss = -self._entropy_cost * policy_entropy - - # Combine weighted sum of actor & entropy regularization. + # Combine weighted sum of actor & entropy regularization. policy_loss = policy_gradient_loss + entropy_loss - policy_losses[agent] = policy_loss - critic_losses[agent] = critic_loss + # Multiply by discounts to not train on padded data. + policy_losses[agent] = policy_loss*discount + critic_losses[agent] = critic_loss*discount self.policy_losses = policy_losses self.critic_losses = critic_losses @@ -489,10 +504,10 @@ def __init__( critic_optimizer: Union[snt.Optimizer, Dict[str, snt.Optimizer]], agent_net_keys: Dict[str, str], checkpoint_minute_interval: int, - discount: float = 0.99, + discount: float = 0.999, lambda_gae: float = 1.0, - entropy_cost: float = 0.0, - baseline_cost: float = 1.0, + entropy_cost: float = 0.01, + baseline_cost: float = 0.5, clipping_epsilon: float = 0.2, max_gradient_norm: Optional[float] = None, counter: counting.Counter = None, @@ -527,9 +542,71 @@ def __init__( def _get_critic_feed( self, observations_trans: Dict[str, np.ndarray], + extras: Dict[str, np.ndarray], agent: str, ) -> tf.Tensor: # Centralised based observation_feed = tf.stack([x for x in observations_trans.values()], 2) return observation_feed + +class StateBasedMAPPOTrainer(MAPPOTrainer): + """MAPPO trainer for a centralised architecture.""" + + def __init__( + self, + agents: List[Any], + agent_types: List[str], + observation_networks: Dict[str, snt.Module], + policy_networks: Dict[str, snt.Module], + critic_networks: Dict[str, snt.Module], + dataset: tf.data.Dataset, + policy_optimizer: Union[snt.Optimizer, Dict[str, snt.Optimizer]], + critic_optimizer: Union[snt.Optimizer, Dict[str, snt.Optimizer]], + agent_net_keys: Dict[str, str], + checkpoint_minute_interval: int, + discount: float = 0.999, + lambda_gae: float = 1.0, + entropy_cost: float = 0.01, + baseline_cost: float = 0.5, + clipping_epsilon: float = 0.2, + max_gradient_norm: Optional[float] = None, + counter: counting.Counter = None, + logger: loggers.Logger = None, + checkpoint: bool = False, + checkpoint_subpath: str = "Checkpoints", + ): + + super().__init__( + agents=agents, + agent_types=agent_types, + policy_networks=policy_networks, + critic_networks=critic_networks, + observation_networks=observation_networks, + dataset=dataset, + agent_net_keys=agent_net_keys, + checkpoint_minute_interval=checkpoint_minute_interval, + policy_optimizer=policy_optimizer, + critic_optimizer=critic_optimizer, + discount=discount, + lambda_gae=lambda_gae, + entropy_cost=entropy_cost, + baseline_cost=baseline_cost, + clipping_epsilon=clipping_epsilon, + max_gradient_norm=max_gradient_norm, + counter=counter, + logger=logger, + checkpoint=checkpoint, + checkpoint_subpath=checkpoint_subpath, + ) + + def _get_critic_feed( + self, + observations_trans: Dict[str, np.ndarray], + extras: Dict[str, np.ndarray], + agent: str, + ) -> tf.Tensor: + # State based + observation_feed = extras["env_state"] + + return observation_feed From 0c6405b4edc060e9469c0d5ecca8b3a164d19082 Mon Sep 17 00:00:00 2001 From: Dries Date: Mon, 27 Sep 2021 16:11:46 +0200 Subject: [PATCH 002/104] Save progress. --- mava/components/tf/architectures/__init__.py | 1 + .../tf/architectures/state_based.py | 47 +++++++++++++++++++ mava/components/tf/networks/__init__.py | 1 + mava/components/tf/networks/convolution.py | 29 ++++++++++++ mava/systems/tf/maddpg/system.py | 4 +- mava/systems/tf/maddpg/training.py | 2 + mava/systems/tf/mappo/__init__.py | 4 +- mava/systems/tf/mappo/builder.py | 4 +- mava/systems/tf/mappo/execution.py | 9 ++-- mava/systems/tf/mappo/system.py | 15 +++--- mava/systems/tf/mappo/training.py | 30 ++++++++---- mava/utils/training_utils.py | 27 +++++++---- mava/wrappers/pettingzoo.py | 6 ++- 13 files changed, 144 insertions(+), 35 deletions(-) create mode 100644 mava/components/tf/networks/convolution.py diff --git a/mava/components/tf/architectures/__init__.py b/mava/components/tf/architectures/__init__.py index fe9385f5a..a84167fce 100644 --- a/mava/components/tf/architectures/__init__.py +++ b/mava/components/tf/architectures/__init__.py @@ -38,6 +38,7 @@ ) from mava.components.tf.architectures.state_based import ( StateBasedPolicyActor, + StateBasedValueActorCritic, StateBasedQValueActorCritic, StateBasedQValueCritic, ) diff --git a/mava/components/tf/architectures/state_based.py b/mava/components/tf/architectures/state_based.py index f0a9a7ac5..21e3bf1ea 100644 --- a/mava/components/tf/architectures/state_based.py +++ b/mava/components/tf/architectures/state_based.py @@ -24,6 +24,7 @@ from mava import specs as mava_specs from mava.components.tf.architectures.decentralised import ( + DecentralisedValueActorCritic, DecentralisedPolicyActor, DecentralisedQValueActorCritic, ) @@ -146,3 +147,49 @@ def __init__( critic_networks=critic_networks, agent_net_keys=agent_net_keys, ) + +class StateBasedValueActorCritic( # type: ignore + DecentralisedValueActorCritic +): + """Multi-agent actor critic architecture where both actor policies + and critics use environment state information""" + + def __init__( + self, + environment_spec: mava_specs.MAEnvironmentSpec, + observation_networks: Dict[str, snt.Module], + policy_networks: Dict[str, snt.Module], + critic_networks: Dict[str, snt.Module], + agent_net_keys: Dict[str, str], + ): + DecentralisedValueActorCritic.__init__( + self, + environment_spec=environment_spec, + observation_networks=observation_networks, + policy_networks=policy_networks, + critic_networks=critic_networks, + agent_net_keys=agent_net_keys, + ) + def _get_critic_specs( + self, + ) -> Tuple[Dict[str, acme_specs.Array], Dict[str, acme_specs.Array]]: + # Create one critic per agent. Each critic gets + # absolute state information of the environment. + critic_env_state_spec = list( + self._env_spec.get_extra_specs()["env_states"].values() + )[0] + + critic_obs_spec = [] + for spec in critic_env_state_spec: + critic_obs_spec.append( + tf.TensorSpec( + shape=spec.shape, + dtype=tf.dtypes.float32, + ) + ) + + critic_obs_specs = {} + for agent_key in self._agents: + # Get observation spec for critic. + critic_obs_specs[agent_key] = critic_obs_spec + return critic_obs_specs, None diff --git a/mava/components/tf/networks/__init__.py b/mava/components/tf/networks/__init__.py index ee14261c7..05fcebafa 100644 --- a/mava/components/tf/networks/__init__.py +++ b/mava/components/tf/networks/__init__.py @@ -27,6 +27,7 @@ LayerNormMLP, NearZeroInitializedLinear, ) +from mava.components.tf.networks.convolution import Conv1DNetwork from mava.components.tf.networks.epsilon_greedy import epsilon_greedy_action_selector from mava.components.tf.networks.fingerprints import ObservationNetworkWithFingerprint from mava.components.tf.networks.mad4pg import ( diff --git a/mava/components/tf/networks/convolution.py b/mava/components/tf/networks/convolution.py new file mode 100644 index 000000000..7d4e9ac6d --- /dev/null +++ b/mava/components/tf/networks/convolution.py @@ -0,0 +1,29 @@ +from typing import List + +import sonnet as snt +import tensorflow as tf + + +class Conv1DNetwork(snt.Module): + """Simple 1D convolutional network setup.""" + + def __init__( + self, + channels: List[int] = [32, 64, 64], + kernel: List[int] = [8, 4, 3], + stride: List[int] = [4, 2, 2], + name: str = None, + ): + super(Conv1DNetwork, self).__init__(name=name) + assert len(channels) == len(kernel) == len(stride) + seq_list = [] + for i in range(len(channels)): + seq_list.append( + snt.Conv1D(channels[i], kernel_shape=kernel[i], stride=stride[i]) + ) + seq_list.append(tf.nn.relu) + seq_list.append(snt.Flatten()) + self._network = snt.Sequential(seq_list) + + def __call__(self, inputs: tf.Tensor) -> tf.Tensor: + return self._network(inputs) \ No newline at end of file diff --git a/mava/systems/tf/maddpg/system.py b/mava/systems/tf/maddpg/system.py index 38d13c3df..2417ff06f 100644 --- a/mava/systems/tf/maddpg/system.py +++ b/mava/systems/tf/maddpg/system.py @@ -275,10 +275,10 @@ def _get_extra_specs(self) -> Any: agent_net_keys=self._agent_net_keys, ) for agent in agents: - agent_type = agent.split("_")[0] + net_key = self._agent_net_keys[agent] core_state_specs[agent] = ( tf2_utils.squeeze_batch_dim( - networks["policies"][agent_type].initial_state(1) + networks["policies"][net_key].initial_state(1) ), ) return {"core_states": core_state_specs} diff --git a/mava/systems/tf/maddpg/training.py b/mava/systems/tf/maddpg/training.py index 440e30f9d..76870cbfc 100644 --- a/mava/systems/tf/maddpg/training.py +++ b/mava/systems/tf/maddpg/training.py @@ -1146,6 +1146,8 @@ def _transform_observations( for agent in self._agents: agent_key = self._agent_net_keys[agent] + print("Observation: ", observations[agent].observation) + exit() reshaped_obs, dims = train_utils.combine_dim( observations[agent].observation ) diff --git a/mava/systems/tf/mappo/__init__.py b/mava/systems/tf/mappo/__init__.py index ee36f200a..af3ccdb05 100644 --- a/mava/systems/tf/mappo/__init__.py +++ b/mava/systems/tf/mappo/__init__.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from mava.systems.tf.mappo.execution import MAPPOFeedForwardExecutor +from mava.systems.tf.mappo.execution import MAPPOFeedForwardExecutor, MAPPORecurrentExecutor from mava.systems.tf.mappo.networks import make_default_networks from mava.systems.tf.mappo.system import MAPPO -from mava.systems.tf.mappo.training import CentralisedMAPPOTrainer, MAPPOTrainer +from mava.systems.tf.mappo.training import CentralisedMAPPOTrainer, MAPPOTrainer, StateBasedMAPPOTrainer diff --git a/mava/systems/tf/mappo/builder.py b/mava/systems/tf/mappo/builder.py index d06fb7c83..52f031412 100644 --- a/mava/systems/tf/mappo/builder.py +++ b/mava/systems/tf/mappo/builder.py @@ -164,7 +164,7 @@ def make_replay_tables( """ # Create system architecture with target networks. - adder_env_spec = self._builder.add_log_prob_to_spec( + adder_env_spec = self.add_log_prob_to_spec( environment_spec ) @@ -283,6 +283,7 @@ def make_trainer( self, networks: Dict[str, Dict[str, snt.Module]], dataset: Iterator[reverb.ReplaySample], + can_sample: Any, counter: Optional[counting.Counter] = None, logger: Optional[types.NestedLogger] = None, ) -> core.Trainer: @@ -314,6 +315,7 @@ def make_trainer( trainer = self._trainer_fn( agents=agents, agent_types=agent_types, + can_sample=can_sample, observation_networks=observation_networks, policy_networks=policy_networks, critic_networks=critic_networks, diff --git a/mava/systems/tf/mappo/execution.py b/mava/systems/tf/mappo/execution.py index 74a704eaa..471f79f68 100644 --- a/mava/systems/tf/mappo/execution.py +++ b/mava/systems/tf/mappo/execution.py @@ -57,7 +57,6 @@ def __init__( variable_client (Optional[tf2_variable_utils.VariableClient], optional): client to copy weights from the trainer. Defaults to None. """ - # Store these for later use. self._adder = adder self._variable_client = variable_client @@ -165,7 +164,6 @@ def observe_first( extras (Dict[str, types.NestedArray], optional): possible extra information to record during the first step. Defaults to {}. """ - if self._adder: self._adder.add_first(timestep) @@ -205,7 +203,7 @@ def update(self, wait: bool = False) -> None: if self._variable_client: self._variable_client.update(wait) -class MADDPGRecurrentExecutor(executors.RecurrentExecutor): +class MAPPORecurrentExecutor(executors.RecurrentExecutor): """A recurrent executor for MADDPG. An executor based on a recurrent policy for each agent in the system. """ @@ -228,7 +226,6 @@ def __init__( agent_net_keys: (dict, optional): specifies what network each agent uses. Defaults to {}. """ - super().__init__( policy_networks=policy_networks, agent_net_keys=agent_net_keys, @@ -275,8 +272,8 @@ def _policy( # Cast for compatibility with reverb. # sample() returns a 'int32', which is a problem. - if isinstance(policy, tfp.distributions.Categorical): - action = tf.cast(action, "int64") + # if isinstance(policy, tfp.distributions.Categorical): + # action = tf.cast(action, "int64") return action, log_prob, new_state diff --git a/mava/systems/tf/mappo/system.py b/mava/systems/tf/mappo/system.py index 2af1bd732..6d588b6ad 100644 --- a/mava/systems/tf/mappo/system.py +++ b/mava/systems/tf/mappo/system.py @@ -69,9 +69,9 @@ def __init__( baseline_cost: float = 0.5, max_gradient_norm: Optional[float] = None, max_queue_size: int = 100000, - batch_size: int = 256, - sequence_length: int = 10, - sequence_period: int = 5, + batch_size: int = 32, + sequence_length: int = 50, + sequence_period: int = 10, max_executor_steps: int = None, checkpoint: bool = True, checkpoint_subpath: str = "~/mava/", @@ -196,7 +196,7 @@ def __init__( self._eval_loop_fn_kwargs = eval_loop_fn_kwargs self._checkpoint_minute_interval = checkpoint_minute_interval - if issubclass(executor_fn, execution.MAPPOFeedForwardExecutor): + if issubclass(executor_fn, execution.MAPPORecurrentExecutor): extra_specs = self._get_extra_specs() else: extra_specs = {} @@ -240,10 +240,10 @@ def _get_extra_specs(self) -> Any: agent_net_keys=self._agent_net_keys, ) for agent in agents: - agent_type = agent.split("_")[0] + net_keys = self._agent_net_keys[agent] core_state_specs[agent] = ( tf2_utils.squeeze_batch_dim( - networks["policies"][agent_type].initial_state(1) + networks["policies"][net_keys].initial_state(1) ), ) return {"core_states": core_state_specs} @@ -327,7 +327,7 @@ def trainer( trainer_logger = self._logger_factory( # type: ignore "trainer", **trainer_logger_config ) - + can_sample = None # lambda: replay.can_sample(self._builder._config.batch_size) dataset = self._builder.make_dataset_iterator(replay) counter = counting.Counter(counter, "trainer") @@ -336,6 +336,7 @@ def trainer( dataset=dataset, counter=counter, logger=trainer_logger, + can_sample=can_sample, ) def executor( diff --git a/mava/systems/tf/mappo/training.py b/mava/systems/tf/mappo/training.py index 3dc168298..72343a2b4 100644 --- a/mava/systems/tf/mappo/training.py +++ b/mava/systems/tf/mappo/training.py @@ -48,6 +48,7 @@ def __init__( self, agents: List[Any], agent_types: List[str], + can_sample: Any, observation_networks: Dict[str, snt.Module], policy_networks: Dict[str, snt.Module], critic_networks: Dict[str, snt.Module], @@ -111,6 +112,9 @@ def __init__( self._agents = agents self._agent_types = agent_types self._checkpoint = checkpoint + + # Can sample + self._can_sample = can_sample # Store agent_net_keys. self._agent_net_keys = agent_net_keys @@ -224,7 +228,7 @@ def _get_critic_feed( return observation_feed def _transform_observations( - self, observation: Dict[str, np.ndarray] + self, observations: Dict[str, np.ndarray] ) -> Dict[str, np.ndarray]: """apply the observation networks to the raw observations from the dataset @@ -239,12 +243,18 @@ def _transform_observations( observation_trans = {} for agent in self._agents: agent_key = self._agent_net_keys[agent] - observation_trans[agent] = self._observation_networks[agent_key]( - observation[agent].observation + + reshaped_obs, dims = train_utils.combine_dim( + observations[agent].observation + ) + + observation_trans[agent] = train_utils.extract_dim( + self._observation_networks[agent_key](reshaped_obs), dims ) return observation_trans - @tf.function + print("Add this back in. But this breaks the next(itter check) as and infinite number of pulls are allowed now. Fix it.") + # @tf.function def _step( self, ) -> Dict[str, Dict[str, Any]]: @@ -279,7 +289,7 @@ def _forward(self, inputs: Any) -> None: data = tf2_utils.batch_to_sequence(inputs.data) # Unpack input data as follows: - observations, actions, rewards, discounts, extras = ( + observations, actions_data, rewards, discounts, extras = ( data.observations, data.actions, data.rewards, @@ -294,8 +304,8 @@ def _forward(self, inputs: Any) -> None: observations_trans = self._transform_observations(observations) # Get log_probs. - actions = actions["actions"] - behaviour_logits = actions["log_probs"] + actions = {agent: actions_data[agent]["actions"] for agent in actions_data.keys()} + behaviour_logits = {agent: actions_data[agent]["log_probs"] for agent in actions_data.keys()} # Store losses. policy_losses: Dict[str, Any] = {} @@ -496,6 +506,7 @@ def __init__( self, agents: List[Any], agent_types: List[str], + can_sample: Any, observation_networks: Dict[str, snt.Module], policy_networks: Dict[str, snt.Module], critic_networks: Dict[str, snt.Module], @@ -519,6 +530,7 @@ def __init__( super().__init__( agents=agents, agent_types=agent_types, + can_sample=can_sample, policy_networks=policy_networks, critic_networks=critic_networks, observation_networks=observation_networks, @@ -557,6 +569,7 @@ def __init__( self, agents: List[Any], agent_types: List[str], + can_sample: Any, observation_networks: Dict[str, snt.Module], policy_networks: Dict[str, snt.Module], critic_networks: Dict[str, snt.Module], @@ -580,6 +593,7 @@ def __init__( super().__init__( agents=agents, agent_types=agent_types, + can_sample=can_sample, policy_networks=policy_networks, critic_networks=critic_networks, observation_networks=observation_networks, @@ -607,6 +621,6 @@ def _get_critic_feed( agent: str, ) -> tf.Tensor: # State based - observation_feed = extras["env_state"] + observation_feed = extras["env_states"] return observation_feed diff --git a/mava/utils/training_utils.py b/mava/utils/training_utils.py index 01d57dcc7..3c1c5070a 100644 --- a/mava/utils/training_utils.py +++ b/mava/utils/training_utils.py @@ -1,5 +1,5 @@ import os -from typing import Any, Dict, Iterable, Sequence +from typing import Any, Dict, Iterable, Sequence, Union, List, Tuple import sonnet as snt import tensorflow as tf @@ -39,13 +39,24 @@ def map_losses_per_agent_ac(critic_losses: Dict, policy_losses: Dict) -> Dict: return logged_losses -def combine_dim(tensor: tf.Tensor) -> tf.Tensor: - dims = tensor.shape[:2] - return snt.merge_leading_dims(tensor, num_dims=2), dims - - -def extract_dim(tensor: tf.Tensor, dims: tf.Tensor) -> tf.Tensor: - return tf.reshape(tensor, [dims[0], dims[1], -1]) +def combine_dim(inputs: Union[tf.Tensor, List, Tuple]) -> tf.Tensor: + if isinstance(inputs, tf.Tensor): + dims = inputs.shape[:2] + return snt.merge_leading_dims(inputs, num_dims=2), dims + else: + dims = None + return_list = [] + for tensor in inputs: + comb_one, dims = combine_dim(tensor) + return_list.append(comb_one) + return return_list, dims + + +def extract_dim(inputs: Union[tf.Tensor, List, Tuple], dims: tf.Tensor) -> tf.Tensor: + if isinstance(inputs, tf.Tensor): + return tf.reshape(inputs, [dims[0], dims[1], -1]) + else: + return [extract_dim(tensor, dims) for tensor in inputs] # Require correct tensor ranks---as long as we have shape information diff --git a/mava/wrappers/pettingzoo.py b/mava/wrappers/pettingzoo.py index a462279a6..9953b7fa1 100644 --- a/mava/wrappers/pettingzoo.py +++ b/mava/wrappers/pettingzoo.py @@ -23,7 +23,11 @@ from acme import specs from acme.wrappers.gym_wrapper import _convert_to_spec from pettingzoo.utils.env import AECEnv, ParallelEnv -from supersuit import black_death_v1 + +try: + from supersuit import black_death_v1 +except ImportError: + black_death_v1 = None from mava import types from mava.utils.wrapper_utils import ( From 71f4dc6506c1785ea218bfb696f138106dd3cbd7 Mon Sep 17 00:00:00 2001 From: Dries Date: Mon, 27 Sep 2021 16:29:06 +0200 Subject: [PATCH 003/104] Save recurrent PPO progress. --- mava/systems/tf/mappo/training.py | 41 ++++++++++++++++++++----------- 1 file changed, 27 insertions(+), 14 deletions(-) diff --git a/mava/systems/tf/mappo/training.py b/mava/systems/tf/mappo/training.py index 72343a2b4..a40159dfe 100644 --- a/mava/systems/tf/mappo/training.py +++ b/mava/systems/tf/mappo/training.py @@ -297,8 +297,8 @@ def _forward(self, inputs: Any) -> None: data.extras, ) - if 'core_state' in extras: - core_state = tree.map_structure(lambda s: s[:, 0, :], extras["core_states"]) + if 'core_states' in extras: + core_states = tree.map_structure(lambda s: s[:, 0, :], extras["core_states"]) # transform observation using observation networks observations_trans = self._transform_observations(observations) @@ -334,28 +334,38 @@ def _forward(self, inputs: Any) -> None: critic_network = self._critic_networks[agent_key] # Reshape inputs. - dims = actor_observation.shape[:2] - actor_observation = snt.merge_leading_dims( - actor_observation, num_dims=2 - ) - critic_observation = snt.merge_leading_dims( - critic_observation, num_dims=2 + critic_observation, dims = train_utils.combine_dim( + critic_observation ) - if 'core_state' in extras: + + if 'core_states' in extras: # Unroll current policy over actor_observation. - (logits, values), _ = snt.static_unroll(self._actor_critic_network, actor_observation, - core_state) + # policy, _ = snt.static_unroll(policy_network, actor_observation, + # core_states) + + transposed_obs = tf2_utils.batch_to_sequence(actor_observation) + outputs, updated_states = snt.static_unroll( + policy_network, + transposed_obs, + core_states, + ) + + outputs = tf2_utils.batch_to_sequence(outputs) + + print("outputs: ", outputs) + exit() + policy = tfd.Categorical(logits=logits[:-1]) policy_log_prob = policy.log_prob(actions) # Entropy regulariser. entropy_loss = trfl.policy_entropy_loss(policy).loss else: policy = policy_network(actor_observation) - values = critic_network(critic_observation) + # Reshape the outputs. policy = tfd.BatchReshape(policy, batch_shape=dims, name="policy") - values = tf.reshape(values, dims, name="value") + policy_log_prob = policy.log_prob(action) # Entropy regularization. Only implemented for categorical dist. @@ -363,6 +373,9 @@ def _forward(self, inputs: Any) -> None: entropy_loss = -self._entropy_cost * policy_entropy # Compute importance sampling weights: current policy / behavior policy. + values = critic_network(critic_observation) + values = tf.reshape(values, dims, name="value") + # Values along the sequence T. bootstrap_value = values[-1] state_values = values[:-1] @@ -621,6 +634,6 @@ def _get_critic_feed( agent: str, ) -> tf.Tensor: # State based - observation_feed = extras["env_states"] + observation_feed = extras["env_states"][agent] return observation_feed From cc42e9807ffbed0a97c1b029e45bcd1ef22d2e56 Mon Sep 17 00:00:00 2001 From: Dries Date: Tue, 28 Sep 2021 08:30:04 +0200 Subject: [PATCH 004/104] Recurrent PPO is running. --- mava/systems/tf/maddpg/training.py | 2 - mava/systems/tf/mappo/builder.py | 16 ++-- mava/systems/tf/mappo/execution.py | 15 ++-- mava/systems/tf/mappo/system.py | 8 +- mava/systems/tf/mappo/training.py | 113 +++++++++++++---------------- 5 files changed, 73 insertions(+), 81 deletions(-) diff --git a/mava/systems/tf/maddpg/training.py b/mava/systems/tf/maddpg/training.py index 76870cbfc..440e30f9d 100644 --- a/mava/systems/tf/maddpg/training.py +++ b/mava/systems/tf/maddpg/training.py @@ -1146,8 +1146,6 @@ def _transform_observations( for agent in self._agents: agent_key = self._agent_net_keys[agent] - print("Observation: ", observations[agent].observation) - exit() reshaped_obs, dims = train_utils.combine_dim( observations[agent].observation ) diff --git a/mava/systems/tf/mappo/builder.py b/mava/systems/tf/mappo/builder.py index 52f031412..b6d872028 100644 --- a/mava/systems/tf/mappo/builder.py +++ b/mava/systems/tf/mappo/builder.py @@ -99,7 +99,7 @@ def __init__( config: MAPPOConfig, trainer_fn: Type[training.MAPPOTrainer] = training.MAPPOTrainer, executor_fn: Type[core.Executor] = execution.MAPPOFeedForwardExecutor, - extra_specs: Dict[str, Any] = {} + extra_specs: Dict[str, Any] = {}, ): """Initialise the system. @@ -121,7 +121,7 @@ def __init__( self._executor_fn = executor_fn self._extras_specs = extra_specs - def add_log_prob_to_spec( + def add_logits_to_spec( self, environment_spec: specs.MAEnvironmentSpec ) -> specs.MAEnvironmentSpec: """convert discrete action space to bounded continuous action space @@ -137,9 +137,11 @@ def add_log_prob_to_spec( for key in keys: agent_spec = env_adder_spec._specs[key] new_act_spec = {"actions": agent_spec.actions} - - # Make dummy log_probs - new_act_spec["log_probs"] = tf2_utils.squeeze_batch_dim(tf.ones(shape=(1,), dtype=tf.float32)) + + # Make dummy logits + new_act_spec["logits"] = tf.ones( + shape=(agent_spec.actions.num_values), dtype=tf.float32 + ) env_adder_spec._specs[key] = EnvironmentSpec( observations=agent_spec.observations, @@ -164,9 +166,7 @@ def make_replay_tables( """ # Create system architecture with target networks. - adder_env_spec = self.add_log_prob_to_spec( - environment_spec - ) + adder_env_spec = self.add_logits_to_spec(environment_spec) replay_table = reverb.Table.queue( name=self._config.replay_table_name, diff --git a/mava/systems/tf/mappo/execution.py b/mava/systems/tf/mappo/execution.py index 471f79f68..c84f5af59 100644 --- a/mava/systems/tf/mappo/execution.py +++ b/mava/systems/tf/mappo/execution.py @@ -183,6 +183,10 @@ def observe( information to record during the transition. Defaults to {}. """ + raise NotImplementedError( + "Need to fix the FeedForward case to reflect the recurrent updates." + ) + if not self._adder: return @@ -203,6 +207,7 @@ def update(self, wait: bool = False) -> None: if self._variable_client: self._variable_client.update(wait) + class MAPPORecurrentExecutor(executors.RecurrentExecutor): """A recurrent executor for MADDPG. An executor based on a recurrent policy for each agent in the system. @@ -264,18 +269,18 @@ def _policy( agent_key = self._agent_net_keys[agent] # Compute the policy, conditioned on the observation. - policy, new_state = self._policy_networks[agent_key](batched_observation, state) + logits, new_state = self._policy_networks[agent_key](batched_observation, state) # Sample from the policy and compute the log likelihood. - action = policy.sample() - log_prob = policy.log_prob(action) + action = tfd.Categorical(logits).sample() + # action = tf2_utils.to_numpy_squeeze(action) # Cast for compatibility with reverb. # sample() returns a 'int32', which is a problem. # if isinstance(policy, tfp.distributions.Categorical): # action = tf.cast(action, "int64") - - return action, log_prob, new_state + + return action, logits, new_state def select_action( self, agent: str, observation: types.NestedArray diff --git a/mava/systems/tf/mappo/system.py b/mava/systems/tf/mappo/system.py index 6d588b6ad..c3afc32ea 100644 --- a/mava/systems/tf/mappo/system.py +++ b/mava/systems/tf/mappo/system.py @@ -62,8 +62,8 @@ def __init__( snt.Optimizer, Dict[str, snt.Optimizer] ] = snt.optimizers.Adam(learning_rate=5e-4), critic_optimizer: snt.Optimizer = snt.optimizers.Adam(learning_rate=1e-5), - discount: float = 0.99, - lambda_gae: float = 0.99, + discount: float = 0.999, + lambda_gae: float = 1.0, clipping_epsilon: float = 0.2, entropy_cost: float = 0.01, baseline_cost: float = 0.5, @@ -226,7 +226,7 @@ def __init__( executor_fn=executor_fn, extra_specs=extra_specs, ) - + def _get_extra_specs(self) -> Any: """helper to establish specs for extra information Returns: @@ -327,7 +327,7 @@ def trainer( trainer_logger = self._logger_factory( # type: ignore "trainer", **trainer_logger_config ) - can_sample = None # lambda: replay.can_sample(self._builder._config.batch_size) + can_sample = None # lambda: replay.can_sample(self._builder._config.batch_size) dataset = self._builder.make_dataset_iterator(replay) counter = counting.Counter(counter, "trainer") diff --git a/mava/systems/tf/mappo/training.py b/mava/systems/tf/mappo/training.py index a40159dfe..44ff2efcc 100644 --- a/mava/systems/tf/mappo/training.py +++ b/mava/systems/tf/mappo/training.py @@ -30,6 +30,7 @@ from acme.utils import counting, loggers import mava +from mava.adders.reverb.base import Trajectory from mava.systems.tf import savers as tf2_savers from mava.utils import training_utils as train_utils @@ -58,7 +59,7 @@ def __init__( agent_net_keys: Dict[str, str], checkpoint_minute_interval: int, discount: float = 0.999, - lambda_gae: float = 1.0, # Question (dries): What is this used for? + lambda_gae: float = 1.0, # Question (dries): What is this used for? entropy_cost: float = 0.01, baseline_cost: float = 0.5, clipping_epsilon: float = 0.2, @@ -112,7 +113,7 @@ def __init__( self._agents = agents self._agent_types = agent_types self._checkpoint = checkpoint - + # Can sample self._can_sample = can_sample @@ -253,7 +254,9 @@ def _transform_observations( ) return observation_trans - print("Add this back in. But this breaks the next(itter check) as and infinite number of pulls are allowed now. Fix it.") + print( + "Add this back in. But this breaks the next(itter check) as and infinite number of pulls are allowed now. Fix it." + ) # @tf.function def _step( self, @@ -284,12 +287,11 @@ def _forward(self, inputs: Any) -> None: inputs (Any): input data from the data table (transitions) """ - # Convert to sequence data - # TODO(Kale-ab): Is this the recurrent trainer? - data = tf2_utils.batch_to_sequence(inputs.data) + data: Trajectory = inputs.data # Unpack input data as follows: - observations, actions_data, rewards, discounts, extras = ( + data = tf2_utils.batch_to_sequence(inputs.data) + observations, actions, rewards, discounts, extras = ( data.observations, data.actions, data.rewards, @@ -297,16 +299,12 @@ def _forward(self, inputs: Any) -> None: data.extras, ) - if 'core_states' in extras: - core_states = tree.map_structure(lambda s: s[:, 0, :], extras["core_states"]) + if "core_states" in extras: + core_states = tree.map_structure(lambda s: s[0], extras["core_states"]) # transform observation using observation networks observations_trans = self._transform_observations(observations) - # Get log_probs. - actions = {agent: actions_data[agent]["actions"] for agent in actions_data.keys()} - behaviour_logits = {agent: actions_data[agent]["log_probs"] for agent in actions_data.keys()} - # Store losses. policy_losses: Dict[str, Any] = {} critic_losses: Dict[str, Any] = {} @@ -315,16 +313,19 @@ def _forward(self, inputs: Any) -> None: for agent in self._agents: action, reward, discount, behaviour_logits = ( - actions[agent], + actions[agent]["actions"], rewards[agent], discounts[agent], - behaviour_logits[agent], + actions[agent]["logits"], ) actor_observation = observations_trans[agent] - critic_observation = self._get_critic_feed(observations_trans, extras, agent) + critic_observation = self._get_critic_feed( + observations_trans, extras, agent + ) # Chop off final timestep for bootstrapping value + action = action[:-1] reward = reward[:-1] discount = discount[:-1] @@ -333,46 +334,32 @@ def _forward(self, inputs: Any) -> None: policy_network = self._policy_networks[agent_key] critic_network = self._critic_networks[agent_key] - # Reshape inputs. - critic_observation, dims = train_utils.combine_dim( - critic_observation - ) - - - if 'core_states' in extras: + if "core_states" in extras: # Unroll current policy over actor_observation. # policy, _ = snt.static_unroll(policy_network, actor_observation, # core_states) - transposed_obs = tf2_utils.batch_to_sequence(actor_observation) - outputs, updated_states = snt.static_unroll( + # transposed_obs = tf2_utils.batch_to_sequence(actor_observation) + agent_core_state = core_states[agent][0] + logits, updated_states = snt.static_unroll( policy_network, - transposed_obs, - core_states, + actor_observation, + agent_core_state, ) + # logits = tf2_utils.batch_to_sequence(logits) - outputs = tf2_utils.batch_to_sequence(outputs) - - print("outputs: ", outputs) - exit() - - policy = tfd.Categorical(logits=logits[:-1]) - policy_log_prob = policy.log_prob(actions) - # Entropy regulariser. - entropy_loss = trfl.policy_entropy_loss(policy).loss else: - policy = policy_network(actor_observation) - + logits = policy_network(actor_observation) + # Reshape the outputs. - policy = tfd.BatchReshape(policy, batch_shape=dims, name="policy") - - policy_log_prob = policy.log_prob(action) - - # Entropy regularization. Only implemented for categorical dist. - policy_entropy = tf.reduce_mean(policy.entropy()) - entropy_loss = -self._entropy_cost * policy_entropy - # Compute importance sampling weights: current policy / behavior policy. - + # policy = tfd.BatchReshape(policy, batch_shape=dims, name="policy") + + # Compute importance sampling weights: current policy / behavior policy. + pi_behaviour = tfd.Categorical(logits=behaviour_logits[:-1]) + pi_target = tfd.Categorical(logits=logits[:-1]) + + # Calculate critic values. + critic_observation, dims = train_utils.combine_dim(critic_observation) values = critic_network(critic_observation) values = tf.reshape(values, dims, name="value") @@ -384,7 +371,7 @@ def _forward(self, inputs: Any) -> None: td_loss, td_lambda_extra = trfl.td_lambda( state_values=state_values, rewards=reward, - pcontinues=discount, + pcontinues=discount, # Question (dries): Why is self._discount not used. Why is discount not 0.0/1.0 but actually has discount values. bootstrap_value=bootstrap_value, lambda_=self._lambda_gae, name="CriticLoss", @@ -392,15 +379,13 @@ def _forward(self, inputs: Any) -> None: # Do not use the loss provided by td_lambda as they sum the losses over # the sequence length rather than averaging them. - critic_loss = self._baseline_cost * tf.reduce_mean( - tf.square(td_lambda_extra.temporal_differences), name="CriticLoss" + critic_loss = self._baseline_cost * tf.square( + td_lambda_extra.temporal_differences ) # Compute importance sampling weights: current policy / behavior policy. - behaviour_log_prob = tfd.Categorical(logits=behaviour_logits[:-1]) - - log_rhos = policy_log_prob - behaviour_log_prob - importance_ratio = tf.exp(log_rhos)[:-1] + log_rhos = pi_target.log_prob(action) - pi_behaviour.log_prob(action) + importance_ratio = tf.exp(log_rhos) clipped_importance_ratio = tf.clip_by_value( importance_ratio, 1.0 - self._clipping_epsilon, @@ -412,20 +397,23 @@ def _forward(self, inputs: Any) -> None: mean, variance = tf.nn.moments(gae, axes=[0, 1], keepdims=True) normalized_gae = (gae - mean) / tf.sqrt(variance) - policy_gradient_loss = tf.reduce_mean( - -tf.minimum( - tf.multiply(importance_ratio, normalized_gae), - tf.multiply(clipped_importance_ratio, normalized_gae), - ), + policy_gradient_loss = -tf.minimum( + tf.multiply(importance_ratio, normalized_gae), + tf.multiply(clipped_importance_ratio, normalized_gae), name="PolicyGradientLoss", ) - # Combine weighted sum of actor & entropy regularization. + # Entropy regulariser. + entropy_loss = trfl.policy_entropy_loss(pi_target).loss + + # Combine weighted sum of actor & entropy regularization. policy_loss = policy_gradient_loss + entropy_loss # Multiply by discounts to not train on padded data. - policy_losses[agent] = policy_loss*discount - critic_losses[agent] = critic_loss*discount + loss_mask = discount > 0.0 + # TODO (dries): Is multiplication maybe better here? As assignment might not work with tf.function? + policy_losses[agent] = tf.reduce_mean(policy_loss[loss_mask]) + critic_losses[agent] = tf.reduce_mean(critic_loss[loss_mask]) self.policy_losses = policy_losses self.critic_losses = critic_losses @@ -575,6 +563,7 @@ def _get_critic_feed( return observation_feed + class StateBasedMAPPOTrainer(MAPPOTrainer): """MAPPO trainer for a centralised architecture.""" From 1ae02b5c04d0d20a5d491bf09b960a0bd193af3d Mon Sep 17 00:00:00 2001 From: Dries Date: Tue, 28 Sep 2021 16:20:00 +0200 Subject: [PATCH 005/104] Small fixes. --- .../recurrent/state_based/run_mappo.py | 114 ++++++++++++++++++ .../tf/architectures/state_based.py | 22 ++-- .../components/tf/modules/mixing/monotonic.py | 2 +- mava/systems/tf/maddpg/training.py | 8 +- mava/systems/tf/mappo/builder.py | 23 ++-- mava/systems/tf/mappo/execution.py | 2 +- mava/systems/tf/mappo/networks.py | 43 ++++++- mava/systems/tf/mappo/system.py | 3 +- mava/systems/tf/mappo/training.py | 29 ++--- mava/systems/tf/qmix/training.py | 4 +- mava/utils/debugging/environment.py | 2 +- mava/utils/debugging/environments/two_step.py | 10 +- .../robocup_utils/util_functions.py | 2 +- mava/wrappers/debugging_envs.py | 6 +- mava/wrappers/smac.py | 6 +- 15 files changed, 218 insertions(+), 58 deletions(-) create mode 100644 examples/debugging/simple_spread/recurrent/state_based/run_mappo.py diff --git a/examples/debugging/simple_spread/recurrent/state_based/run_mappo.py b/examples/debugging/simple_spread/recurrent/state_based/run_mappo.py new file mode 100644 index 000000000..c84c3885a --- /dev/null +++ b/examples/debugging/simple_spread/recurrent/state_based/run_mappo.py @@ -0,0 +1,114 @@ +# python3 +# Copyright 2021 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Example running MAPPO on debug MPE environments.""" + +import functools +from datetime import datetime +from typing import Any + +import launchpad as lp +import sonnet as snt +from absl import app, flags + +from mava.components.tf import architectures +from mava.systems.tf import mappo +from mava.utils import lp_utils +from mava.utils.enums import ArchitectureType +from mava.utils.environments import debugging_utils +from mava.utils.loggers import logger_utils + +FLAGS = flags.FLAGS +flags.DEFINE_string( + "env_name", + "simple_spread", + "Debugging environment name (str).", +) +flags.DEFINE_string( + "action_space", + "discrete", + "Environment action space type (str).", +) + +flags.DEFINE_string( + "mava_id", + str(datetime.now()), + "Experiment identifier that can be used to continue experiments.", +) +print("Dries change this back to ~/mava.") +flags.DEFINE_string("base_dir", "/home/mava_logs", "Base dir to store experiments.") + + +def main(_: Any) -> None: + + # Environment. + environment_factory = functools.partial( + debugging_utils.make_environment, + env_name=FLAGS.env_name, + action_space=FLAGS.action_space, + return_state_info=True, + ) + + # Networks. + network_factory = lp_utils.partial_kwargs( + mappo.make_default_networks, + archecture_type=ArchitectureType.recurrent, + ) + + # Checkpointer appends "Checkpoints" to checkpoint_dir + checkpoint_dir = f"{FLAGS.base_dir}/{FLAGS.mava_id}" + + # Log every [log_every] seconds. + log_every = 10 + logger_factory = functools.partial( + logger_utils.make_logger, + directory=FLAGS.base_dir, + to_terminal=True, + to_tensorboard=True, + time_stamp=FLAGS.mava_id, + time_delta=log_every, + ) + + # Distributed program + program = mappo.MAPPO( + environment_factory=environment_factory, + network_factory=network_factory, + logger_factory=logger_factory, + num_executors=1, + policy_optimizer=snt.optimizers.Adam(learning_rate=1e-4), + critic_optimizer=snt.optimizers.Adam(learning_rate=1e-4), + checkpoint_subpath=checkpoint_dir, + max_gradient_norm=40.0, + executor_fn=mappo.MAPPORecurrentExecutor, + architecture=architectures.StateBasedValueActorCritic, + trainer_fn=mappo.StateBasedMAPPOTrainer, + ).build() + + # Ensure only trainer runs on gpu, while other processes run on cpu. + local_resources = lp_utils.to_device( + program_nodes=program.groups.keys(), nodes_on_gpu=["trainer"] + ) + + # Launch. + lp.launch( + program, + lp.LaunchType.LOCAL_MULTI_PROCESSING, + terminal="current_terminal", + local_resources=local_resources, + ) + + +if __name__ == "__main__": + app.run(main) diff --git a/mava/components/tf/architectures/state_based.py b/mava/components/tf/architectures/state_based.py index 21e3bf1ea..45a1609a9 100644 --- a/mava/components/tf/architectures/state_based.py +++ b/mava/components/tf/architectures/state_based.py @@ -24,9 +24,9 @@ from mava import specs as mava_specs from mava.components.tf.architectures.decentralised import ( - DecentralisedValueActorCritic, DecentralisedPolicyActor, DecentralisedQValueActorCritic, + DecentralisedValueActorCritic, ) @@ -56,7 +56,7 @@ def _get_actor_specs( agents_by_type = self._env_spec.get_agents_by_type() for agent_type, agents in agents_by_type.items(): - actor_state_shape = self._env_spec.get_extra_specs()["s_t"].shape + actor_state_shape = self._env_spec.get_extra_specs()["env_states"].shape obs_specs_per_type[agent_type] = tf.TensorSpec( shape=actor_state_shape, dtype=tf.dtypes.float32, @@ -99,7 +99,7 @@ def _get_critic_specs( # Create one critic per agent. Each critic gets # absolute state information of the environment. - critic_state_shape = self._env_spec.get_extra_specs()["s_t"].shape + critic_state_shape = self._env_spec.get_extra_specs()["env_states"].shape critic_obs_spec = tf.TensorSpec( shape=critic_state_shape, dtype=tf.dtypes.float32, @@ -148,9 +148,8 @@ def __init__( agent_net_keys=agent_net_keys, ) -class StateBasedValueActorCritic( # type: ignore - DecentralisedValueActorCritic -): + +class StateBasedValueActorCritic(DecentralisedValueActorCritic): # type: ignore """Multi-agent actor critic architecture where both actor policies and critics use environment state information""" @@ -170,14 +169,19 @@ def __init__( critic_networks=critic_networks, agent_net_keys=agent_net_keys, ) + def _get_critic_specs( self, ) -> Tuple[Dict[str, acme_specs.Array], Dict[str, acme_specs.Array]]: # Create one critic per agent. Each critic gets # absolute state information of the environment. - critic_env_state_spec = list( - self._env_spec.get_extra_specs()["env_states"].values() - )[0] + + critic_env_state_spec = self._env_spec.get_extra_specs()["env_states"] + if type(critic_env_state_spec) == dict: + critic_env_state_spec = list(critic_env_state_spec.values())[0] + + if type(critic_env_state_spec) != list: + critic_env_state_spec = [critic_env_state_spec] critic_obs_spec = [] for spec in critic_env_state_spec: diff --git a/mava/components/tf/modules/mixing/monotonic.py b/mava/components/tf/modules/mixing/monotonic.py index bee0465b8..1c6afe161 100644 --- a/mava/components/tf/modules/mixing/monotonic.py +++ b/mava/components/tf/modules/mixing/monotonic.py @@ -63,7 +63,7 @@ def __init__( def _create_mixing_layer(self, name: str = "mixing") -> snt.Module: """Modify and return system architecture given mixing structure.""" state_specs = self._environment_spec.get_extra_specs() - state_specs = state_specs["s_t"] + state_specs = state_specs["env_states"] q_value_dim = tf.TensorSpec(self._n_agents) diff --git a/mava/systems/tf/maddpg/training.py b/mava/systems/tf/maddpg/training.py index 440e30f9d..2055ddda1 100644 --- a/mava/systems/tf/maddpg/training.py +++ b/mava/systems/tf/maddpg/training.py @@ -885,8 +885,8 @@ def _get_critic_feed( ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]: # State based - o_tm1_feed = e_tm1["s_t"] - o_t_feed = e_t["s_t"] + o_tm1_feed = e_tm1["env_states"] + o_t_feed = e_t["env_states"] a_tm1_feed = tf.stack([a_tm1[agent] for agent in self._agents], 1) a_t_feed = tf.stack([a_t[agent] for agent in self._agents], 1) @@ -1744,8 +1744,8 @@ def _get_critic_feed( ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]: # State based - obs_trans_feed = extras["s_t"] - target_obs_trans_feed = extras["s_t"] + obs_trans_feed = extras["env_states"] + target_obs_trans_feed = extras["env_states"] actions_feed = tf.stack([actions[agent] for agent in self._agents], -1) target_actions_feed = tf.stack( [target_actions[agent] for agent in self._agents], -1 diff --git a/mava/systems/tf/mappo/builder.py b/mava/systems/tf/mappo/builder.py index b6d872028..7370ebb4c 100644 --- a/mava/systems/tf/mappo/builder.py +++ b/mava/systems/tf/mappo/builder.py @@ -15,16 +15,16 @@ """MAPPO system builder implementation.""" -import dataclasses import copy -from typing import Dict, Iterator, List, Optional, Type, Union, Any -import tensorflow as tf +import dataclasses +from typing import Any, Dict, Iterator, List, Optional, Type, Union import reverb import sonnet as snt -from acme.tf import utils as tf2_utils +import tensorflow as tf from acme import datasets from acme.specs import EnvironmentSpec +from acme.tf import utils as tf2_utils from acme.tf import variable_utils from acme.utils import counting @@ -198,10 +198,19 @@ def make_dataset_iterator( """ # Create tensorflow dataset to interface with reverb - dataset = datasets.make_reverb_dataset( + # dataset = datasets.make_reverb_dataset( + # server_address=replay_client.server_address, + # batch_size=self._config.batch_size, + # ) + + dataset = reverb.TrajectoryDataset.from_table_signature( server_address=replay_client.server_address, - batch_size=self._config.batch_size, - ) + table=self._config.replay_table_name, + max_in_flight_samples_per_worker=1, + ).batch( + self._config.batch_size, drop_remainder=True + ) # .prefetch(self._config.prefetch_size) + # .as_numpy_iterator() return iter(dataset) diff --git a/mava/systems/tf/mappo/execution.py b/mava/systems/tf/mappo/execution.py index c84f5af59..02fbe74e8 100644 --- a/mava/systems/tf/mappo/execution.py +++ b/mava/systems/tf/mappo/execution.py @@ -278,7 +278,7 @@ def _policy( # Cast for compatibility with reverb. # sample() returns a 'int32', which is a problem. # if isinstance(policy, tfp.distributions.Categorical): - # action = tf.cast(action, "int64") + action = tf.cast(action, "int64") return action, logits, new_state diff --git a/mava/systems/tf/mappo/networks.py b/mava/systems/tf/mappo/networks.py index b2f40c647..f5fe78e78 100644 --- a/mava/systems/tf/mappo/networks.py +++ b/mava/systems/tf/mappo/networks.py @@ -24,6 +24,7 @@ from mava import specs as mava_specs from mava.components.tf import networks +from mava.utils.enums import ArchitectureType Array = specs.Array BoundedArray = specs.BoundedArray @@ -40,6 +41,7 @@ def make_default_networks( 256, ), critic_networks_layer_sizes: Union[Dict[str, Sequence], Sequence] = (512, 512, 256), + archecture_type=ArchitectureType.feedforward, seed: Optional[int] = None, ) -> Dict[str, snt.Module]: """Default networks for mappo. @@ -67,6 +69,21 @@ def make_default_networks( specs = environment_spec.get_agent_specs() specs = {agent_net_keys[key]: specs[key] for key in specs.keys()} + # Set Policy function and layer size + # Default size per arch type. + if archecture_type == ArchitectureType.feedforward: + if not policy_networks_layer_sizes: + policy_networks_layer_sizes = ( + 256, + 256, + 256, + ) + policy_network_func = snt.Sequential + elif archecture_type == ArchitectureType.recurrent: + if not policy_networks_layer_sizes: + policy_networks_layer_sizes = (128, 128) + policy_network_func = snt.DeepRNN + if isinstance(policy_networks_layer_sizes, Sequence): policy_networks_layer_sizes = { key: policy_networks_layer_sizes for key in specs.keys() @@ -87,20 +104,36 @@ def make_default_networks( # Note: The discrete case must be placed first as it inherits from BoundedArray. if isinstance(specs[key].actions, dm_env.specs.DiscreteArray): # discrete num_actions = specs[key].actions.num_values - policy_network = snt.Sequential( - [ + + if archecture_type == ArchitectureType.feedforward: + policy_layers = [ networks.LayerNormMLP( tuple(policy_networks_layer_sizes[key]) + (num_actions,), activate_final=False, seed=seed, ), - tf.keras.layers.Lambda( - lambda logits: tfp.distributions.Categorical(logits=logits) + ] + elif archecture_type == ArchitectureType.recurrent: + policy_layers = [ + networks.LayerNormMLP( + policy_networks_layer_sizes[key][:-1], + activate_final=True, + seed=seed, + ), + snt.LSTM(policy_networks_layer_sizes[key][-1]), + networks.LayerNormMLP( + (num_actions,), + activate_final=False, + seed=seed, ), ] - ) + + policy_network = policy_network_func(policy_layers) + elif isinstance(specs[key].actions, dm_env.specs.BoundedArray): # continuous num_actions = np.prod(specs[key].actions.shape, dtype=int) + # No recurrent setting implemented yet. + assert archecture_type == ArchitectureType.feedforward policy_network = snt.Sequential( [ networks.LayerNormMLP( diff --git a/mava/systems/tf/mappo/system.py b/mava/systems/tf/mappo/system.py index c3afc32ea..f17408191 100644 --- a/mava/systems/tf/mappo/system.py +++ b/mava/systems/tf/mappo/system.py @@ -24,8 +24,8 @@ import reverb import sonnet as snt from acme import specs as acme_specs -from acme.utils import counting, loggers from acme.tf import utils as tf2_utils +from acme.utils import counting, loggers import mava from mava import core @@ -303,7 +303,6 @@ def trainer( Returns: mava.core.Trainer: system trainer. """ - # Create the networks to optimize (online) networks = self._network_factory( # type: ignore environment_spec=self._environment_spec, diff --git a/mava/systems/tf/mappo/training.py b/mava/systems/tf/mappo/training.py index 44ff2efcc..95962a65d 100644 --- a/mava/systems/tf/mappo/training.py +++ b/mava/systems/tf/mappo/training.py @@ -18,13 +18,13 @@ import copy import os import time -import tree from typing import Any, Dict, List, Optional, Sequence, Union import numpy as np import sonnet as snt import tensorflow as tf import tensorflow_probability as tfp +import tree import trfl from acme.tf import utils as tf2_utils from acme.utils import counting, loggers @@ -254,10 +254,7 @@ def _transform_observations( ) return observation_trans - print( - "Add this back in. But this breaks the next(itter check) as and infinite number of pulls are allowed now. Fix it." - ) - # @tf.function + # Warning (dries): Do not place a tf.function here. It breaks the itterator. def _step( self, ) -> Dict[str, Dict[str, Any]]: @@ -269,18 +266,20 @@ def _step( # Get data from replay. inputs = next(self._iterator) + self.forward_backward(inputs) - self._forward(inputs) - - self._backward() - + @tf.function + def forward_backward(self, inputs: Any): + self._forward_pass(inputs) + self._backward_pass() # Log losses per agent return train_utils.map_losses_per_agent_ac( self.critic_losses, self.policy_losses ) # Forward pass that calculates loss. - def _forward(self, inputs: Any) -> None: + # This name is changed from _backward to make sure the trainer wrappers are not overwriting the _step function. + def _forward_pass(self, inputs: Any) -> None: """Trainer forward pass Args: @@ -420,7 +419,8 @@ def _forward(self, inputs: Any) -> None: self.tape = tape # Backward pass that calculates gradients and updates network. - def _backward(self) -> None: + # This name is changed from _backward to make sure the trainer wrappers are not overwriting the _step function. + def _backward_pass(self) -> None: """Trainer backward pass updating network parameters""" # Calculate the gradients and update the networks @@ -623,6 +623,7 @@ def _get_critic_feed( agent: str, ) -> tf.Tensor: # State based - observation_feed = extras["env_states"][agent] - - return observation_feed + if type(extras["env_states"]) == dict: + return extras["env_states"][agent] + else: + return extras["env_states"] diff --git a/mava/systems/tf/qmix/training.py b/mava/systems/tf/qmix/training.py index 1897553fb..8aaaca4a7 100644 --- a/mava/systems/tf/qmix/training.py +++ b/mava/systems/tf/qmix/training.py @@ -239,8 +239,8 @@ def _forward(self, inputs: reverb.ReplaySample) -> None: trans.next_extras, ) - s_tm1 = e_tm1["s_t"] - s_t = e_t["s_t"] + s_tm1 = e_tm1["env_states"] + s_t = e_t["env_states"] # Do forward passes through the networks and calculate the losses with tf.GradientTape(persistent=True) as tape: diff --git a/mava/utils/debugging/environment.py b/mava/utils/debugging/environment.py index a27c4f586..30b9dc901 100644 --- a/mava/utils/debugging/environment.py +++ b/mava/utils/debugging/environment.py @@ -180,7 +180,7 @@ def reset(self) -> Tuple[Dict[str, Any], Dict[str, Any]]: agent = self.agents[agent_id] obs_n[agent_id] = self._get_obs(a_i, agent) state_n = self._get_state() - return obs_n, {"s_t": state_n} + return obs_n, {"env_states": state_n} # get info used for benchmarking def _get_info(self, agent: Agent) -> Dict: diff --git a/mava/utils/debugging/environments/two_step.py b/mava/utils/debugging/environments/two_step.py index 89c2d8e53..760161616 100644 --- a/mava/utils/debugging/environments/two_step.py +++ b/mava/utils/debugging/environments/two_step.py @@ -62,7 +62,7 @@ def reset( "agent_0": np.array([0.0], dtype=np.float32), "agent_1": np.array([0.0], dtype=np.float32), } - return self.obs_n, {"s_t": np.array(self.state, dtype=np.int64)} + return self.obs_n, {"env_states": np.array(self.state, dtype=np.int64)} def step( self, action_n: Dict[str, int] @@ -83,7 +83,7 @@ def step( self.obs_n, self.reward_n, self.done_n, - {"s_t": np.array(self.state, dtype=np.int64)}, + {"env_states": np.array(self.state, dtype=np.int64)}, ) # Go to 2A else: self.state = 2 @@ -95,7 +95,7 @@ def step( self.obs_n, self.reward_n, self.done_n, - {"s_t": np.array(self.state, dtype=np.int64)}, + {"env_states": np.array(self.state, dtype=np.int64)}, ) # Go to 2B elif self.state == 1: # State 2A @@ -110,7 +110,7 @@ def step( self.obs_n, self.reward_n, self.done_n, - {"s_t": np.array(self.state, dtype=np.int64)}, + {"env_states": np.array(self.state, dtype=np.int64)}, ) elif self.state == 2: # State 2B @@ -141,7 +141,7 @@ def step( self.obs_n, self.reward_n, self.done_n, - {"s_t": np.array(self.state, dtype=np.int64)}, + {"env_states": np.array(self.state, dtype=np.int64)}, ) else: diff --git a/mava/utils/environments/RoboCup_env/robocup_utils/util_functions.py b/mava/utils/environments/RoboCup_env/robocup_utils/util_functions.py index 5a033139d..effa5d749 100644 --- a/mava/utils/environments/RoboCup_env/robocup_utils/util_functions.py +++ b/mava/utils/environments/RoboCup_env/robocup_utils/util_functions.py @@ -308,7 +308,7 @@ def discount_spec(self) -> Dict[str, specs.BoundedArray]: return discount_specs def extra_spec(self) -> Dict[str, specs.BoundedArray]: - return {"s_t": self._state_spec} + return {"env_states": self._state_spec} def _proc_robocup_obs( self, observations: Dict, done: bool, nn_actions: Dict = None diff --git a/mava/wrappers/debugging_envs.py b/mava/wrappers/debugging_envs.py index 1f5d7e232..32d17165a 100644 --- a/mava/wrappers/debugging_envs.py +++ b/mava/wrappers/debugging_envs.py @@ -103,7 +103,7 @@ def step(self, actions: Dict[str, np.ndarray]) -> Tuple[dm_env.TimeStep, np.arra step_type=self._step_type, ) if self.return_state_info: - return timestep, {"s_t": state} + return timestep, {"env_states": state} else: return timestep @@ -159,7 +159,7 @@ def extra_spec(self) -> Dict[str, specs.BoundedArray]: minimum=[float("-inf")] * shape[0], maximum=[float("inf")] * shape[0], ) - extras.update({"s_t": ex_spec}) + extras.update({"env_states": ex_spec}) return extras @@ -205,7 +205,7 @@ def __init__(self, environment: TwoStepEnv) -> None: self.environment.action_spaces = {} self.environment.observation_spaces = {} self.environment.extra_specs = { - "s_t": spaces.Discrete(3) + "env_states": spaces.Discrete(3) } # Global state 1, 2, or 3 for agent_id in self.environment.agent_ids: diff --git a/mava/wrappers/smac.py b/mava/wrappers/smac.py index c9ee9aee0..1426f7c6a 100644 --- a/mava/wrappers/smac.py +++ b/mava/wrappers/smac.py @@ -120,7 +120,7 @@ def reset(self) -> Tuple[dm_env.TimeStep, np.array]: # dm_env timestep timestep = parameterized_restart(rewards, self._discounts, observations) - return timestep, {"s_t": state} + return timestep, {"env_states": state} def step(self, actions: Dict[str, np.ndarray]) -> Tuple[dm_env.TimeStep, np.array]: """Returns observations from ready agents. @@ -189,7 +189,7 @@ def step(self, actions: Dict[str, np.ndarray]) -> Tuple[dm_env.TimeStep, np.arra self.reward = rewards - return timestep, {"s_t": state} + return timestep, {"env_states": state} def env_done(self) -> bool: """ @@ -247,7 +247,7 @@ def extra_spec(self) -> Dict[str, specs.BoundedArray]: state = self._environment.get_state() # TODO (dries): What should the real bounds be of the state spec? return { - "s_t": specs.BoundedArray( + "env_states": specs.BoundedArray( state.shape, np.float32, minimum=float("-inf"), maximum=float("inf") ) } From ac461766c5e7b1b08922c94d3276be4ad3ba436c Mon Sep 17 00:00:00 2001 From: Dries Date: Wed, 29 Sep 2021 10:09:33 +0200 Subject: [PATCH 006/104] Recurrent MAPPO trains of the debugging environment! --- mava/systems/tf/mappo/execution.py | 7 ++++--- mava/systems/tf/mappo/training.py | 22 ++++++++++++---------- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/mava/systems/tf/mappo/execution.py b/mava/systems/tf/mappo/execution.py index 02fbe74e8..9d2a6f479 100644 --- a/mava/systems/tf/mappo/execution.py +++ b/mava/systems/tf/mappo/execution.py @@ -304,7 +304,7 @@ def select_action( self._states[agent] = self._policy_networks[agent_key].initia_state(1) # Step the recurrent policy forward given the current observation and state. - action, log_prob, new_state = self._policy( + action, logits, new_state = self._policy( agent, observation.observation, self._states[agent] ) @@ -313,8 +313,8 @@ def select_action( # Return a numpy array with squeezed out batch dimension. action = tf2_utils.to_numpy_squeeze(action) - log_prob = tf2_utils.to_numpy_squeeze(log_prob) - return action, log_prob + logits = tf2_utils.to_numpy_squeeze(logits) + return action, logits def select_actions( self, observations: Dict[str, types.NestedArray] @@ -330,6 +330,7 @@ def select_actions( actions and policies for all agents in the system. """ + # TODO (dries): Add this to a function and add tf.function here. actions = {} log_probs = {} for agent, observation in observations.items(): diff --git a/mava/systems/tf/mappo/training.py b/mava/systems/tf/mappo/training.py index 95962a65d..7f73ed916 100644 --- a/mava/systems/tf/mappo/training.py +++ b/mava/systems/tf/mappo/training.py @@ -266,7 +266,9 @@ def _step( # Get data from replay. inputs = next(self._iterator) - self.forward_backward(inputs) + + # Do a forward and backwards pass using tf.function. + return self.forward_backward(inputs) @tf.function def forward_backward(self, inputs: Any): @@ -310,15 +312,14 @@ def _forward_pass(self, inputs: Any) -> None: with tf.GradientTape(persistent=True) as tape: for agent in self._agents: - - action, reward, discount, behaviour_logits = ( + action, reward, discount, behaviour_logits, actor_observation = ( actions[agent]["actions"], rewards[agent], discounts[agent], actions[agent]["logits"], + observations_trans[agent], ) - actor_observation = observations_trans[agent] critic_observation = self._get_critic_feed( observations_trans, extras, agent ) @@ -359,8 +360,7 @@ def _forward_pass(self, inputs: Any) -> None: # Calculate critic values. critic_observation, dims = train_utils.combine_dim(critic_observation) - values = critic_network(critic_observation) - values = tf.reshape(values, dims, name="value") + values = tf.reshape(critic_network(critic_observation), dims) # Values along the sequence T. bootstrap_value = values[-1] @@ -393,12 +393,14 @@ def _forward_pass(self, inputs: Any) -> None: # Generalized Advantage Estimation gae = tf.stop_gradient(td_lambda_extra.temporal_differences) - mean, variance = tf.nn.moments(gae, axes=[0, 1], keepdims=True) - normalized_gae = (gae - mean) / tf.sqrt(variance) + + # Note (dries): Maybe add this in again? But this might be breaking training. + # mean, variance = tf.nn.moments(gae, axes=[0, 1], keepdims=True) + # normalized_gae = (gae - mean) / tf.sqrt(variance) policy_gradient_loss = -tf.minimum( - tf.multiply(importance_ratio, normalized_gae), - tf.multiply(clipped_importance_ratio, normalized_gae), + tf.multiply(importance_ratio, gae), + tf.multiply(clipped_importance_ratio, gae), name="PolicyGradientLoss", ) From 1e4aaae3cd0d8cea841f2f89acdca8cf859d504f Mon Sep 17 00:00:00 2001 From: Dries Date: Wed, 29 Sep 2021 10:14:52 +0200 Subject: [PATCH 007/104] Small fix. --- .../debugging/simple_spread/recurrent/state_based/run_mappo.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/debugging/simple_spread/recurrent/state_based/run_mappo.py b/examples/debugging/simple_spread/recurrent/state_based/run_mappo.py index c84c3885a..cf4ad41b6 100644 --- a/examples/debugging/simple_spread/recurrent/state_based/run_mappo.py +++ b/examples/debugging/simple_spread/recurrent/state_based/run_mappo.py @@ -47,8 +47,7 @@ str(datetime.now()), "Experiment identifier that can be used to continue experiments.", ) -print("Dries change this back to ~/mava.") -flags.DEFINE_string("base_dir", "/home/mava_logs", "Base dir to store experiments.") +flags.DEFINE_string("base_dir", "~/mava", "Base dir to store experiments.") def main(_: Any) -> None: From fa2b43ce71f53d09eb303ba4a7fe28351a5f9fce Mon Sep 17 00:00:00 2001 From: Dries Date: Thu, 30 Sep 2021 13:38:08 +0200 Subject: [PATCH 008/104] Save changes. --- .../recurrent/state_based/run_mappo.py | 12 +- .../tf/architectures/state_based.py | 4 + mava/systems/tf/mappo/execution.py | 146 +++++++----------- mava/systems/tf/mappo/training.py | 36 +++-- mava/utils/debugging/make_env.py | 8 +- .../debugging/scenarios/simple_spread.py | 19 ++- mava/utils/environments/debugging_utils.py | 5 +- 7 files changed, 116 insertions(+), 114 deletions(-) diff --git a/examples/debugging/simple_spread/recurrent/state_based/run_mappo.py b/examples/debugging/simple_spread/recurrent/state_based/run_mappo.py index cf4ad41b6..2b8e5715f 100644 --- a/examples/debugging/simple_spread/recurrent/state_based/run_mappo.py +++ b/examples/debugging/simple_spread/recurrent/state_based/run_mappo.py @@ -52,18 +52,24 @@ def main(_: Any) -> None: + recurrent_test = True + recurrent_ppo = True + # Environment. environment_factory = functools.partial( debugging_utils.make_environment, env_name=FLAGS.env_name, action_space=FLAGS.action_space, return_state_info=True, + recurrent_test=recurrent_test, ) # Networks. network_factory = lp_utils.partial_kwargs( mappo.make_default_networks, - archecture_type=ArchitectureType.recurrent, + archecture_type=ArchitectureType.recurrent + if recurrent_ppo + else ArchitectureType.feedforward, ) # Checkpointer appends "Checkpoints" to checkpoint_dir @@ -90,7 +96,9 @@ def main(_: Any) -> None: critic_optimizer=snt.optimizers.Adam(learning_rate=1e-4), checkpoint_subpath=checkpoint_dir, max_gradient_norm=40.0, - executor_fn=mappo.MAPPORecurrentExecutor, + executor_fn=mappo.MAPPORecurrentExecutor + if recurrent_ppo + else mappo.MAPPOFeedForwardExecutor, architecture=architectures.StateBasedValueActorCritic, trainer_fn=mappo.StateBasedMAPPOTrainer, ).build() diff --git a/mava/components/tf/architectures/state_based.py b/mava/components/tf/architectures/state_based.py index 45a1609a9..09f37e6d0 100644 --- a/mava/components/tf/architectures/state_based.py +++ b/mava/components/tf/architectures/state_based.py @@ -192,6 +192,10 @@ def _get_critic_specs( ) ) + # TODO (dries): Fix this + assert len(critic_obs_spec) == 1 + critic_obs_spec = critic_obs_spec[0] + critic_obs_specs = {} for agent_key in self._agents: # Get observation spec for critic. diff --git a/mava/systems/tf/mappo/execution.py b/mava/systems/tf/mappo/execution.py index 9d2a6f479..e0eb12072 100644 --- a/mava/systems/tf/mappo/execution.py +++ b/mava/systems/tf/mappo/execution.py @@ -33,9 +33,9 @@ tfd = tfp.distributions -class MAPPOFeedForwardExecutor(core.Executor): - """A feed-forward executor. - An executor based on a feed-forward policy for each agent in the system. +class MAPPOFeedForwardExecutor(executors.FeedForwardExecutor): + """A feedforward executor for MAPPO. + An executor based on a feedforward policy for each agent in the system. """ def __init__( @@ -46,30 +46,29 @@ def __init__( variable_client: Optional[tf2_variable_utils.VariableClient] = None, ): """Initialise the system executor - Args: policy_networks (Dict[str, snt.Module]): policy networks for each agent in the system. - agent_net_keys: (dict, optional): specifies what network each agent uses. - Defaults to {}. adder (Optional[adders.ParallelAdder], optional): adder which sends data to a replay buffer. Defaults to None. variable_client (Optional[tf2_variable_utils.VariableClient], optional): client to copy weights from the trainer. Defaults to None. + agent_net_keys: (dict, optional): specifies what network each agent uses. + Defaults to {}. """ - # Store these for later use. - self._adder = adder - self._variable_client = variable_client - self._policy_networks = policy_networks - self._agent_net_keys = agent_net_keys - self._prev_log_probs: Dict[str, Any] = {} + super().__init__( + policy_networks=policy_networks, + agent_net_keys=agent_net_keys, + adder=adder, + variable_client=variable_client, + ) @tf.function def _policy( self, agent: str, observation: types.NestedTensor, - ) -> Tuple[types.NestedTensor, types.NestedTensor]: + ) -> Tuple[types.NestedTensor, types.NestedTensor, types.NestedTensor]: """Agent specific policy function Args: @@ -77,29 +76,33 @@ def _policy( observation (types.NestedTensor): observation tensor received from the environment. + Raises: + NotImplementedError: unknown action space + Returns: - Tuple[types.NestedTensor, types.NestedTensor]: log probabilities and action + Tuple[types.NestedTensor, types.NestedTensor, types.NestedTensor]: + action and policy logits """ - # Index network either on agent type or on agent id. - network_key = self._agent_net_keys[agent] - # Add a dummy batch dimension and as a side effect convert numpy to TF. - observation = tf2_utils.add_batch_dim(observation.observation) + batched_observation = tf2_utils.add_batch_dim(observation) + + # index network either on agent type or on agent id + agent_key = self._agent_net_keys[agent] # Compute the policy, conditioned on the observation. - policy = self._policy_networks[network_key](observation) + logits = self._policy_networks[agent_key](batched_observation) # Sample from the policy and compute the log likelihood. - action = policy.sample() - log_prob = policy.log_prob(action) + action = tfd.Categorical(logits).sample() + # action = tf2_utils.to_numpy_squeeze(action) # Cast for compatibility with reverb. # sample() returns a 'int32', which is a problem. - if isinstance(policy, tfp.distributions.Categorical): - action = tf.cast(action, "int64") + # if isinstance(policy, tfp.distributions.Categorical): + action = tf.cast(action, "int64") - return log_prob, action + return action, logits def select_action( self, agent: str, observation: types.NestedArray @@ -107,71 +110,50 @@ def select_action( """select an action for a single agent in the system Args: - agent (str): agent id. + agent (str): agent id observation (types.NestedArray): observation tensor received from the environment. Returns: - types.NestedArray: agent action + types.NestedArray: action and policy. """ - - # Step the recurrent policy/value network forward - # given the current observation and state. - self._prev_log_probs[agent], action = self._policy(agent, observation) + # Step the recurrent policy forward given the current observation and state. + action, logits = self._policy( + agent, + observation.observation, + ) # Return a numpy array with squeezed out batch dimension. action = tf2_utils.to_numpy_squeeze(action) - - # TODO(Kale-ab) : Remove. This is for debugging. - if np.isnan(action).any(): - print( - f"Value error- Log Probs:{self._prev_log_probs[agent]} Action: {action} " # noqa: E501 - ) - - return action + logits = tf2_utils.to_numpy_squeeze(logits) + return action, logits def select_actions( self, observations: Dict[str, types.NestedArray] - ) -> Dict[str, types.NestedArray]: + ) -> Tuple[Dict[str, types.NestedArray], Dict[str, types.NestedArray]]: """select the actions for all agents in the system Args: - observations (Dict[str, types.NestedArray]): agent observations from the + observations (Dict[str, types.NestedArray]): agent observations from the environment. Returns: - Dict[str, types.NestedArray]: actions and policies for all agents in - the system. + Tuple[Dict[str, types.NestedArray], Dict[str, types.NestedArray]]: + actions and policies for all agents in the system. """ + # TODO (dries): Add this to a function and add tf.function here. actions = {} + logits = {} for agent, observation in observations.items(): - action = self.select_action(agent, observation) - actions[agent] = action - - return actions - - def observe_first( - self, - timestep: dm_env.TimeStep, - extras: Dict[str, types.NestedArray] = {}, - ) -> None: - """record first observed timestep from the environment - - Args: - timestep (dm_env.TimeStep): data emitted by an environment at first step of - interaction. - extras (Dict[str, types.NestedArray], optional): possible extra information - to record during the first step. Defaults to {}. - """ - if self._adder: - self._adder.add_first(timestep) + actions[agent], logits[agent] = self.select_action(agent, observation) + return actions, logits def observe( self, actions: Dict[str, types.NestedArray], next_timestep: dm_env.TimeStep, - next_extras: Dict[str, types.NestedArray] = {}, + next_extras: Optional[Dict[str, types.NestedArray]] = {}, ) -> None: """record observed timestep from the environment @@ -183,33 +165,21 @@ def observe( information to record during the transition. Defaults to {}. """ - raise NotImplementedError( - "Need to fix the FeedForward case to reflect the recurrent updates." - ) - if not self._adder: return + actions, logits = actions - next_extras.update({"log_probs": self._prev_log_probs}) - - next_extras = tf2_utils.to_numpy_squeeze(next_extras) - - self._adder.add(actions, next_timestep, next_extras) - - def update(self, wait: bool = False) -> None: - """update executor variables - - Args: - wait (bool, optional): whether to stall the executor's request for new - variables. Defaults to False. - """ + adder_actions = {} + for agent in actions.keys(): + adder_actions[agent] = {} + adder_actions[agent]["actions"] = actions[agent] + adder_actions[agent]["logits"] = logits[agent] - if self._variable_client: - self._variable_client.update(wait) + self._adder.add(adder_actions, next_timestep, next_extras) # type: ignore class MAPPORecurrentExecutor(executors.RecurrentExecutor): - """A recurrent executor for MADDPG. + """A recurrent executor for MAPPO. An executor based on a recurrent policy for each agent in the system. """ @@ -332,10 +302,10 @@ def select_actions( # TODO (dries): Add this to a function and add tf.function here. actions = {} - log_probs = {} + logits = {} for agent, observation in observations.items(): - actions[agent], log_probs[agent] = self.select_action(agent, observation) - return actions, log_probs + actions[agent], logits[agent] = self.select_action(agent, observation) + return actions, logits def observe( self, @@ -355,13 +325,13 @@ def observe( if not self._adder: return - actions, log_probs = actions + actions, logits = actions adder_actions = {} for agent in actions.keys(): adder_actions[agent] = {} adder_actions[agent]["actions"] = actions[agent] - adder_actions[agent]["log_probs"] = log_probs[agent] + adder_actions[agent]["logits"] = logits[agent] numpy_states = { agent: tf2_utils.to_numpy_squeeze(_state) diff --git a/mava/systems/tf/mappo/training.py b/mava/systems/tf/mappo/training.py index 7f73ed916..f32a36002 100644 --- a/mava/systems/tf/mappo/training.py +++ b/mava/systems/tf/mappo/training.py @@ -270,7 +270,8 @@ def _step( # Do a forward and backwards pass using tf.function. return self.forward_backward(inputs) - @tf.function + print("Add this tf.function in again!") + # @tf.function def forward_backward(self, inputs: Any): self._forward_pass(inputs) self._backward_pass() @@ -336,32 +337,33 @@ def _forward_pass(self, inputs: Any) -> None: if "core_states" in extras: # Unroll current policy over actor_observation. - # policy, _ = snt.static_unroll(policy_network, actor_observation, - # core_states) - - # transposed_obs = tf2_utils.batch_to_sequence(actor_observation) agent_core_state = core_states[agent][0] + # Question (dries): Maybe we can save training time by + # only unroll the LSTM part and doing a batch forward + # pass for the feedforward part of the policy? logits, updated_states = snt.static_unroll( policy_network, actor_observation, agent_core_state, ) - # logits = tf2_utils.batch_to_sequence(logits) - else: - logits = policy_network(actor_observation) - - # Reshape the outputs. - # policy = tfd.BatchReshape(policy, batch_shape=dims, name="policy") + # Pass observations through the feedforward policy + actor_observation, dims = train_utils.combine_dim(actor_observation) + logits = train_utils.extract_dim( + policy_network(actor_observation), dims + ) # Compute importance sampling weights: current policy / behavior policy. pi_behaviour = tfd.Categorical(logits=behaviour_logits[:-1]) pi_target = tfd.Categorical(logits=logits[:-1]) # Calculate critic values. - critic_observation, dims = train_utils.combine_dim(critic_observation) - values = tf.reshape(critic_network(critic_observation), dims) - + dims = critic_observation.shape[:2] + critic_observation = tf.reshape( + critic_observation, (dims[0] * dims[1], -1) + ) + critic_output = critic_network(critic_observation) + values = tf.reshape(critic_output, dims) # Values along the sequence T. bootstrap_value = values[-1] state_values = values[:-1] @@ -413,8 +415,10 @@ def _forward_pass(self, inputs: Any) -> None: # Multiply by discounts to not train on padded data. loss_mask = discount > 0.0 # TODO (dries): Is multiplication maybe better here? As assignment might not work with tf.function? - policy_losses[agent] = tf.reduce_mean(policy_loss[loss_mask]) - critic_losses[agent] = tf.reduce_mean(critic_loss[loss_mask]) + pol_loss = policy_loss[loss_mask] + crit_loss = critic_loss[loss_mask] + policy_losses[agent] = tf.reduce_mean(pol_loss) + critic_losses[agent] = tf.reduce_mean(crit_loss) self.policy_losses = policy_losses self.critic_losses = critic_losses diff --git a/mava/utils/debugging/make_env.py b/mava/utils/debugging/make_env.py index 69065026b..96fdd6d23 100644 --- a/mava/utils/debugging/make_env.py +++ b/mava/utils/debugging/make_env.py @@ -22,7 +22,11 @@ def make_debugging_env( - scenario_name: str, action_space: Space, num_agents: int, seed: Optional[int] = None + scenario_name: str, + action_space: Space, + num_agents: int, + recurrent_test: bool, + seed: Optional[int] = None, ) -> MultiAgentEnv: """ Creates a MultiAgentEnv object as env. This can be used similar to a gym @@ -40,7 +44,7 @@ def make_debugging_env( """ # load scenario from script - scenario = scenarios.load(scenario_name).Scenario() + scenario = scenarios.load(scenario_name).Scenario(recurrent_test=recurrent_test) if seed: scenario.seed(seed) diff --git a/mava/utils/debugging/scenarios/simple_spread.py b/mava/utils/debugging/scenarios/simple_spread.py index ec0ff3b25..c09992b1c 100644 --- a/mava/utils/debugging/scenarios/simple_spread.py +++ b/mava/utils/debugging/scenarios/simple_spread.py @@ -24,9 +24,10 @@ class Scenario(BaseScenario): - def __init__(self) -> None: + def __init__(self, recurrent_test) -> None: super().__init__() self.np_rnd = np.random.RandomState() + self.recurrent_test = recurrent_test def make_world(self, num_agents: int) -> World: world = World() @@ -102,7 +103,11 @@ def reward(self, agent: Agent, a_i: int, world: World) -> float: def observation(self, agent: Agent, a_i: int, world: World) -> np.array: # get the position of the agent's target landmark - target_landmark = world.landmarks[a_i].state.p_pos - agent.state.p_pos + if world.current_step < 5 or not self.recurrent_test: + target_landmark = world.landmarks[a_i].state.p_pos - agent.state.p_pos + else: + # Only provide the target landmark initially if testing agent memory. + target_landmark = [0.0, 0.0] # Get the other agent and landmark positions other_agents_pos = [] @@ -111,10 +116,14 @@ def observation(self, agent: Agent, a_i: int, world: World) -> np.array: if other_agent is agent: continue other_agents_pos.append(other_agent.state.p_pos - agent.state.p_pos) - other_landmarks_pos.append( - world.landmarks[i].state.p_pos - agent.state.p_pos - ) + if world.current_step == 0 or not self.recurrent_test: + other_landmarks_pos.append( + world.landmarks[i].state.p_pos - agent.state.p_pos + ) + else: + # Only provide the target landmark initially if testing agent memory. + other_landmarks_pos.append([0.0, 0.0]) return np.concatenate( [agent.state.p_vel] + [agent.state.p_pos] diff --git a/mava/utils/environments/debugging_utils.py b/mava/utils/environments/debugging_utils.py index 580998591..59288c0e7 100644 --- a/mava/utils/environments/debugging_utils.py +++ b/mava/utils/environments/debugging_utils.py @@ -34,6 +34,7 @@ def make_environment( render: bool = False, return_state_info: bool = False, random_seed: Optional[int] = None, + recurrent_test: bool = False, ) -> dm_env.Environment: assert action_space == "continuous" or action_space == "discrete" @@ -50,7 +51,9 @@ def make_environment( return environment_fn else: """Creates a MPE environment.""" - env_module = make_debugging_env(env_name, action_space, num_agents, random_seed) + env_module = make_debugging_env( + env_name, action_space, num_agents, recurrent_test, random_seed + ) environment = DebuggingEnvWrapper( env_module, return_state_info=return_state_info ) From a2bbc861559fb147fd6a7469781d1c1470f0f981 Mon Sep 17 00:00:00 2001 From: Dries Date: Sun, 3 Oct 2021 19:36:10 +0200 Subject: [PATCH 009/104] Add code to MAD4PG. --- .../recurrent/state_based/run_mappo.py | 2 +- mava/components/tf/architectures/__init__.py | 1 + .../tf/architectures/state_based.py | 94 ++++++++++++++++++- mava/systems/tf/mad4pg/__init__.py | 1 + mava/systems/tf/mad4pg/training.py | 59 ++++++++++++ mava/systems/tf/maddpg/__init__.py | 1 + mava/systems/tf/maddpg/training.py | 92 +++++++++++++++++- mava/systems/tf/mappo/system.py | 2 +- mava/systems/tf/mappo/training.py | 11 +-- 9 files changed, 249 insertions(+), 14 deletions(-) diff --git a/examples/debugging/simple_spread/recurrent/state_based/run_mappo.py b/examples/debugging/simple_spread/recurrent/state_based/run_mappo.py index 2b8e5715f..3a94bdaf6 100644 --- a/examples/debugging/simple_spread/recurrent/state_based/run_mappo.py +++ b/examples/debugging/simple_spread/recurrent/state_based/run_mappo.py @@ -52,7 +52,7 @@ def main(_: Any) -> None: - recurrent_test = True + recurrent_test = False recurrent_ppo = True # Environment. diff --git a/mava/components/tf/architectures/__init__.py b/mava/components/tf/architectures/__init__.py index a84167fce..fe55a49f2 100644 --- a/mava/components/tf/architectures/__init__.py +++ b/mava/components/tf/architectures/__init__.py @@ -41,4 +41,5 @@ StateBasedValueActorCritic, StateBasedQValueActorCritic, StateBasedQValueCritic, + StateBasedQValueDecentralActionCritic, ) diff --git a/mava/components/tf/architectures/state_based.py b/mava/components/tf/architectures/state_based.py index 09f37e6d0..73db11f25 100644 --- a/mava/components/tf/architectures/state_based.py +++ b/mava/components/tf/architectures/state_based.py @@ -99,11 +99,27 @@ def _get_critic_specs( # Create one critic per agent. Each critic gets # absolute state information of the environment. - critic_state_shape = self._env_spec.get_extra_specs()["env_states"].shape - critic_obs_spec = tf.TensorSpec( - shape=critic_state_shape, - dtype=tf.dtypes.float32, - ) + critic_env_state_spec = self._env_spec.get_extra_specs()["env_states"] + if type(critic_env_state_spec) == dict: + critic_env_state_spec = list(critic_env_state_spec.values())[0] + + if type(critic_env_state_spec) != list: + critic_env_state_spec = [critic_env_state_spec] + + critic_obs_spec = [] + for spec in critic_env_state_spec: + critic_obs_spec.append( + tf.TensorSpec( + shape=spec.shape, + dtype=tf.dtypes.float32, + ) + ) + + # TODO (dries): Fix this + assert len(critic_obs_spec) == 1 + critic_obs_spec = critic_obs_spec[0] + + for agent_type, agents in agents_by_type.items(): critic_act_shape = list( copy.copy(self._agent_specs[agents[0]].actions.shape) @@ -201,3 +217,71 @@ def _get_critic_specs( # Get observation spec for critic. critic_obs_specs[agent_key] = critic_obs_spec return critic_obs_specs, None + +class StateBasedQValueDecentralActionCritic(DecentralisedQValueActorCritic): + """Multi-agent actor critic architecture with a critic using + environment state information. For this state-based critic + only one action gets fed in. This allows the critic + to only focus on the one agent's reward function, but + requires that the state information is + egocentric.""" + + def __init__( + self, + environment_spec: mava_specs.MAEnvironmentSpec, + observation_networks: Dict[str, snt.Module], + policy_networks: Dict[str, snt.Module], + critic_networks: Dict[str, snt.Module], + agent_net_keys: Dict[str, str], + ): + super().__init__( + environment_spec=environment_spec, + observation_networks=observation_networks, + policy_networks=policy_networks, + critic_networks=critic_networks, + agent_net_keys=agent_net_keys, + ) + + def _get_critic_specs( + self, + ) -> Tuple[Dict[str, acme_specs.Array], Dict[str, acme_specs.Array]]: + action_specs_per_type: Dict[str, acme_specs.Array] = {} + + agents_by_type = self._env_spec.get_agents_by_type() + + # Create one critic per agent. Each critic gets + # absolute state information of the environment. + critic_env_state_spec = self._env_spec.get_extra_specs()["env_states"] + if type(critic_env_state_spec) == dict: + critic_env_state_spec = list(critic_env_state_spec.values())[0] + + if type(critic_env_state_spec) != list: + critic_env_state_spec = [critic_env_state_spec] + + critic_obs_spec = [] + for spec in critic_env_state_spec: + critic_obs_spec.append( + tf.TensorSpec( + shape=spec.shape, + dtype=tf.dtypes.float32, + ) + ) + + for agent_type, agents in agents_by_type.items(): + # Only feed in the main agent's policy action. + critic_act_shape = list( + copy.copy(self._agent_specs[agents[0]].actions.shape) + ) + action_specs_per_type[agent_type] = tf.TensorSpec( + shape=critic_act_shape, + dtype=tf.dtypes.float32, + ) + + critic_obs_specs = {} + critic_act_specs = {} + for agent_key in self._agents: + agent_type = agent_key.split("_")[0] + # Get observation and action spec for critic. + critic_obs_specs[agent_key] = critic_obs_spec + critic_act_specs[agent_key] = action_specs_per_type[agent_type] + return critic_obs_specs, critic_act_specs diff --git a/mava/systems/tf/mad4pg/__init__.py b/mava/systems/tf/mad4pg/__init__.py index dea188ed3..96f51039a 100644 --- a/mava/systems/tf/mad4pg/__init__.py +++ b/mava/systems/tf/mad4pg/__init__.py @@ -28,4 +28,5 @@ MAD4PGDecentralisedTrainer, MAD4PGStateBasedRecurrentTrainer, MAD4PGStateBasedTrainer, + MAD4PGStateBasedDecentralActionRecurrentTrainer, ) diff --git a/mava/systems/tf/mad4pg/training.py b/mava/systems/tf/mad4pg/training.py index 14ac62706..99cc79b9a 100644 --- a/mava/systems/tf/mad4pg/training.py +++ b/mava/systems/tf/mad4pg/training.py @@ -38,6 +38,7 @@ MADDPGDecentralisedTrainer, MADDPGStateBasedRecurrentTrainer, MADDPGStateBasedTrainer, + MADDPGStateBasedDecentralActionRecurrentTrainer, ) from mava.utils import training_utils as train_utils @@ -810,3 +811,61 @@ def __init__( checkpoint_subpath=checkpoint_subpath, bootstrap_n=bootstrap_n, ) + +class MAD4PGStateBasedDecentralActionRecurrentTrainer( + MAD4PGBaseRecurrentTrainer, MADDPGStateBasedDecentralActionRecurrentTrainer +): + """Recurrent MAD4PG trainer for a state-based architecture.""" + + def __init__( + self, + agents: List[str], + agent_types: List[str], + policy_networks: Dict[str, snt.Module], + critic_networks: Dict[str, snt.Module], + target_policy_networks: Dict[str, snt.Module], + target_critic_networks: Dict[str, snt.Module], + policy_optimizer: Union[snt.Optimizer, Dict[str, snt.Optimizer]], + critic_optimizer: snt.Optimizer, + discount: float, + target_averaging: bool, + target_update_period: int, + target_update_rate: float, + dataset: tf.data.Dataset, + observation_networks: Dict[str, snt.Module], + target_observation_networks: Dict[str, snt.Module], + agent_net_keys: Dict[str, str], + checkpoint_minute_interval: int, + max_gradient_norm: float = None, + counter: counting.Counter = None, + logger: loggers.Logger = None, + checkpoint: bool = True, + checkpoint_subpath: str = "~/mava/", + bootstrap_n: int = 10, + ): + + super().__init__( + agents=agents, + agent_types=agent_types, + policy_networks=policy_networks, + critic_networks=critic_networks, + target_policy_networks=target_policy_networks, + target_critic_networks=target_critic_networks, + discount=discount, + target_averaging=target_averaging, + target_update_period=target_update_period, + target_update_rate=target_update_rate, + dataset=dataset, + observation_networks=observation_networks, + target_observation_networks=target_observation_networks, + agent_net_keys=agent_net_keys, + checkpoint_minute_interval=checkpoint_minute_interval, + policy_optimizer=policy_optimizer, + critic_optimizer=critic_optimizer, + max_gradient_norm=max_gradient_norm, + counter=counter, + logger=logger, + checkpoint=checkpoint, + checkpoint_subpath=checkpoint_subpath, + bootstrap_n=bootstrap_n, + ) \ No newline at end of file diff --git a/mava/systems/tf/maddpg/__init__.py b/mava/systems/tf/maddpg/__init__.py index 7399176f4..b853b41bc 100644 --- a/mava/systems/tf/maddpg/__init__.py +++ b/mava/systems/tf/maddpg/__init__.py @@ -29,4 +29,5 @@ MADDPGNetworkedTrainer, MADDPGStateBasedRecurrentTrainer, MADDPGStateBasedTrainer, + MADDPGStateBasedDecentralActionRecurrentTrainer ) diff --git a/mava/systems/tf/maddpg/training.py b/mava/systems/tf/maddpg/training.py index 2055ddda1..39ae2ca3c 100644 --- a/mava/systems/tf/maddpg/training.py +++ b/mava/systems/tf/maddpg/training.py @@ -1123,7 +1123,7 @@ def _update_target_networks(self) -> None: if tf.math.mod(self._num_steps, self._target_update_period) == 0: for src, dest in zip(online_variables, target_variables): dest.assign(src) - self._num_steps.assign_add(1) + self._num_steps.assign_add(1) def _transform_observations( self, observations: Dict[str, np.ndarray] @@ -1746,6 +1746,11 @@ def _get_critic_feed( # State based obs_trans_feed = extras["env_states"] target_obs_trans_feed = extras["env_states"] + if type(obs_trans_feed) == dict: + obs_trans_feed = obs_trans_feed[agent] + target_obs_trans_feed = target_obs_trans_feed[agent] + + actions_feed = tf.stack([actions[agent] for agent in self._agents], -1) target_actions_feed = tf.stack( [target_actions[agent] for agent in self._agents], -1 @@ -1770,3 +1775,88 @@ def _get_dpg_feed( tf.stack([dpg_actions_feed[agent] for agent in self._agents], -1) ) return dpg_actions_feed + +class MADDPGStateBasedDecentralActionRecurrentTrainer(MADDPGBaseRecurrentTrainer): + """Recurrent MADDPG trainer for a state-based architecture. + This is the trainer component of a MADDPG system. IE it takes a dataset as input + and implements update functionality to learn from this dataset. + """ + + def __init__( + self, + agents: List[str], + agent_types: List[str], + policy_networks: Dict[str, snt.Module], + critic_networks: Dict[str, snt.Module], + target_policy_networks: Dict[str, snt.Module], + target_critic_networks: Dict[str, snt.Module], + policy_optimizer: Union[snt.Optimizer, Dict[str, snt.Optimizer]], + critic_optimizer: snt.Optimizer, + discount: float, + target_averaging: bool, + target_update_period: int, + target_update_rate: float, + dataset: tf.data.Dataset, + observation_networks: Dict[str, snt.Module], + target_observation_networks: Dict[str, snt.Module], + agent_net_keys: Dict[str, str], + checkpoint_minute_interval: int, + max_gradient_norm: float = None, + counter: counting.Counter = None, + logger: loggers.Logger = None, + checkpoint: bool = True, + checkpoint_subpath: str = "~/mava/", + bootstrap_n: int = 10, + ): + + super().__init__( + agents=agents, + agent_types=agent_types, + policy_networks=policy_networks, + critic_networks=critic_networks, + target_policy_networks=target_policy_networks, + target_critic_networks=target_critic_networks, + discount=discount, + target_averaging=target_averaging, + target_update_period=target_update_period, + target_update_rate=target_update_rate, + dataset=dataset, + observation_networks=observation_networks, + target_observation_networks=target_observation_networks, + agent_net_keys=agent_net_keys, + checkpoint_minute_interval=checkpoint_minute_interval, + policy_optimizer=policy_optimizer, + critic_optimizer=critic_optimizer, + max_gradient_norm=max_gradient_norm, + counter=counter, + logger=logger, + checkpoint=checkpoint, + checkpoint_subpath=checkpoint_subpath, + bootstrap_n=bootstrap_n, + ) + + def _get_critic_feed( + self, + obs_trans: Dict[str, np.ndarray], + target_obs_trans: Dict[str, np.ndarray], + actions: Dict[str, np.ndarray], + target_actions: Dict[str, np.ndarray], + extras: Dict[str, np.ndarray], + agent: str, + ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]: + # State based + obs_trans_feed = extras["env_states"][agent] + target_obs_trans_feed = extras["env_states"][agent] + actions_feed = actions[agent] + target_actions_feed = target_actions[agent] + return obs_trans_feed, target_obs_trans_feed, actions_feed, target_actions_feed + + def _get_dpg_feed( + self, + target_actions: Dict[str, np.ndarray], + dpg_actions: np.ndarray, + agent: str, + ) -> tf.Tensor: + # Decentralised DPG + dpg_actions_feed = dpg_actions + return dpg_actions_feed \ No newline at end of file diff --git a/mava/systems/tf/mappo/system.py b/mava/systems/tf/mappo/system.py index f17408191..ee695f14b 100644 --- a/mava/systems/tf/mappo/system.py +++ b/mava/systems/tf/mappo/system.py @@ -70,7 +70,7 @@ def __init__( max_gradient_norm: Optional[float] = None, max_queue_size: int = 100000, batch_size: int = 32, - sequence_length: int = 50, + sequence_length: int = 20, sequence_period: int = 10, max_executor_steps: int = None, checkpoint: bool = True, diff --git a/mava/systems/tf/mappo/training.py b/mava/systems/tf/mappo/training.py index f32a36002..914b4c6db 100644 --- a/mava/systems/tf/mappo/training.py +++ b/mava/systems/tf/mappo/training.py @@ -270,8 +270,7 @@ def _step( # Do a forward and backwards pass using tf.function. return self.forward_backward(inputs) - print("Add this tf.function in again!") - # @tf.function + @tf.function def forward_backward(self, inputs: Any): self._forward_pass(inputs) self._backward_pass() @@ -415,10 +414,10 @@ def _forward_pass(self, inputs: Any) -> None: # Multiply by discounts to not train on padded data. loss_mask = discount > 0.0 # TODO (dries): Is multiplication maybe better here? As assignment might not work with tf.function? - pol_loss = policy_loss[loss_mask] - crit_loss = critic_loss[loss_mask] - policy_losses[agent] = tf.reduce_mean(pol_loss) - critic_losses[agent] = tf.reduce_mean(crit_loss) + policy_loss = policy_loss[loss_mask] + critic_loss = critic_loss[loss_mask] + policy_losses[agent] = tf.reduce_mean(policy_loss) + critic_losses[agent] = tf.reduce_mean(critic_loss) self.policy_losses = policy_losses self.critic_losses = critic_losses From 345fe1b2a8a603e47e622f01ff1e550f44e55907 Mon Sep 17 00:00:00 2001 From: Dries Date: Mon, 4 Oct 2021 08:39:25 +0200 Subject: [PATCH 010/104] Ready to run 2 vs 2 xray_attention. --- mava/systems/tf/maddpg/training.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mava/systems/tf/maddpg/training.py b/mava/systems/tf/maddpg/training.py index 39ae2ca3c..c4a473d29 100644 --- a/mava/systems/tf/maddpg/training.py +++ b/mava/systems/tf/maddpg/training.py @@ -203,8 +203,8 @@ def __init__( "target_policy": self._target_policy_networks[agent_key], "target_critic": self._target_critic_networks[agent_key], "target_observation": self._target_observation_networks[agent_key], - "policy_optimizer": self._policy_optimizers, - "critic_optimizer": self._critic_optimizers, + "policy_optimizer": self._policy_optimizers[agent_key], + "critic_optimizer": self._critic_optimizers[agent_key], "num_steps": self._num_steps, } @@ -248,7 +248,7 @@ def _update_target_networks(self) -> None: if tf.math.mod(self._num_steps, self._target_update_period) == 0: for src, dest in zip(online_variables, target_variables): dest.assign(src) - self._num_steps.assign_add(1) + self._num_steps.assign_add(1) def _transform_observations( self, obs: Dict[str, np.ndarray], next_obs: Dict[str, np.ndarray] @@ -1750,7 +1750,7 @@ def _get_critic_feed( obs_trans_feed = obs_trans_feed[agent] target_obs_trans_feed = target_obs_trans_feed[agent] - + actions_feed = tf.stack([actions[agent] for agent in self._agents], -1) target_actions_feed = tf.stack( [target_actions[agent] for agent in self._agents], -1 @@ -1859,4 +1859,4 @@ def _get_dpg_feed( ) -> tf.Tensor: # Decentralised DPG dpg_actions_feed = dpg_actions - return dpg_actions_feed \ No newline at end of file + return dpg_actions_feed From a25e97707bf65b38fb042295169137fd0ee32e9e Mon Sep 17 00:00:00 2001 From: Dries Date: Tue, 5 Oct 2021 09:54:23 +0200 Subject: [PATCH 011/104] Decrease queue size. --- mava/systems/tf/mappo/builder.py | 2 +- mava/systems/tf/mappo/system.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mava/systems/tf/mappo/builder.py b/mava/systems/tf/mappo/builder.py index 7370ebb4c..288136c3c 100644 --- a/mava/systems/tf/mappo/builder.py +++ b/mava/systems/tf/mappo/builder.py @@ -79,7 +79,7 @@ class MAPPOConfig: sequence_period: int = 5 discount: float = 0.99 lambda_gae: float = 0.95 - max_queue_size: int = 100_000 + max_queue_size: int = 1000 executor_variable_update_period: int = 100 batch_size: int = 32 entropy_cost: float = 0.01 diff --git a/mava/systems/tf/mappo/system.py b/mava/systems/tf/mappo/system.py index ee695f14b..8eda70092 100644 --- a/mava/systems/tf/mappo/system.py +++ b/mava/systems/tf/mappo/system.py @@ -68,7 +68,7 @@ def __init__( entropy_cost: float = 0.01, baseline_cost: float = 0.5, max_gradient_norm: Optional[float] = None, - max_queue_size: int = 100000, + max_queue_size: int = 1000, batch_size: int = 32, sequence_length: int = 20, sequence_period: int = 10, From 792ea3c6c4e66381d52288f42a9871c6b2c8e644 Mon Sep 17 00:00:00 2001 From: Dries Date: Fri, 29 Oct 2021 10:20:05 +0200 Subject: [PATCH 012/104] PPO seems to be training and running. --- .../tf/architectures/state_based.py | 2 +- mava/systems/tf/mad4pg/__init__.py | 1 - mava/systems/tf/mad4pg/training.py | 80 ---- mava/systems/tf/maddpg/__init__.py | 2 +- mava/systems/tf/maddpg/builder.py | 3 - mava/systems/tf/maddpg/execution.py | 22 +- mava/systems/tf/maddpg/system.py | 7 +- mava/systems/tf/maddpg/training.py | 20 - mava/systems/tf/mappo/__init__.py | 11 +- mava/systems/tf/mappo/builder.py | 329 +++++++++---- mava/systems/tf/mappo/execution.py | 125 ++++- mava/systems/tf/mappo/networks.py | 10 +- mava/systems/tf/mappo/system.py | 439 ++++++++++-------- mava/systems/tf/mappo/training.py | 115 ++--- .../debugging/scenarios/simple_spread.py | 2 +- 15 files changed, 681 insertions(+), 487 deletions(-) diff --git a/mava/components/tf/architectures/state_based.py b/mava/components/tf/architectures/state_based.py index 0f73c65f7..febdb452e 100644 --- a/mava/components/tf/architectures/state_based.py +++ b/mava/components/tf/architectures/state_based.py @@ -215,7 +215,7 @@ def _get_critic_specs( for agent_key in self._agents: # Get observation spec for critic. critic_obs_specs[agent_key] = critic_obs_spec - return critic_obs_specs, None + return critic_obs_specs, None # type: ignore class StateBasedQValueSingleActionCritic(DecentralisedQValueActorCritic): diff --git a/mava/systems/tf/mad4pg/__init__.py b/mava/systems/tf/mad4pg/__init__.py index e66950369..348474955 100644 --- a/mava/systems/tf/mad4pg/__init__.py +++ b/mava/systems/tf/mad4pg/__init__.py @@ -26,7 +26,6 @@ MAD4PGCentralisedTrainer, MAD4PGDecentralisedRecurrentTrainer, MAD4PGDecentralisedTrainer, - MAD4PGStateBasedDecentralActionRecurrentTrainer, MAD4PGStateBasedRecurrentTrainer, MAD4PGStateBasedSingleActionCriticRecurrentTrainer, MAD4PGStateBasedTrainer, diff --git a/mava/systems/tf/mad4pg/training.py b/mava/systems/tf/mad4pg/training.py index 4389e6004..aea5bdc53 100644 --- a/mava/systems/tf/mad4pg/training.py +++ b/mava/systems/tf/mad4pg/training.py @@ -36,7 +36,6 @@ MADDPGCentralisedTrainer, MADDPGDecentralisedRecurrentTrainer, MADDPGDecentralisedTrainer, - MADDPGStateBasedDecentralActionRecurrentTrainer, MADDPGStateBasedRecurrentTrainer, MADDPGStateBasedSingleActionCriticRecurrentTrainer, MADDPGStateBasedTrainer, @@ -56,7 +55,6 @@ class MAD4PGBaseTrainer(MADDPGBaseTrainer): def __init__( self, agents: List[str], - agent_types: List[str], policy_networks: Dict[str, snt.Module], critic_networks: Dict[str, snt.Module], target_policy_networks: Dict[str, snt.Module], @@ -79,7 +77,6 @@ def __init__( """Initialise MAD4PG trainer Args: agents: agent ids, e.g. "agent_0". - agent_types: agent types, e.g. "speaker" or "listener". policy_networks: policy networks for each agent in the system. critic_networks: critic network(s), shared or for @@ -113,7 +110,6 @@ def __init__( """Initialise the decentralised MADDPG trainer.""" super().__init__( agents=agents, - agent_types=agent_types, policy_networks=policy_networks, critic_networks=critic_networks, target_policy_networks=target_policy_networks, @@ -229,7 +225,6 @@ class MAD4PGDecentralisedTrainer(MAD4PGBaseTrainer, MADDPGDecentralisedTrainer): def __init__( self, agents: List[str], - agent_types: List[str], policy_networks: Dict[str, snt.Module], critic_networks: Dict[str, snt.Module], target_policy_networks: Dict[str, snt.Module], @@ -252,7 +247,6 @@ def __init__( """Initialise the decentralised MAD4PG trainer.""" super().__init__( agents=agents, - agent_types=agent_types, policy_networks=policy_networks, critic_networks=critic_networks, target_policy_networks=target_policy_networks, @@ -280,7 +274,6 @@ class MAD4PGCentralisedTrainer(MAD4PGBaseTrainer, MADDPGCentralisedTrainer): def __init__( self, agents: List[str], - agent_types: List[str], policy_networks: Dict[str, snt.Module], critic_networks: Dict[str, snt.Module], target_policy_networks: Dict[str, snt.Module], @@ -303,7 +296,6 @@ def __init__( """Initialise the centralised MAD4PG trainer.""" super().__init__( agents=agents, - agent_types=agent_types, policy_networks=policy_networks, critic_networks=critic_networks, target_policy_networks=target_policy_networks, @@ -331,7 +323,6 @@ class MAD4PGStateBasedTrainer(MAD4PGBaseTrainer, MADDPGStateBasedTrainer): def __init__( self, agents: List[str], - agent_types: List[str], policy_networks: Dict[str, snt.Module], critic_networks: Dict[str, snt.Module], target_policy_networks: Dict[str, snt.Module], @@ -354,7 +345,6 @@ def __init__( """Initialise the state-based MAD4PG trainer.""" super().__init__( agents=agents, - agent_types=agent_types, policy_networks=policy_networks, critic_networks=critic_networks, target_policy_networks=target_policy_networks, @@ -385,7 +375,6 @@ class MAD4PGBaseRecurrentTrainer(MADDPGBaseRecurrentTrainer): def __init__( self, agents: List[str], - agent_types: List[str], policy_networks: Dict[str, snt.Module], critic_networks: Dict[str, snt.Module], target_policy_networks: Dict[str, snt.Module], @@ -410,7 +399,6 @@ def __init__( Args: agents: agent ids, e.g. "agent_0". - agent_types: agent types, e.g. "speaker" or "listener". policy_networks: policy networks for each agent in the system. critic_networks: critic network(s), shared or for @@ -443,7 +431,6 @@ def __init__( super().__init__( agents=agents, - agent_types=agent_types, policy_networks=policy_networks, critic_networks=critic_networks, target_policy_networks=target_policy_networks, @@ -613,7 +600,6 @@ class MAD4PGDecentralisedRecurrentTrainer( def __init__( self, agents: List[str], - agent_types: List[str], policy_networks: Dict[str, snt.Module], critic_networks: Dict[str, snt.Module], target_policy_networks: Dict[str, snt.Module], @@ -637,7 +623,6 @@ def __init__( super().__init__( agents=agents, - agent_types=agent_types, policy_networks=policy_networks, critic_networks=critic_networks, target_policy_networks=target_policy_networks, @@ -668,7 +653,6 @@ class MAD4PGCentralisedRecurrentTrainer( def __init__( self, agents: List[str], - agent_types: List[str], policy_networks: Dict[str, snt.Module], critic_networks: Dict[str, snt.Module], target_policy_networks: Dict[str, snt.Module], @@ -692,7 +676,6 @@ def __init__( super().__init__( agents=agents, - agent_types=agent_types, policy_networks=policy_networks, critic_networks=critic_networks, target_policy_networks=target_policy_networks, @@ -723,7 +706,6 @@ class MAD4PGStateBasedRecurrentTrainer( def __init__( self, agents: List[str], - agent_types: List[str], policy_networks: Dict[str, snt.Module], critic_networks: Dict[str, snt.Module], target_policy_networks: Dict[str, snt.Module], @@ -747,7 +729,6 @@ def __init__( super().__init__( agents=agents, - agent_types=agent_types, policy_networks=policy_networks, critic_networks=critic_networks, target_policy_networks=target_policy_networks, @@ -778,7 +759,6 @@ class MAD4PGStateBasedSingleActionCriticRecurrentTrainer( def __init__( self, agents: List[str], - agent_types: List[str], policy_networks: Dict[str, snt.Module], critic_networks: Dict[str, snt.Module], target_policy_networks: Dict[str, snt.Module], @@ -802,7 +782,6 @@ def __init__( super().__init__( agents=agents, - agent_types=agent_types, policy_networks=policy_networks, critic_networks=critic_networks, target_policy_networks=target_policy_networks, @@ -823,62 +802,3 @@ def __init__( counts=counts, bootstrap_n=bootstrap_n, ) - - -class MAD4PGStateBasedDecentralActionRecurrentTrainer( - MAD4PGBaseRecurrentTrainer, MADDPGStateBasedDecentralActionRecurrentTrainer -): - """Recurrent MAD4PG trainer for a state-based architecture.""" - - def __init__( - self, - agents: List[str], - agent_types: List[str], - policy_networks: Dict[str, snt.Module], - critic_networks: Dict[str, snt.Module], - target_policy_networks: Dict[str, snt.Module], - target_critic_networks: Dict[str, snt.Module], - policy_optimizer: Union[snt.Optimizer, Dict[str, snt.Optimizer]], - critic_optimizer: snt.Optimizer, - discount: float, - target_averaging: bool, - target_update_period: int, - target_update_rate: float, - dataset: tf.data.Dataset, - observation_networks: Dict[str, snt.Module], - target_observation_networks: Dict[str, snt.Module], - agent_net_keys: Dict[str, str], - checkpoint_minute_interval: int, - max_gradient_norm: float = None, - counter: counting.Counter = None, - logger: loggers.Logger = None, - checkpoint: bool = True, - checkpoint_subpath: str = "~/mava/", - bootstrap_n: int = 10, - ): - - super().__init__( - agents=agents, - agent_types=agent_types, - policy_networks=policy_networks, - critic_networks=critic_networks, - target_policy_networks=target_policy_networks, - target_critic_networks=target_critic_networks, - discount=discount, - target_averaging=target_averaging, - target_update_period=target_update_period, - target_update_rate=target_update_rate, - dataset=dataset, - observation_networks=observation_networks, - target_observation_networks=target_observation_networks, - agent_net_keys=agent_net_keys, - checkpoint_minute_interval=checkpoint_minute_interval, - policy_optimizer=policy_optimizer, - critic_optimizer=critic_optimizer, - max_gradient_norm=max_gradient_norm, - counter=counter, - logger=logger, - checkpoint=checkpoint, - checkpoint_subpath=checkpoint_subpath, - bootstrap_n=bootstrap_n, - ) diff --git a/mava/systems/tf/maddpg/__init__.py b/mava/systems/tf/maddpg/__init__.py index 533ccbabf..aaab158ad 100644 --- a/mava/systems/tf/maddpg/__init__.py +++ b/mava/systems/tf/maddpg/__init__.py @@ -29,7 +29,7 @@ MADDPGDecentralisedRecurrentTrainer, MADDPGDecentralisedTrainer, MADDPGNetworkedTrainer, - MADDPGStateBasedDecentralActionRecurrentTrainer, MADDPGStateBasedRecurrentTrainer, + MADDPGStateBasedSingleActionCriticRecurrentTrainer, MADDPGStateBasedTrainer, ) diff --git a/mava/systems/tf/maddpg/builder.py b/mava/systems/tf/maddpg/builder.py index 0924ae70b..59e57fccc 100644 --- a/mava/systems/tf/maddpg/builder.py +++ b/mava/systems/tf/maddpg/builder.py @@ -147,7 +147,6 @@ def __init__( self._config = config self._extra_specs = extra_specs self._agents = self._config.environment_spec.get_agent_ids() - self._agent_types = self._config.environment_spec.get_agent_types() self._trainer_fn = trainer_fn self._executor_fn = executor_fn @@ -507,7 +506,6 @@ def make_trainer( executors to update the parameters of the agent networks in the system. """ # This assumes agents are sort_str_num in the other methods - agent_types = self._agent_types max_gradient_norm = self._config.max_gradient_norm discount = self._config.discount target_update_period = self._config.target_update_period @@ -561,7 +559,6 @@ def make_trainer( trainer_config: Dict[str, Any] = { "agents": trainer_agents, - "agent_types": agent_types, "policy_networks": networks["policies"], "critic_networks": networks["critics"], "observation_networks": networks["observations"], diff --git a/mava/systems/tf/maddpg/execution.py b/mava/systems/tf/maddpg/execution.py index 1be76ef7f..3083a4997 100644 --- a/mava/systems/tf/maddpg/execution.py +++ b/mava/systems/tf/maddpg/execution.py @@ -389,12 +389,11 @@ def observe_first( self._net_keys_to_ids, ) - if self._store_recurrent_state: - numpy_states = { - agent: tf2_utils.to_numpy_squeeze(_state) - for agent, _state in self._states.items() - } - extras.update({"core_states": numpy_states}) + numpy_states = { + agent: tf2_utils.to_numpy_squeeze(_state) + for agent, _state in self._states.items() + } + extras.update({"core_states": numpy_states}) extras["network_int_keys"] = self._network_int_keys_extras self._adder.add_first(timestep, extras) @@ -417,12 +416,11 @@ def observe( return _, policy = actions - if self._store_recurrent_state: - numpy_states = { - agent: tf2_utils.to_numpy_squeeze(_state) - for agent, _state in self._states.items() - } - next_extras.update({"core_states": numpy_states}) + numpy_states = { + agent: tf2_utils.to_numpy_squeeze(_state) + for agent, _state in self._states.items() + } + next_extras.update({"core_states": numpy_states}) next_extras["network_int_keys"] = self._network_int_keys_extras self._adder.add(policy, next_timestep, next_extras) # type: ignore diff --git a/mava/systems/tf/maddpg/system.py b/mava/systems/tf/maddpg/system.py index c0241b949..1625d7a53 100644 --- a/mava/systems/tf/maddpg/system.py +++ b/mava/systems/tf/maddpg/system.py @@ -38,14 +38,11 @@ from mava.environment_loop import ParallelEnvironmentLoop from mava.systems.tf import executors from mava.systems.tf.maddpg import builder, training -from mava.systems.tf.maddpg.execution import ( - MADDPGFeedForwardExecutor, - sample_new_agent_keys, -) +from mava.systems.tf.maddpg.execution import MADDPGFeedForwardExecutor from mava.systems.tf.variable_sources import VariableSource as MavaVariableSource from mava.utils import enums from mava.utils.loggers import MavaLogger, logger_utils -from mava.utils.sort_utils import sort_str_num +from mava.utils.sort_utils import sample_new_agent_keys, sort_str_num from mava.wrappers import DetailedPerAgentStatistics diff --git a/mava/systems/tf/maddpg/training.py b/mava/systems/tf/maddpg/training.py index 5cee8abe2..f0cad64d6 100644 --- a/mava/systems/tf/maddpg/training.py +++ b/mava/systems/tf/maddpg/training.py @@ -50,7 +50,6 @@ class MADDPGBaseTrainer(mava.Trainer): def __init__( self, agents: List[str], - agent_types: List[str], policy_networks: Dict[str, snt.Module], critic_networks: Dict[str, snt.Module], target_policy_networks: Dict[str, snt.Module], @@ -73,7 +72,6 @@ def __init__( """Initialise MADDPG trainer Args: agents: agent ids, e.g. "agent_0". - agent_types: agent types, e.g. "speaker" or "listener". policy_networks: policy networks for each agent in the system. critic_networks: critic network(s), shared or for @@ -515,7 +513,6 @@ class MADDPGDecentralisedTrainer(MADDPGBaseTrainer): def __init__( self, agents: List[str], - agent_types: List[str], policy_networks: Dict[str, snt.Module], critic_networks: Dict[str, snt.Module], target_policy_networks: Dict[str, snt.Module], @@ -538,7 +535,6 @@ def __init__( """Initialise the decentralised MADDPG trainer.""" super().__init__( agents=agents, - agent_types=agent_types, policy_networks=policy_networks, critic_networks=critic_networks, target_policy_networks=target_policy_networks, @@ -566,7 +562,6 @@ class MADDPGCentralisedTrainer(MADDPGBaseTrainer): def __init__( self, agents: List[str], - agent_types: List[str], policy_networks: Dict[str, snt.Module], critic_networks: Dict[str, snt.Module], target_policy_networks: Dict[str, snt.Module], @@ -589,7 +584,6 @@ def __init__( """Initialise the centralised MADDPG trainer.""" super().__init__( agents=agents, - agent_types=agent_types, policy_networks=policy_networks, critic_networks=critic_networks, target_policy_networks=target_policy_networks, @@ -656,7 +650,6 @@ class MADDPGNetworkedTrainer(MADDPGBaseTrainer): def __init__( self, agents: List[str], - agent_types: List[str], connection_spec: Dict[str, List[str]], policy_networks: Dict[str, snt.Module], critic_networks: Dict[str, snt.Module], @@ -680,7 +673,6 @@ def __init__( """Initialise the networked MADDPG trainer.""" super().__init__( agents=agents, - agent_types=agent_types, policy_networks=policy_networks, critic_networks=critic_networks, target_policy_networks=target_policy_networks, @@ -765,7 +757,6 @@ class MADDPGStateBasedTrainer(MADDPGBaseTrainer): def __init__( self, agents: List[str], - agent_types: List[str], policy_networks: Dict[str, snt.Module], critic_networks: Dict[str, snt.Module], target_policy_networks: Dict[str, snt.Module], @@ -788,7 +779,6 @@ def __init__( """Initialise the decentralised MADDPG trainer.""" super().__init__( agents=agents, - agent_types=agent_types, policy_networks=policy_networks, critic_networks=critic_networks, target_policy_networks=target_policy_networks, @@ -858,7 +848,6 @@ class MADDPGBaseRecurrentTrainer(mava.Trainer): def __init__( self, agents: List[str], - agent_types: List[str], policy_networks: Dict[str, snt.Module], critic_networks: Dict[str, snt.Module], target_policy_networks: Dict[str, snt.Module], @@ -882,7 +871,6 @@ def __init__( """Initialise Recurrent MADDPG trainer Args: agents: agent ids, e.g. "agent_0". - agent_types: agent types, e.g. "speaker" or "listener". policy_networks: policy networks for each agent in the system. critic_networks: critic network(s), shared or for @@ -1423,7 +1411,6 @@ class MADDPGDecentralisedRecurrentTrainer(MADDPGBaseRecurrentTrainer): def __init__( self, agents: List[str], - agent_types: List[str], policy_networks: Dict[str, snt.Module], critic_networks: Dict[str, snt.Module], target_policy_networks: Dict[str, snt.Module], @@ -1447,7 +1434,6 @@ def __init__( super().__init__( agents=agents, - agent_types=agent_types, policy_networks=policy_networks, critic_networks=critic_networks, target_policy_networks=target_policy_networks, @@ -1479,7 +1465,6 @@ class MADDPGCentralisedRecurrentTrainer(MADDPGBaseRecurrentTrainer): def __init__( self, agents: List[str], - agent_types: List[str], policy_networks: Dict[str, snt.Module], critic_networks: Dict[str, snt.Module], target_policy_networks: Dict[str, snt.Module], @@ -1503,7 +1488,6 @@ def __init__( super().__init__( agents=agents, - agent_types=agent_types, policy_networks=policy_networks, critic_networks=critic_networks, target_policy_networks=target_policy_networks, @@ -1575,7 +1559,6 @@ class MADDPGStateBasedRecurrentTrainer(MADDPGBaseRecurrentTrainer): def __init__( self, agents: List[str], - agent_types: List[str], policy_networks: Dict[str, snt.Module], critic_networks: Dict[str, snt.Module], target_policy_networks: Dict[str, snt.Module], @@ -1599,7 +1582,6 @@ def __init__( super().__init__( agents=agents, - agent_types=agent_types, policy_networks=policy_networks, critic_networks=critic_networks, target_policy_networks=target_policy_networks, @@ -1673,7 +1655,6 @@ class MADDPGStateBasedSingleActionCriticRecurrentTrainer(MADDPGBaseRecurrentTrai def __init__( self, agents: List[str], - agent_types: List[str], policy_networks: Dict[str, snt.Module], critic_networks: Dict[str, snt.Module], target_policy_networks: Dict[str, snt.Module], @@ -1697,7 +1678,6 @@ def __init__( super().__init__( agents=agents, - agent_types=agent_types, policy_networks=policy_networks, critic_networks=critic_networks, target_policy_networks=target_policy_networks, diff --git a/mava/systems/tf/mappo/__init__.py b/mava/systems/tf/mappo/__init__.py index af3ccdb05..b9bbe99fd 100644 --- a/mava/systems/tf/mappo/__init__.py +++ b/mava/systems/tf/mappo/__init__.py @@ -13,7 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from mava.systems.tf.mappo.execution import MAPPOFeedForwardExecutor, MAPPORecurrentExecutor +from mava.systems.tf.mappo.execution import ( + MAPPOFeedForwardExecutor, + MAPPORecurrentExecutor, +) from mava.systems.tf.mappo.networks import make_default_networks from mava.systems.tf.mappo.system import MAPPO -from mava.systems.tf.mappo.training import CentralisedMAPPOTrainer, MAPPOTrainer, StateBasedMAPPOTrainer +from mava.systems.tf.mappo.training import ( + CentralisedMAPPOTrainer, + MAPPOTrainer, + StateBasedMAPPOTrainer, +) diff --git a/mava/systems/tf/mappo/builder.py b/mava/systems/tf/mappo/builder.py index 288136c3c..daedc348d 100644 --- a/mava/systems/tf/mappo/builder.py +++ b/mava/systems/tf/mappo/builder.py @@ -22,16 +22,16 @@ import reverb import sonnet as snt import tensorflow as tf -from acme import datasets from acme.specs import EnvironmentSpec from acme.tf import utils as tf2_utils -from acme.tf import variable_utils -from acme.utils import counting from mava import adders, core, specs, types from mava.adders import reverb as reverb_adders +from mava.systems.tf import variable_utils from mava.systems.tf.mappo import execution, training -from mava.wrappers import DetailedTrainerStatistics +from mava.systems.tf.variable_sources import VariableSource as MavaVariableSource +from mava.utils.sort_utils import sort_str_num +from mava.wrappers import NetworkStatisticsActorCritic, ScaledDetailedTrainerStatistics @dataclasses.dataclass @@ -68,12 +68,23 @@ class MAPPOConfig: checkpoint: boolean to indicate whether to checkpoint models. checkpoint_subpath: subdirectory specifying where to store checkpoints. replay_table_name: string indicating what name to give the replay table. + termination_condition: An optional terminal condition can be provided + that stops the program once the condition is satisfied. Available options + include specifying maximum values for trainer_steps, trainer_walltime, + evaluator_steps, evaluator_episodes, executor_episodes or executor_steps. + E.g. termination_condition = {'trainer_steps': 100000}. """ environment_spec: specs.EnvironmentSpec policy_optimizer: Union[snt.Optimizer, Dict[str, snt.Optimizer]] critic_optimizer: snt.Optimizer + num_executors: int agent_net_keys: Dict[str, str] + trainer_networks: Dict[str, List] + table_network_config: Dict[str, List] + network_sampling_setup: List + net_keys_to_ids: Dict[str, int] + unique_net_keys: List[str] checkpoint_minute_interval: int sequence_length: int = 10 sequence_period: int = 5 @@ -89,6 +100,7 @@ class MAPPOConfig: checkpoint: bool = True checkpoint_subpath: str = "~/mava/" replay_table_name: str = reverb_adders.DEFAULT_PRIORITY_TABLE + termination_condition: Optional[Dict[str, int]] = None class MAPPOBuilder: @@ -116,10 +128,9 @@ def __init__( self._config = config self._agents = self._config.environment_spec.get_agent_ids() - self._agent_types = self._config.environment_spec.get_agent_types() self._trainer_fn = trainer_fn self._executor_fn = executor_fn - self._extras_specs = extra_specs + self._extra_specs = extra_specs def add_logits_to_spec( self, environment_spec: specs.MAEnvironmentSpec @@ -151,6 +162,21 @@ def add_logits_to_spec( ) return env_adder_spec + def covert_specs(self, spec: Dict[str, Any], num_networks: int) -> Dict[str, Any]: + if type(spec) is not dict: + return spec + + agents = sort_str_num(self._config.agent_net_keys.keys())[:num_networks] + converted_spec: Dict[str, Any] = {} + if agents[0] in spec.keys(): + for agent in agents: + converted_spec[agent] = spec[agent] + else: + # For the extras + for key in spec.keys(): + converted_spec[key] = self.covert_specs(spec[key], num_networks) + return converted_spec + def make_replay_tables( self, environment_spec: specs.MAEnvironmentSpec, @@ -168,21 +194,43 @@ def make_replay_tables( # Create system architecture with target networks. adder_env_spec = self.add_logits_to_spec(environment_spec) - replay_table = reverb.Table.queue( - name=self._config.replay_table_name, - max_size=self._config.max_queue_size, - signature=reverb_adders.ParallelSequenceAdder.signature( - adder_env_spec, - sequence_length=self._config.sequence_length, - extras_spec=self._extras_specs, - ), - ) + # Create table per trainer + replay_tables = [] + for table_key in self._config.table_network_config.keys(): + # TODO (dries): Clean the below coverter code up. + # Convert a Mava spec + num_networks = len(self._config.table_network_config[table_key]) + env_spec = copy.deepcopy(adder_env_spec) + env_spec._specs = self.covert_specs(env_spec._specs, num_networks) + + env_spec._keys = list(sort_str_num(env_spec._specs.keys())) + if env_spec.extra_specs is not None: + env_spec.extra_specs = self.covert_specs( + env_spec.extra_specs, num_networks + ) + extra_specs = self.covert_specs( + self._extra_specs, + num_networks, + ) - return [replay_table] + replay_tables.append( + reverb.Table.queue( + name=table_key, + max_size=self._config.max_queue_size, + signature=reverb_adders.ParallelSequenceAdder.signature( + env_spec, + sequence_length=self._config.sequence_length, + extras_spec=extra_specs, + ), + ) + ) + + return replay_tables def make_dataset_iterator( self, replay_client: reverb.Client, + table_name: str, ) -> Iterator[reverb.ReplaySample]: """Create a dataset iterator to use for training/updating the system. @@ -205,7 +253,7 @@ def make_dataset_iterator( dataset = reverb.TrajectoryDataset.from_table_signature( server_address=replay_client.server_address, - table=self._config.replay_table_name, + table=table_name, max_in_flight_samples_per_worker=1, ).batch( self._config.batch_size, drop_remainder=True @@ -231,59 +279,138 @@ def make_adder( Optional[adders.ParallelAdder]: adder which sends data to a replay buffer. """ + # Create custom priority functons for the adder + priority_fns = { + table_key: lambda x: 1.0 + for table_key in self._config.table_network_config.keys() + } + return reverb_adders.ParallelSequenceAdder( + priority_fns=priority_fns, client=replay_client, - period=self._config.sequence_period, + net_ids_to_keys=self._config.unique_net_keys, + table_network_config=self._config.table_network_config, sequence_length=self._config.sequence_length, + period=self._config.sequence_period, use_next_extras=True, ) + def create_counter_variables( + self, variables: Dict[str, tf.Variable] + ) -> Dict[str, tf.Variable]: + """Create counter variables. + Args: + variables: dictionary with variable_source + variables in. + Returns: + variables: dictionary with variable_source + variables in. + """ + variables["trainer_steps"] = tf.Variable(0, dtype=tf.int32) + variables["trainer_walltime"] = tf.Variable(0, dtype=tf.float32) + variables["evaluator_steps"] = tf.Variable(0, dtype=tf.int32) + variables["evaluator_episodes"] = tf.Variable(0, dtype=tf.int32) + variables["executor_episodes"] = tf.Variable(0, dtype=tf.int32) + variables["executor_steps"] = tf.Variable(0, dtype=tf.int32) + return variables + + def make_variable_server( + self, + networks: Dict[str, Dict[str, snt.Module]], + ) -> MavaVariableSource: + """Create the variable server. + Args: + networks: dictionary with the + system's networks in. + Returns: + variable_source: A Mava variable source object. + """ + # Create variables + variables = {} + # Network variables + for net_type_key in networks.keys(): + for net_key in networks[net_type_key].keys(): + # Ensure obs and target networks are sonnet modules + variables[f"{net_key}_{net_type_key}"] = tf2_utils.to_sonnet_module( + networks[net_type_key][net_key] + ).variables + + variables = self.create_counter_variables(variables) + + # Create variable source + variable_source = MavaVariableSource( + variables, + self._config.checkpoint, + self._config.checkpoint_subpath, + self._config.checkpoint_minute_interval, + self._config.termination_condition, + ) + return variable_source + def make_executor( self, + # executor_id: str, + networks: Dict[str, snt.Module], policy_networks: Dict[str, snt.Module], adder: Optional[adders.ParallelAdder] = None, - variable_source: Optional[core.VariableSource] = None, + variable_source: Optional[MavaVariableSource] = None, ) -> core.Executor: - """Create an executor instance. - Args: - policy_networks (Dict[str, snt.Module]): policy networks for each agent in + networks: dictionary with the system's networks in. + policy_networks: policy networks for each agent in the system. - adder (Optional[adders.ParallelAdder], optional): adder to send data to + adder: adder to send data to a replay buffer. Defaults to None. - variable_source (Optional[core.VariableSource], optional): variables server. + variable_source: variables server. Defaults to None. - Returns: - core.Executor: system executor, a collection of agents making up the part + system executor, a collection of agents making up the part of the system generating data by interacting the environment. """ + # Create policy variables + variables = {} + get_keys = [] + for net_type_key in ["observations", "policies"]: + for net_key in networks[net_type_key].keys(): + var_key = f"{net_key}_{net_type_key}" + variables[var_key] = networks[net_type_key][net_key].variables + get_keys.append(var_key) + variables = self.create_counter_variables(variables) + + count_names = [ + "trainer_steps", + "trainer_walltime", + "evaluator_steps", + "evaluator_episodes", + "executor_episodes", + "executor_steps", + ] + get_keys.extend(count_names) + counts = {name: variables[name] for name in count_names} variable_client = None if variable_source: - # Create policy variables. - variables = { - network_key: policy_networks[network_key].variables - for network_key in policy_networks - } - - # Get new policy variables. - # TODO Should the variable update period be in the MAPPO config? + # Get new policy variables variable_client = variable_utils.VariableClient( client=variable_source, - variables={"policy": variables}, + variables=variables, + get_keys=get_keys, update_period=self._config.executor_variable_update_period, ) # Make sure not to use a random policy after checkpoint restoration by # assigning variables before running the environment loop. - variable_client.update_and_wait() + variable_client.get_and_wait() - # Create the executor which defines how agents take actions. + # Create the actor which defines how we take actions. return self._executor_fn( policy_networks=policy_networks, + counts=counts, + net_keys_to_ids=self._config.net_keys_to_ids, + agent_specs=self._config.environment_spec.get_agent_specs(), agent_net_keys=self._config.agent_net_keys, + network_sampling_setup=self._config.network_sampling_setup, variable_client=variable_client, adder=adder, ) @@ -292,60 +419,104 @@ def make_trainer( self, networks: Dict[str, Dict[str, snt.Module]], dataset: Iterator[reverb.ReplaySample], - can_sample: Any, - counter: Optional[counting.Counter] = None, + variable_source: MavaVariableSource, + trainer_networks: List[Any], + trainer_table_entry: List[Any], logger: Optional[types.NestedLogger] = None, + connection_spec: Dict[str, List[str]] = None, ) -> core.Trainer: """Create a trainer instance. - Args: - networks (Dict[str, Dict[str, snt.Module]]): system networks. - dataset (Iterator[reverb.ReplaySample]): dataset iterator to feed data to + networks: system networks. + dataset: dataset iterator to feed data to the trainer networks. - counter (Optional[counting.Counter], optional): a Counter which allows for - recording of counts, e.g. trainer steps. Defaults to None. - logger (Optional[types.NestedLogger], optional): Logger object for logging - metadata. Defaults to None. - + variable_source: Source with variables in. + trainer_networks: Set of unique network keys to train on.. + trainer_table_entry: List of networks per agent to train on. + logger: Logger object for logging metadata. + connection_spec: connection topology used + for networked system architectures. Defaults to None. Returns: - core.Trainer: system trainer, that uses the collected data from the + system trainer, that uses the collected data from the executors to update the parameters of the agent networks in the system. """ + # This assumes agents are sort_str_num in the other methods + max_gradient_norm = self._config.max_gradient_norm + + # Create variable client + variables = {} + set_keys = [] + get_keys = [] + # TODO (dries): Only add the networks this trainer is working with. + # Not all of them. + for net_type_key in networks.keys(): + for net_key in networks[net_type_key].keys(): + variables[f"{net_key}_{net_type_key}"] = networks[net_type_key][ + net_key + ].variables + if net_key in set(trainer_networks): + set_keys.append(f"{net_key}_{net_type_key}") + else: + get_keys.append(f"{net_key}_{net_type_key}") + + variables = self.create_counter_variables(variables) + count_names = [ + "trainer_steps", + "trainer_walltime", + "evaluator_steps", + "evaluator_episodes", + "executor_episodes", + "executor_steps", + ] + get_keys.extend(count_names) + counts = {name: variables[name] for name in count_names} + + variable_client = variable_utils.VariableClient( + client=variable_source, + variables=variables, + get_keys=get_keys, + set_keys=set_keys, + update_period=10, + ) - agents = self._agents - agent_types = self._agent_types - agent_net_keys = self._config.agent_net_keys - - observation_networks = networks["observations"] - policy_networks = networks["policies"] - critic_networks = networks["critics"] + # Get all the initial variables + variable_client.get_all_and_wait() + + # Convert network keys for the trainer. + trainer_agents = self._agents[: len(trainer_table_entry)] + trainer_agent_net_keys = { + agent: trainer_table_entry[a_i] for a_i, agent in enumerate(trainer_agents) + } + + trainer_config: Dict[str, Any] = { + "agents": trainer_agents, + "policy_networks": networks["policies"], + "critic_networks": networks["critics"], + "observation_networks": networks["observations"], + "agent_net_keys": trainer_agent_net_keys, + "policy_optimizer": self._config.policy_optimizer, + "critic_optimizer": self._config.critic_optimizer, + "max_gradient_norm": max_gradient_norm, + "discount": self._config.discount, + "variable_client": variable_client, + "dataset": dataset, + "counts": counts, + "logger": logger, + "lambda_gae": self._config.lambda_gae, + "entropy_cost": self._config.entropy_cost, + "baseline_cost": self._config.baseline_cost, + "clipping_epsilon": self._config.clipping_epsilon, + } # The learner updates the parameters (and initializes them). - trainer = self._trainer_fn( - agents=agents, - agent_types=agent_types, - can_sample=can_sample, - observation_networks=observation_networks, - policy_networks=policy_networks, - critic_networks=critic_networks, - dataset=dataset, - agent_net_keys=agent_net_keys, - critic_optimizer=self._config.critic_optimizer, - policy_optimizer=self._config.policy_optimizer, - discount=self._config.discount, - lambda_gae=self._config.lambda_gae, - entropy_cost=self._config.entropy_cost, - baseline_cost=self._config.baseline_cost, - clipping_epsilon=self._config.clipping_epsilon, - max_gradient_norm=self._config.max_gradient_norm, - counter=counter, - logger=logger, - checkpoint_minute_interval=self._config.checkpoint_minute_interval, - checkpoint=self._config.checkpoint, - checkpoint_subpath=self._config.checkpoint_subpath, - ) + trainer = self._trainer_fn(**trainer_config) + + # NB If using both NetworkStatistics and TrainerStatistics, order is important. + # NetworkStatistics needs to appear before TrainerStatistics. + # TODO(Kale-ab/Arnu): need to fix wrapper type issues + trainer = NetworkStatisticsActorCritic(trainer) # type: ignore - trainer = DetailedTrainerStatistics( # type: ignore + trainer = ScaledDetailedTrainerStatistics( # type: ignore trainer, metrics=["policy_loss", "critic_loss"] ) diff --git a/mava/systems/tf/mappo/execution.py b/mava/systems/tf/mappo/execution.py index e0eb12072..52628cd6b 100644 --- a/mava/systems/tf/mappo/execution.py +++ b/mava/systems/tf/mappo/execution.py @@ -15,7 +15,7 @@ """MAPPO system executor implementation.""" -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple import dm_env import numpy as np @@ -27,8 +27,9 @@ from acme.tf import utils as tf2_utils from acme.tf import variable_utils as tf2_variable_utils -from mava import adders, core +from mava import adders from mava.systems.tf import executors +from mava.utils.sort_utils import sample_new_agent_keys, sort_str_num tfd = tfp.distributions @@ -41,8 +42,12 @@ class MAPPOFeedForwardExecutor(executors.FeedForwardExecutor): def __init__( self, policy_networks: Dict[str, snt.Module], + agent_specs: Dict[str, EnvironmentSpec], agent_net_keys: Dict[str, str], + network_sampling_setup: List, + net_keys_to_ids: Dict[str, int], adder: Optional[adders.ParallelAdder] = None, + counts: Optional[Dict[str, Any]] = None, variable_client: Optional[tf2_variable_utils.VariableClient] = None, ): """Initialise the system executor @@ -56,6 +61,11 @@ def __init__( agent_net_keys: (dict, optional): specifies what network each agent uses. Defaults to {}. """ + self._agent_specs = agent_specs + self._network_sampling_setup = network_sampling_setup + self._counts = counts + self._network_int_keys_extras: Dict[str, np.array] = {} + self._net_keys_to_ids = net_keys_to_ids super().__init__( policy_networks=policy_networks, agent_net_keys=agent_net_keys, @@ -68,7 +78,7 @@ def _policy( self, agent: str, observation: types.NestedTensor, - ) -> Tuple[types.NestedTensor, types.NestedTensor, types.NestedTensor]: + ) -> Tuple[types.NestedTensor, types.NestedTensor]: """Agent specific policy function Args: @@ -149,11 +159,37 @@ def select_actions( actions[agent], logits[agent] = self.select_action(agent, observation) return actions, logits + def observe_first( + self, + timestep: dm_env.TimeStep, + extras: Dict[str, types.NestedArray] = {}, + ) -> None: + """record first observed timestep from the environment + Args: + timestep: data emitted by an environment at first step of + interaction. + extras: possible extra information + to record during the first step. Defaults to {}. + """ + if not self._adder: + return + + "Select new networks from the sampler at the start of each episode." + agents = sort_str_num(list(self._agent_net_keys.keys())) + self._network_int_keys_extras, self._agent_net_keys = sample_new_agent_keys( + agents, + self._network_sampling_setup, + self._net_keys_to_ids, + ) + + extras["network_int_keys"] = self._network_int_keys_extras + self._adder.add_first(timestep, extras) + def observe( self, actions: Dict[str, types.NestedArray], next_timestep: dm_env.TimeStep, - next_extras: Optional[Dict[str, types.NestedArray]] = {}, + next_extras: Dict[str, types.NestedArray] = {}, ) -> None: """record observed timestep from the environment @@ -167,16 +203,22 @@ def observe( if not self._adder: return - actions, logits = actions + actions, logits = actions # type: ignore - adder_actions = {} + adder_actions: Dict[str, Any] = {} for agent in actions.keys(): adder_actions[agent] = {} adder_actions[agent]["actions"] = actions[agent] - adder_actions[agent]["logits"] = logits[agent] + adder_actions[agent]["logits"] = logits[agent] # type: ignore + next_extras["network_int_keys"] = self._network_int_keys_extras self._adder.add(adder_actions, next_timestep, next_extras) # type: ignore + def update(self, wait: bool = False) -> None: + """Update the policy variables.""" + if self._variable_client: + self._variable_client.get_async() + class MAPPORecurrentExecutor(executors.RecurrentExecutor): """A recurrent executor for MAPPO. @@ -186,8 +228,12 @@ class MAPPORecurrentExecutor(executors.RecurrentExecutor): def __init__( self, policy_networks: Dict[str, snt.Module], + agent_specs: Dict[str, EnvironmentSpec], agent_net_keys: Dict[str, str], + network_sampling_setup: List, + net_keys_to_ids: Dict[str, int], adder: Optional[adders.ParallelAdder] = None, + counts: Optional[Dict[str, Any]] = None, variable_client: Optional[tf2_variable_utils.VariableClient] = None, ): """Initialise the system executor @@ -201,6 +247,11 @@ def __init__( agent_net_keys: (dict, optional): specifies what network each agent uses. Defaults to {}. """ + self._agent_specs = agent_specs + self._network_sampling_setup = network_sampling_setup + self._counts = counts + self._net_keys_to_ids = net_keys_to_ids + self._network_int_keys_extras: Dict[str, np.array] = {} super().__init__( policy_networks=policy_networks, agent_net_keys=agent_net_keys, @@ -307,11 +358,50 @@ def select_actions( actions[agent], logits[agent] = self.select_action(agent, observation) return actions, logits + def observe_first( + self, + timestep: dm_env.TimeStep, + extras: Dict[str, types.NestedArray] = {}, + ) -> None: + """record first observed timestep from the environment + + Args: + timestep: data emitted by an environment at first step of + interaction. + extras: possible extra information + to record during the first step. + """ + + # Re-initialize the RNN state. + for agent, _ in timestep.observation.items(): + # index network either on agent type or on agent id + agent_key = self._agent_net_keys[agent] + self._states[agent] = self._policy_networks[agent_key].initial_state(1) + + if not self._adder: + return + + # Sample new agent_net_keys. + agents = sort_str_num(list(self._agent_net_keys.keys())) + self._network_int_keys_extras, self._agent_net_keys = sample_new_agent_keys( + agents, + self._network_sampling_setup, + self._net_keys_to_ids, + ) + + numpy_states = { + agent: tf2_utils.to_numpy_squeeze(_state) + for agent, _state in self._states.items() + } + extras.update({"core_states": numpy_states}) + extras["network_int_keys"] = self._network_int_keys_extras + self._adder.add_first(timestep, extras) + def observe( self, actions: Dict[str, types.NestedArray], next_timestep: dm_env.TimeStep, - next_extras: Optional[Dict[str, types.NestedArray]] = {}, + next_extras: Dict[str, types.NestedArray] = {}, ) -> None: """record observed timestep from the environment @@ -325,20 +415,23 @@ def observe( if not self._adder: return - actions, logits = actions + actions, logits = actions # type: ignore - adder_actions = {} + adder_actions: Dict[str, Any] = {} for agent in actions.keys(): adder_actions[agent] = {} adder_actions[agent]["actions"] = actions[agent] - adder_actions[agent]["logits"] = logits[agent] + adder_actions[agent]["logits"] = logits[agent] # type: ignore numpy_states = { agent: tf2_utils.to_numpy_squeeze(_state) for agent, _state in self._states.items() } - if next_extras: - next_extras.update({"core_states": numpy_states}) - self._adder.add(adder_actions, next_timestep, next_extras) # type: ignore - else: - self._adder.add(adder_actions, next_timestep, numpy_states) # type: ignore + next_extras.update({"core_states": numpy_states}) + next_extras["network_int_keys"] = self._network_int_keys_extras + self._adder.add(adder_actions, next_timestep, next_extras) # type: ignore + + def update(self, wait: bool = False) -> None: + """Update the policy variables.""" + if self._variable_client: + self._variable_client.get_async() diff --git a/mava/systems/tf/mappo/networks.py b/mava/systems/tf/mappo/networks.py index f5fe78e78..95c99fad0 100644 --- a/mava/systems/tf/mappo/networks.py +++ b/mava/systems/tf/mappo/networks.py @@ -18,7 +18,6 @@ import numpy as np import sonnet as snt import tensorflow as tf -import tensorflow_probability as tfp from acme.tf import utils as tf2_utils from dm_env import specs @@ -35,13 +34,14 @@ def make_default_networks( environment_spec: mava_specs.MAEnvironmentSpec, agent_net_keys: Dict[str, str], + net_spec_keys: Dict[str, str] = {}, policy_networks_layer_sizes: Union[Dict[str, Sequence], Sequence] = ( 256, 256, 256, ), critic_networks_layer_sizes: Union[Dict[str, Sequence], Sequence] = (512, 512, 256), - archecture_type=ArchitectureType.feedforward, + archecture_type: ArchitectureType = ArchitectureType.feedforward, seed: Optional[int] = None, ) -> Dict[str, snt.Module]: """Default networks for mappo. @@ -51,6 +51,7 @@ def make_default_networks( observation spaces etc. for each agent in the system. agent_net_keys: (dict, optional): specifies what network each agent uses. Defaults to {}. + net_spec_keys: specifies the specs of each network. policy_networks_layer_sizes (Union[Dict[str, Sequence], Sequence], optional): size of policy networks. Defaults to (256, 256, 256). critic_networks_layer_sizes (Union[Dict[str, Sequence], Sequence], optional): @@ -67,7 +68,10 @@ def make_default_networks( # Create agent_type specs. specs = environment_spec.get_agent_specs() - specs = {agent_net_keys[key]: specs[key] for key in specs.keys()} + if not net_spec_keys: + specs = {agent_net_keys[key]: specs[key] for key in specs.keys()} + else: + specs = {net_key: specs[value] for net_key, value in net_spec_keys.items()} # Set Policy function and layer size # Default size per arch type. diff --git a/mava/systems/tf/mappo/system.py b/mava/systems/tf/mappo/system.py index 8eda70092..dc78e7839 100644 --- a/mava/systems/tf/mappo/system.py +++ b/mava/systems/tf/mappo/system.py @@ -16,7 +16,7 @@ """MAPPO system implementation.""" import functools -from typing import Any, Callable, Dict, Optional, Type, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union import acme import dm_env @@ -25,24 +25,26 @@ import sonnet as snt from acme import specs as acme_specs from acme.tf import utils as tf2_utils -from acme.utils import counting, loggers +from acme.utils import loggers +from dm_env import specs import mava -from mava import core from mava import specs as mava_specs from mava.components.tf.architectures import DecentralisedValueActorCritic from mava.environment_loop import ParallelEnvironmentLoop -from mava.systems.tf import savers as tf2_savers +from mava.systems.tf import executors from mava.systems.tf.mappo import builder, execution, training -from mava.utils import lp_utils +from mava.systems.tf.variable_sources import VariableSource as MavaVariableSource +from mava.utils import enums from mava.utils.loggers import MavaLogger, logger_utils +from mava.utils.sort_utils import sample_new_agent_keys, sort_str_num from mava.wrappers import DetailedPerAgentStatistics class MAPPO: """MAPPO system.""" - def __init__( + def __init__( # noqa self, environment_factory: Callable[[bool], dm_env.Environment], network_factory: Callable[[acme_specs.BoundedArray], Dict[str, snt.Module]], @@ -51,12 +53,18 @@ def __init__( DecentralisedValueActorCritic ] = DecentralisedValueActorCritic, trainer_fn: Type[training.MAPPOTrainer] = training.MAPPOTrainer, - executor_fn: Type[core.Executor] = execution.MAPPOFeedForwardExecutor, + executor_fn: Type[ + Union[execution.MAPPOFeedForwardExecutor, execution.MAPPORecurrentExecutor] + ] = execution.MAPPOFeedForwardExecutor, num_executors: int = 1, - num_caches: int = 0, + trainer_networks: Union[ + Dict[str, List], enums.Trainer + ] = enums.Trainer.single_trainer, + network_sampling_setup: Union[ + List, enums.NetworkSampler + ] = enums.NetworkSampler.fixed_agent_networks, environment_spec: mava_specs.MAEnvironmentSpec = None, shared_weights: bool = True, - agent_net_keys: Dict[str, str] = {}, executor_variable_update_period: int = 100, policy_optimizer: Union[ snt.Optimizer, Dict[str, snt.Optimizer] @@ -72,7 +80,6 @@ def __init__( batch_size: int = 32, sequence_length: int = 20, sequence_period: int = 10, - max_executor_steps: int = None, checkpoint: bool = True, checkpoint_subpath: str = "~/mava/", checkpoint_minute_interval: int = 5, @@ -81,6 +88,7 @@ def __init__( eval_loop_fn: Callable = ParallelEnvironmentLoop, train_loop_fn_kwargs: Dict = {}, eval_loop_fn_kwargs: Dict = {}, + termination_condition: Optional[Dict[str, int]] = None, ): """Initialise the system @@ -155,6 +163,12 @@ def __init__( to the training loop. Defaults to {}. eval_loop_fn_kwargs (Dict, optional): possible keyword arguments to send to the evaluation loop. Defaults to {}. + termination_condition: An optional terminal condition can be + provided that stops the program once the condition is + satisfied. Available options include specifying maximum + values for trainer_steps, trainer_walltime, evaluator_steps, + evaluator_episodes, executor_episodes or executor_steps. + E.g. termination_condition = {'trainer_steps': 100000}. """ if not environment_spec: @@ -171,22 +185,123 @@ def __init__( time_delta=10, ) + # Setup agent networks and network sampling setup + agents = sort_str_num(environment_spec.get_agent_ids()) + self._network_sampling_setup = network_sampling_setup + + if type(network_sampling_setup) is not list: + if network_sampling_setup == enums.NetworkSampler.fixed_agent_networks: + # if no network_sampling_setup is fixed, use shared_weights to + # determine setup + self._agent_net_keys = { + agent: "network_0" if shared_weights else f"network_{i}" + for i, agent in enumerate(agents) + } + self._network_sampling_setup = [ + [ + self._agent_net_keys[key] + for key in sort_str_num(self._agent_net_keys.keys()) + ] + ] + elif network_sampling_setup == enums.NetworkSampler.random_agent_networks: + """Create N network policies, where N is the number of agents. Randomly + select policies from this sets for each agent at the start of a + episode. This sampling is done with replacement so the same policy + can be selected for more than one agent for a given episode.""" + if shared_weights: + raise ValueError( + "Shared weights cannot be used with random policy per agent" + ) + self._agent_net_keys = { + agents[i]: f"network_{i}" for i in range(len(agents)) + } + self._network_sampling_setup = [ + [ + [self._agent_net_keys[key]] + for key in sort_str_num(self._agent_net_keys.keys()) + ] + ] + else: + raise ValueError( + "network_sampling_setup must be a dict or fixed_agent_networks" + ) + + else: + # if a dictionary is provided, use network_sampling_setup to determine setup + _, self._agent_net_keys = sample_new_agent_keys( + agents, + self._network_sampling_setup, # type: ignore + ) + + # Check that the environment and agent_net_keys has the same amount of agents + sample_length = len(self._network_sampling_setup[0]) # type: ignore + assert len(environment_spec.get_agent_ids()) == len(self._agent_net_keys.keys()) + + # Check if the samples are of the same length and that they perfectly fit + # into the total number of agents + assert len(self._agent_net_keys.keys()) % sample_length == 0 + for i in range(1, len(self._network_sampling_setup)): # type: ignore + assert len(self._network_sampling_setup[i]) == sample_length # type: ignore + + # Get all the unique agent network keys + all_samples = [] + for sample in self._network_sampling_setup: # type: ignore + all_samples.extend(sample) + unique_net_keys = list(sort_str_num(list(set(all_samples)))) + + # Create mapping from ints to networks + net_keys_to_ids = {net_key: i for i, net_key in enumerate(unique_net_keys)} + + # Setup trainer_networks + if type(trainer_networks) is not dict: + if trainer_networks == enums.Trainer.single_trainer: + self._trainer_networks = {"trainer": unique_net_keys} + elif trainer_networks == enums.Trainer.one_trainer_per_network: + self._trainer_networks = { + f"trainer_{i}": [unique_net_keys[i]] + for i in range(len(unique_net_keys)) + } + else: + raise ValueError( + "trainer_networks does not support this enums setting." + ) + else: + self._trainer_networks = trainer_networks # type: ignore + + # Get all the unique trainer network keys + all_trainer_net_keys = [] + for trainer_nets in self._trainer_networks.values(): + all_trainer_net_keys.extend(trainer_nets) + unique_trainer_net_keys = sort_str_num(list(set(all_trainer_net_keys))) + + # Check that all agent_net_keys are in trainer_networks + assert unique_net_keys == unique_trainer_net_keys + # Setup specs for each network + self._net_spec_keys = {} + for i in range(len(unique_net_keys)): + self._net_spec_keys[unique_net_keys[i]] = agents[i % len(agents)] + + # Setup table_network_config + table_network_config = {} + for trainer_key in self._trainer_networks.keys(): + most_matches = 0 + trainer_nets = self._trainer_networks[trainer_key] + for sample in self._network_sampling_setup: # type: ignore + matches = 0 + for entry in sample: + if entry in trainer_nets: + matches += 1 + if most_matches < matches: + matches = most_matches + table_network_config[trainer_key] = sample + + self._table_network_config = table_network_config self._architecture = architecture self._environment_factory = environment_factory self._network_factory = network_factory self._logger_factory = logger_factory self._environment_spec = environment_spec - # Setup agent networks - self._agent_net_keys = agent_net_keys - if not agent_net_keys: - agents = environment_spec.get_agent_ids() - self._agent_net_keys = { - agent: agent.split("_")[0] if shared_weights else agent - for agent in agents - } self._num_exectors = num_executors - self._num_caches = num_caches - self._max_executor_steps = max_executor_steps self._checkpoint_subpath = checkpoint_subpath self._checkpoint = checkpoint self._logger_config = logger_config @@ -194,12 +309,15 @@ def __init__( self._train_loop_fn_kwargs = train_loop_fn_kwargs self._eval_loop_fn = eval_loop_fn self._eval_loop_fn_kwargs = eval_loop_fn_kwargs - self._checkpoint_minute_interval = checkpoint_minute_interval - if issubclass(executor_fn, execution.MAPPORecurrentExecutor): + extra_specs = {} + if issubclass(executor_fn, executors.RecurrentExecutor): extra_specs = self._get_extra_specs() - else: - extra_specs = {} + + int_spec = specs.DiscreteArray(len(unique_net_keys)) + agents = environment_spec.get_agent_ids() + net_spec = {"network_keys": {agent: int_spec for agent in agents}} + extra_specs.update(net_spec) self._builder = builder.MAPPOBuilder( config=builder.MAPPOConfig( @@ -221,6 +339,13 @@ def __init__( critic_optimizer=critic_optimizer, checkpoint_subpath=checkpoint_subpath, checkpoint_minute_interval=checkpoint_minute_interval, + trainer_networks=self._trainer_networks, + table_network_config=table_network_config, + num_executors=num_executors, + network_sampling_setup=self._network_sampling_setup, # type: ignore + net_keys_to_ids=net_keys_to_ids, + unique_net_keys=unique_net_keys, + termination_condition=termination_condition, ), trainer_fn=trainer_fn, executor_fn=executor_fn, @@ -257,130 +382,66 @@ def replay(self) -> Any: return self._builder.make_replay_tables(self._environment_spec) - def counter(self, checkpoint: bool) -> Any: - """Step counter - - Args: - checkpoint (bool): whether to checkpoint the counter. - - Returns: - Any: step counter object. - """ - - if checkpoint: - return tf2_savers.CheckpointingRunner( - counting.Counter(), - time_delta_minutes=self._checkpoint_minute_interval, - directory=self._checkpoint_subpath, - subdirectory="counter", - ) - else: - return counting.Counter() - - def coordinator(self, counter: counting.Counter) -> Any: - """Coordination helper for a distributed program - - Args: - counter (counting.Counter): step counter object. - - Returns: - Any: step limiter object. - """ - - return lp_utils.StepsLimiter(counter, self._max_executor_steps) # type: ignore - - def trainer( + def create_system( self, - replay: reverb.Client, - counter: counting.Counter, - ) -> mava.core.Trainer: - """System trainer - - Args: - replay (reverb.Client): replay data table to pull data from. - counter (counting.Counter): step counter object. - - Returns: - mava.core.Trainer: system trainer. - """ + ) -> Tuple[Dict[str, Dict[str, snt.Module]], Dict[str, Dict[str, snt.Module]]]: + """Initialise the system variables from the network factory.""" # Create the networks to optimize (online) networks = self._network_factory( # type: ignore environment_spec=self._environment_spec, agent_net_keys=self._agent_net_keys, + net_spec_keys=self._net_spec_keys, ) - # Create system architecture with target networks. - system_networks = self._architecture( - environment_spec=self._environment_spec, - observation_networks=networks["observations"], - policy_networks=networks["policies"], - critic_networks=networks["critics"], - agent_net_keys=self._agent_net_keys, - ).create_system() - - # create logger - trainer_logger_config = {} - if self._logger_config: - if "trainer" in self._logger_config: - trainer_logger_config = self._logger_config["trainer"] - trainer_logger = self._logger_factory( # type: ignore - "trainer", **trainer_logger_config - ) - can_sample = None # lambda: replay.can_sample(self._builder._config.batch_size) - dataset = self._builder.make_dataset_iterator(replay) - counter = counting.Counter(counter, "trainer") - - return self._builder.make_trainer( - networks=system_networks, - dataset=dataset, - counter=counter, - logger=trainer_logger, - can_sample=can_sample, - ) + # architecture args + architecture_config = { + "environment_spec": self._environment_spec, + "observation_networks": networks["observations"], + "policy_networks": networks["policies"], + "critic_networks": networks["critics"], + "agent_net_keys": self._agent_net_keys, + } + + # net_spec_keys is only implemented for the Decentralised architectures + if self._architecture == DecentralisedValueActorCritic: + architecture_config["net_spec_keys"] = self._net_spec_keys + + # TODO (dries): Can net_spec_keys and network_spec be used as + # the same thing? Can we use use one of those two instead of both. + + system = self._architecture(**architecture_config) + networks = system.create_system() + behaviour_networks = system.create_behaviour_policy() + return behaviour_networks, networks + + def variable_server(self) -> MavaVariableSource: + """Create the variable server.""" + # Create the system + _, networks = self.create_system() + return self._builder.make_variable_server(networks) def executor( self, executor_id: str, replay: reverb.Client, variable_source: acme.VariableSource, - counter: counting.Counter, ) -> mava.ParallelEnvironmentLoop: """System executor - Args: - executor_id (str): id to identify the executor process for logging purposes. - replay (reverb.Client): replay data table to push data to. - variable_source (acme.VariableSource): variable server for updating + executor_id: id to identify the executor process for logging purposes. + replay: replay data table to push data to. + variable_source: variable server for updating network variables. - counter (counting.Counter): step counter object. - Returns: mava.ParallelEnvironmentLoop: environment-executor loop instance. """ - # Create the behavior policy. - networks = self._network_factory( # type: ignore - environment_spec=self._environment_spec, - agent_net_keys=self._agent_net_keys, - ) - - # Create system architecture with target networks. - system = self._architecture( - environment_spec=self._environment_spec, - observation_networks=networks["observations"], - policy_networks=networks["policies"], - critic_networks=networks["critics"], - agent_net_keys=self._agent_net_keys, - ) - - # create variables - _ = system.create_system() - - # behaviour policy networks (obs net + policy head) - behaviour_policy_networks = system.create_behaviour_policy() + # Create the system + behaviour_policy_networks, networks = self.create_system() # Create the executor. executor = self._builder.make_executor( + networks=networks, policy_networks=behaviour_policy_networks, adder=self._builder.make_adder(replay), variable_source=variable_source, @@ -390,14 +451,10 @@ def executor( # Create the environment. environment = self._environment_factory(evaluation=False) # type: ignore - # Create logger and counter; actors will not spam bigtable. - counter = counting.Counter(counter, "executor") - # Create executor logger executor_logger_config = {} - if self._logger_config: - if "executor" in self._logger_config: - executor_logger_config = self._logger_config["executor"] + if self._logger_config and "executor" in self._logger_config: + executor_logger_config = self._logger_config["executor"] exec_logger = self._logger_factory( # type: ignore f"executor_{executor_id}", **executor_logger_config ) @@ -406,7 +463,6 @@ def executor( train_loop = self._train_loop_fn( environment, executor, - counter=counter, logger=exec_logger, **self._train_loop_fn_kwargs, ) @@ -418,45 +474,25 @@ def executor( def evaluator( self, variable_source: acme.VariableSource, - counter: counting.Counter, logger: loggers.Logger = None, ) -> Any: """System evaluator (an executor process not connected to a dataset) - Args: - variable_source (acme.VariableSource): variable server for updating + variable_source: variable server for updating network variables. - counter (counting.Counter): step counter object. - logger (loggers.Logger, optional): logger object. Defaults to None. - + logger: logger object. Returns: - Any: environment-executor evaluation loop instance for evaluating the + environment-executor evaluation loop instance for evaluating the performance of a system. """ - # Create the behavior policy. - networks = self._network_factory( # type: ignore - environment_spec=self._environment_spec, - agent_net_keys=self._agent_net_keys, - ) - - # Create system architecture with target networks. - system = self._architecture( - environment_spec=self._environment_spec, - observation_networks=networks["observations"], - policy_networks=networks["policies"], - critic_networks=networks["critics"], - agent_net_keys=self._agent_net_keys, - ) - - # create variables - _ = system.create_system() - - # behaviour policy networks (obs net + policy head) - behaviour_policy_networks = system.create_behaviour_policy() + # Create the system + behaviour_policy_networks, networks = self.create_system() # Create the agent. executor = self._builder.make_executor( + # executor_id="evaluator", + networks=networks, policy_networks=behaviour_policy_networks, variable_source=variable_source, ) @@ -465,11 +501,9 @@ def evaluator( environment = self._environment_factory(evaluation=True) # type: ignore # Create logger and counter. - counter = counting.Counter(counter, "evaluator") evaluator_logger_config = {} - if self._logger_config: - if "evaluator" in self._logger_config: - evaluator_logger_config = self._logger_config["evaluator"] + if self._logger_config and "evaluator" in self._logger_config: + evaluator_logger_config = self._logger_config["evaluator"] eval_logger = self._logger_factory( # type: ignore "evaluator", **evaluator_logger_config ) @@ -479,7 +513,6 @@ def evaluator( eval_loop = self._eval_loop_fn( environment, executor, - counter=counter, logger=eval_logger, **self._eval_loop_fn_kwargs, ) @@ -487,54 +520,74 @@ def evaluator( eval_loop = DetailedPerAgentStatistics(eval_loop) return eval_loop - def build(self, name: str = "mappo") -> Any: - """Build the distributed system as a graph program. - + def trainer( + self, + trainer_id: str, + replay: reverb.Client, + variable_source: MavaVariableSource, + ) -> mava.core.Trainer: + """System trainer Args: - name (str, optional): system name. Defaults to "mappo". + trainer_id: Id of the trainer being created. + replay: replay data table to pull data from. + variable_source: variable server for updating + network variables. + Returns: + system trainer. + """ + + # create logger + trainer_logger_config = {} + if self._logger_config and "trainer" in self._logger_config: + trainer_logger_config = self._logger_config["trainer"] + trainer_logger = self._logger_factory( # type: ignore + trainer_id, **trainer_logger_config + ) + + # Create the system + _, networks = self.create_system() + + dataset = self._builder.make_dataset_iterator(replay, trainer_id) + + return self._builder.make_trainer( + networks=networks, + trainer_networks=self._trainer_networks[trainer_id], + trainer_table_entry=self._table_network_config[trainer_id], + dataset=dataset, + logger=trainer_logger, + variable_source=variable_source, + ) + def build(self, name: str = "maddpg") -> Any: + """Build the distributed system as a graph program. + Args: + name: system name. Returns: - Any: graph program for distributed system training. + graph program for distributed system training. """ program = lp.Program(name=name) with program.group("replay"): replay = program.add_node(lp.ReverbNode(self.replay)) - with program.group("counter"): - counter = program.add_node(lp.CourierNode(self.counter, self._checkpoint)) - - if self._max_executor_steps: - with program.group("coordinator"): - _ = program.add_node(lp.CourierNode(self.coordinator, counter)) + with program.group("variable_server"): + variable_server = program.add_node(lp.CourierNode(self.variable_server)) with program.group("trainer"): - trainer = program.add_node(lp.CourierNode(self.trainer, replay, counter)) + # Add executors which pull round-robin from our variable sources. + for trainer_id in self._trainer_networks.keys(): + program.add_node( + lp.CourierNode(self.trainer, trainer_id, replay, variable_server) + ) with program.group("evaluator"): - program.add_node(lp.CourierNode(self.evaluator, trainer, counter)) - - if not self._num_caches: - # Use the trainer as a single variable source. - sources = [trainer] - else: - with program.group("cacher"): - # Create a set of trainer caches. - sources = [] - for _ in range(self._num_caches): - cacher = program.add_node( - lp.CacherNode( - trainer, refresh_interval_ms=2000, stale_after_ms=4000 - ) - ) - sources.append(cacher) + program.add_node(lp.CourierNode(self.evaluator, variable_server)) with program.group("executor"): # Add executors which pull round-robin from our variable sources. for executor_id in range(self._num_exectors): - source = sources[executor_id % len(sources)] program.add_node( - lp.CourierNode(self.executor, executor_id, replay, source, counter) + lp.CourierNode(self.executor, executor_id, replay, variable_server) ) return program diff --git a/mava/systems/tf/mappo/training.py b/mava/systems/tf/mappo/training.py index 09389f0ed..3f93a2948 100644 --- a/mava/systems/tf/mappo/training.py +++ b/mava/systems/tf/mappo/training.py @@ -16,7 +16,6 @@ """MAPPO system trainer implementation.""" import copy -import os import time from typing import Any, Dict, List, Optional, Sequence, Union @@ -31,7 +30,7 @@ import mava from mava.adders.reverb.base import Trajectory -from mava.systems.tf import savers as tf2_savers +from mava.systems.tf.variable_utils import VariableClient from mava.utils import training_utils as train_utils from mava.utils.sort_utils import sort_str_num @@ -49,16 +48,15 @@ class MAPPOTrainer(mava.Trainer): def __init__( self, agents: List[Any], - agent_types: List[str], - can_sample: Any, observation_networks: Dict[str, snt.Module], policy_networks: Dict[str, snt.Module], critic_networks: Dict[str, snt.Module], dataset: tf.data.Dataset, policy_optimizer: Union[snt.Optimizer, Dict[str, snt.Optimizer]], critic_optimizer: Union[snt.Optimizer, Dict[str, snt.Optimizer]], + variable_client: VariableClient, + counts: Dict[str, Any], agent_net_keys: Dict[str, str], - checkpoint_minute_interval: int, discount: float = 0.999, lambda_gae: float = 1.0, # Question (dries): What is this used for? entropy_cost: float = 0.01, @@ -67,14 +65,11 @@ def __init__( max_gradient_norm: Optional[float] = None, counter: counting.Counter = None, logger: loggers.Logger = None, - checkpoint: bool = False, - checkpoint_subpath: str = "~/mava/", ): """Initialise MAPPO trainer Args: agents (List[str]): agent ids, e.g. "agent_0". - agent_types (List[str]): agent types, e.g. "speaker" or "listener". policy_networks (Dict[str, snt.Module]): policy networks for each agent in the system. critic_networks (Dict[str, snt.Module]): critic network(s), shared or for @@ -84,10 +79,10 @@ def __init__( optimizer(s) for updating policy networks. critic_optimizer (Union[snt.Optimizer, Dict[str, snt.Optimizer]]): optimizer for updating critic networks. + variable_client: The client used to manage the variables. + counts: step counter object. agent_net_keys: (dict, optional): specifies what network each agent uses. Defaults to {}. - checkpoint_minute_interval (int): The number of minutes to wait between - checkpoints. discount (float, optional): discount factor for TD updates. Defaults to 0.99. lambda_gae (float, optional): scalar determining the mix of bootstrapping @@ -104,23 +99,20 @@ def __init__( counter (counting.Counter, optional): step counter object. Defaults to None. logger (loggers.Logger, optional): logger object for logging trainer statistics. Defaults to None. - checkpoint (bool, optional): whether to checkpoint networks. Defaults to - True. - checkpoint_subpath (str, optional): subdirectory for storing checkpoints. - Defaults to "~/mava/". """ # Store agents. self._agents = agents - self._agent_types = agent_types - self._checkpoint = checkpoint - - # Can sample - self._can_sample = can_sample # Store agent_net_keys. self._agent_net_keys = agent_net_keys + # Setup the variable client + self._variable_client = variable_client + + # Setup counts + self._counts = counts + # Store networks. self._observation_networks = observation_networks self._policy_networks = policy_networks @@ -181,28 +173,6 @@ def __init__( self._counter = counter or counting.Counter() self._logger = logger or loggers.make_default_logger("trainer") - # Create checkpointer - self._system_checkpointer = {} - if checkpoint: - for agent_key in self.unique_net_keys: - objects_to_save = { - "counter": self._counter, - "policy": self._policy_networks[agent_key], - "critic": self._critic_networks[agent_key], - "observation": self._observation_networks[agent_key], - "policy_optimizer": self._policy_optimizers, - "critic_optimizer": self._critic_optimizers, - } - - subdir = os.path.join("trainer", agent_key) - checkpointer = tf2_savers.Checkpointer( - time_delta_minutes=checkpoint_minute_interval, - directory=checkpoint_subpath, - objects_to_save=objects_to_save, - subdirectory=subdir, - ) - self._system_checkpointer[agent_key] = checkpointer - # Do not record timestamps until after the first learning step is done. # This is to avoid including the time it takes for actors to come online and # fill the replay buffer. @@ -211,6 +181,7 @@ def __init__( def _get_critic_feed( self, observations_trans: Dict[str, np.ndarray], + extras: Dict[str, np.ndarray], agent: str, ) -> tf.Tensor: """Get critic feed. @@ -218,6 +189,8 @@ def _get_critic_feed( Args: observations_trans (Dict[str, np.ndarray]): transformed (e.g. using observation network) raw agent observation. + extras: Extra information. E.g. the environment state can be included + here. agent (str): agent id. Returns: @@ -272,7 +245,7 @@ def _step( return self.forward_backward(inputs) @tf.function - def forward_backward(self, inputs: Any): + def forward_backward(self, inputs: Any) -> Dict[str, Dict[str, Any]]: self._forward_pass(inputs) self._backward_pass() # Log losses per agent @@ -281,7 +254,8 @@ def forward_backward(self, inputs: Any): ) # Forward pass that calculates loss. - # This name is changed from _backward to make sure the trainer wrappers are not overwriting the _step function. + # This name is changed from _backward to make sure the trainer + # wrappers are not overwriting the _step function. def _forward_pass(self, inputs: Any) -> None: """Trainer forward pass @@ -372,7 +346,9 @@ def _forward_pass(self, inputs: Any) -> None: td_loss, td_lambda_extra = trfl.td_lambda( state_values=state_values, rewards=reward, - pcontinues=discount, # Question (dries): Why is self._discount not used. Why is discount not 0.0/1.0 but actually has discount values. + pcontinues=discount, # Question (dries): Why is self._discount + # not used. Why is discount not 0.0/1.0 but actually has + # discount values. bootstrap_value=bootstrap_value, lambda_=self._lambda_gae, name="CriticLoss", @@ -396,8 +372,9 @@ def _forward_pass(self, inputs: Any) -> None: # Generalized Advantage Estimation gae = tf.stop_gradient(td_lambda_extra.temporal_differences) - # Note (dries): Maybe add this in again? But this might be breaking training. - # mean, variance = tf.nn.moments(gae, axes=[0, 1], keepdims=True) + # Note (dries): Maybe add this in again? But this might be breaking + # training. mean, variance = tf.nn.moments(gae, axes=[0, 1], + # keepdims=True) # normalized_gae = (gae - mean) / tf.sqrt(variance) policy_gradient_loss = -tf.minimum( @@ -414,7 +391,8 @@ def _forward_pass(self, inputs: Any) -> None: # Multiply by discounts to not train on padded data. loss_mask = discount > 0.0 - # TODO (dries): Is multiplication maybe better here? As assignment might not work with tf.function? + # TODO (dries): Is multiplication maybe better here? As assignment + # might not work with tf.function? policy_loss = policy_loss[loss_mask] critic_loss = critic_loss[loss_mask] policy_losses[agent] = tf.reduce_mean(policy_loss) @@ -425,7 +403,8 @@ def _forward_pass(self, inputs: Any) -> None: self.tape = tape # Backward pass that calculates gradients and updates network. - # This name is changed from _backward to make sure the trainer wrappers are not overwriting the _step function. + # This name is changed from _backward to make sure the trainer wrappers are not + # overwriting the _step function. def _backward_pass(self) -> None: """Trainer backward pass updating network parameters""" @@ -474,9 +453,17 @@ def step(self) -> None: elapsed_time = timestamp - self._timestamp if self._timestamp else 0 self._timestamp = timestamp + raise ValueError("This step should not be used. Use a trainer wrapper.") + # Update our counts and record it. - counts = self._counter.increment(steps=1, walltime=elapsed_time) - fetches.update(counts) + # TODO (dries): Can this be simplified? Only one set and one get? + self._variable_client.add_async( + ["trainer_steps", "trainer_walltime"], + {"trainer_steps": 1, "trainer_walltime": elapsed_time}, + ) + + # Update the variable source and the trainer + self._variable_client.set_and_get_async() # Checkpoint and attempt to write the logs. if self._checkpoint: @@ -512,16 +499,15 @@ class CentralisedMAPPOTrainer(MAPPOTrainer): def __init__( self, agents: List[Any], - agent_types: List[str], - can_sample: Any, observation_networks: Dict[str, snt.Module], policy_networks: Dict[str, snt.Module], critic_networks: Dict[str, snt.Module], dataset: tf.data.Dataset, policy_optimizer: Union[snt.Optimizer, Dict[str, snt.Optimizer]], critic_optimizer: Union[snt.Optimizer, Dict[str, snt.Optimizer]], + variable_client: VariableClient, + counts: Dict[str, Any], agent_net_keys: Dict[str, str], - checkpoint_minute_interval: int, discount: float = 0.999, lambda_gae: float = 1.0, entropy_cost: float = 0.01, @@ -530,22 +516,19 @@ def __init__( max_gradient_norm: Optional[float] = None, counter: counting.Counter = None, logger: loggers.Logger = None, - checkpoint: bool = False, - checkpoint_subpath: str = "Checkpoints", ): super().__init__( agents=agents, - agent_types=agent_types, - can_sample=can_sample, policy_networks=policy_networks, critic_networks=critic_networks, observation_networks=observation_networks, dataset=dataset, agent_net_keys=agent_net_keys, - checkpoint_minute_interval=checkpoint_minute_interval, policy_optimizer=policy_optimizer, critic_optimizer=critic_optimizer, + variable_client=variable_client, + counts=counts, discount=discount, lambda_gae=lambda_gae, entropy_cost=entropy_cost, @@ -554,8 +537,6 @@ def __init__( max_gradient_norm=max_gradient_norm, counter=counter, logger=logger, - checkpoint=checkpoint, - checkpoint_subpath=checkpoint_subpath, ) def _get_critic_feed( @@ -579,16 +560,15 @@ class StateBasedMAPPOTrainer(MAPPOTrainer): def __init__( self, agents: List[Any], - agent_types: List[str], - can_sample: Any, observation_networks: Dict[str, snt.Module], policy_networks: Dict[str, snt.Module], critic_networks: Dict[str, snt.Module], dataset: tf.data.Dataset, policy_optimizer: Union[snt.Optimizer, Dict[str, snt.Optimizer]], critic_optimizer: Union[snt.Optimizer, Dict[str, snt.Optimizer]], + counts: Dict[str, Any], + variable_client: VariableClient, agent_net_keys: Dict[str, str], - checkpoint_minute_interval: int, discount: float = 0.999, lambda_gae: float = 1.0, entropy_cost: float = 0.01, @@ -597,22 +577,19 @@ def __init__( max_gradient_norm: Optional[float] = None, counter: counting.Counter = None, logger: loggers.Logger = None, - checkpoint: bool = False, - checkpoint_subpath: str = "Checkpoints", ): super().__init__( agents=agents, - agent_types=agent_types, - can_sample=can_sample, policy_networks=policy_networks, critic_networks=critic_networks, observation_networks=observation_networks, dataset=dataset, agent_net_keys=agent_net_keys, - checkpoint_minute_interval=checkpoint_minute_interval, policy_optimizer=policy_optimizer, critic_optimizer=critic_optimizer, + variable_client=variable_client, + counts=counts, discount=discount, lambda_gae=lambda_gae, entropy_cost=entropy_cost, @@ -621,8 +598,6 @@ def __init__( max_gradient_norm=max_gradient_norm, counter=counter, logger=logger, - checkpoint=checkpoint, - checkpoint_subpath=checkpoint_subpath, ) def _get_critic_feed( diff --git a/mava/utils/debugging/scenarios/simple_spread.py b/mava/utils/debugging/scenarios/simple_spread.py index c09992b1c..471c03169 100644 --- a/mava/utils/debugging/scenarios/simple_spread.py +++ b/mava/utils/debugging/scenarios/simple_spread.py @@ -24,7 +24,7 @@ class Scenario(BaseScenario): - def __init__(self, recurrent_test) -> None: + def __init__(self, recurrent_test: bool) -> None: super().__init__() self.np_rnd = np.random.RandomState() self.recurrent_test = recurrent_test From f1006c17905cc2de65b76ee8f419e2f942aba457 Mon Sep 17 00:00:00 2001 From: Dries Date: Fri, 29 Oct 2021 10:39:24 +0200 Subject: [PATCH 013/104] Add multiple trainers PPO example. --- .../decentralised/run_mappo_scale_trainers.py | 112 ++++++++++++++++++ mava/systems/tf/mappo/networks.py | 2 +- 2 files changed, 113 insertions(+), 1 deletion(-) create mode 100644 examples/debugging/simple_spread/feedforward/decentralised/run_mappo_scale_trainers.py diff --git a/examples/debugging/simple_spread/feedforward/decentralised/run_mappo_scale_trainers.py b/examples/debugging/simple_spread/feedforward/decentralised/run_mappo_scale_trainers.py new file mode 100644 index 000000000..11516b680 --- /dev/null +++ b/examples/debugging/simple_spread/feedforward/decentralised/run_mappo_scale_trainers.py @@ -0,0 +1,112 @@ +# python3 +# Copyright 2021 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Example running feedforward mappo on debug MPE environments. +NB: Using multiple trainers with non-shared weights is still in its + experimental phase of development. This feature will become faster and + more stable in future Mava updates.""" + +import functools +from datetime import datetime +from typing import Any + +import launchpad as lp +import sonnet as snt +from absl import app, flags + +from mava.systems.tf import mappo +from mava.systems.tf.mappo import make_default_networks +from mava.utils import enums, lp_utils +from mava.utils.environments import debugging_utils +from mava.utils.loggers import logger_utils + +FLAGS = flags.FLAGS +flags.DEFINE_string( + "env_name", + "simple_spread", + "Debugging environment name (str).", +) +flags.DEFINE_string( + "action_space", + "discrete", + "Environment action space type (str).", +) +flags.DEFINE_string( + "mava_id", + str(datetime.now()), + "Experiment identifier that can be used to continue experiments.", +) +flags.DEFINE_string("base_dir", "~/mava/", "Base dir to store experiments.") + + +def main(_: Any) -> None: + + # environment + environment_factory = functools.partial( + debugging_utils.make_environment, + env_name=FLAGS.env_name, + action_space=FLAGS.action_space, + ) + + # networks + network_factory = lp_utils.partial_kwargs(make_default_networks) + + # Checkpointer appends "Checkpoints" to checkpoint_dir + checkpoint_dir = f"{FLAGS.base_dir}/{FLAGS.mava_id}" + + # Log every [log_every] seconds. + log_every = 10 + logger_factory = functools.partial( + logger_utils.make_logger, + directory=FLAGS.base_dir, + to_terminal=True, + to_tensorboard=True, + time_stamp=FLAGS.mava_id, + time_delta=log_every, + ) + + # distributed program + """NB: Using multiple trainers with non-shared weights is still in its + experimental phase of development. This feature will become faster and + more stable in future Mava updates.""" + program = mappo.MAPPO( + environment_factory=environment_factory, + network_factory=network_factory, + logger_factory=logger_factory, + num_executors=2, + shared_weights=False, + trainer_networks=enums.Trainer.one_trainer_per_network, + network_sampling_setup=enums.NetworkSampler.fixed_agent_networks, + policy_optimizer=snt.optimizers.Adam(learning_rate=1e-4), + critic_optimizer=snt.optimizers.Adam(learning_rate=1e-4), + checkpoint_subpath=checkpoint_dir, + max_gradient_norm=40.0, + ).build() + + # Ensure only trainer runs on gpu, while other processes run on cpu. + local_resources = lp_utils.to_device( + program_nodes=program.groups.keys(), nodes_on_gpu=["trainer"] + ) + + lp.launch( + program, + lp.LaunchType.LOCAL_MULTI_PROCESSING, + terminal="current_terminal", + local_resources=local_resources, + ) + + +if __name__ == "__main__": + app.run(main) diff --git a/mava/systems/tf/mappo/networks.py b/mava/systems/tf/mappo/networks.py index 95c99fad0..a76371c9a 100644 --- a/mava/systems/tf/mappo/networks.py +++ b/mava/systems/tf/mappo/networks.py @@ -146,7 +146,7 @@ def make_default_networks( networks.MultivariateNormalDiagHead( num_dimensions=num_actions, w_init=tf.initializers.VarianceScaling(1e-4, seed=seed), - b_init=tf.initializers.Zeros(seed=seed), + b_init=tf.initializers.Zeros(), ), networks.TanhToSpec(specs[key].actions), ] From c33e32c4ba56f2a0bc5d1545b8642e9091ab0d9b Mon Sep 17 00:00:00 2001 From: Dries Date: Fri, 3 Dec 2021 13:07:05 +0200 Subject: [PATCH 014/104] Fix PPO example. --- docs/images/focus_fire.html | 2 +- docs/images/runaway.html | 2 +- mava/components/tf/architectures/centralised.py | 8 +++++--- mava/components/tf/architectures/decentralised.py | 3 +-- mava/systems/tf/mappo/execution.py | 7 +++++++ mava/systems/tf/mappo/training.py | 2 ++ mava/wrappers/meltingpot.py | 1 - 7 files changed, 17 insertions(+), 8 deletions(-) diff --git a/docs/images/focus_fire.html b/docs/images/focus_fire.html index b84a48bac..028d72abe 100644 --- a/docs/images/focus_fire.html +++ b/docs/images/focus_fire.html @@ -428,4 +428,4 @@ MTAw "> Your browser does not support the video tag. - \ No newline at end of file + diff --git a/docs/images/runaway.html b/docs/images/runaway.html index f52c7de2a..d92bb9189 100644 --- a/docs/images/runaway.html +++ b/docs/images/runaway.html @@ -1272,4 +1272,4 @@ dAAAACWpdG9vAAAAHWRhdGEAAAABAAAAAExhdmY1OC4yOS4xMDA= "> Your browser does not support the video tag. - \ No newline at end of file + diff --git a/mava/components/tf/architectures/centralised.py b/mava/components/tf/architectures/centralised.py index cb0ad817b..5f18729de 100644 --- a/mava/components/tf/architectures/centralised.py +++ b/mava/components/tf/architectures/centralised.py @@ -99,9 +99,10 @@ def _get_critic_specs( for agent_type, agents in agents_by_type.items(): agent_key = agents[0] - critic_obs_shape = list(copy.copy(self._embed_specs[agent_key].shape)) + net_key = self._agent_net_keys[agent_key] + critic_obs_shape = list(copy.copy(self._embed_specs[net_key].shape)) critic_obs_shape.insert(0, len(agents)) - obs_specs_per_type[agent_type] = tf.TensorSpec( + obs_specs_per_type[net_key] = tf.TensorSpec( shape=critic_obs_shape, dtype=tf.dtypes.float32, ) @@ -143,11 +144,12 @@ def _get_critic_specs( for agent_type, agents in agents_by_type.items(): agent_key = agents[0] + net_key = self._agent_net_keys[agent_key] # TODO (dries): Add a check to see if all # self._embed_specs[agent_key].shape are of the same shape - critic_obs_shape = list(copy.copy(self._embed_specs[agent_key].shape)) + critic_obs_shape = list(copy.copy(self._embed_specs[net_key].shape)) critic_obs_shape.insert(0, len(agents)) obs_specs_per_type[agent_type] = tf.TensorSpec( shape=critic_obs_shape, diff --git a/mava/components/tf/architectures/decentralised.py b/mava/components/tf/architectures/decentralised.py index 04bbfa019..26876785b 100644 --- a/mava/components/tf/architectures/decentralised.py +++ b/mava/components/tf/architectures/decentralised.py @@ -267,7 +267,7 @@ def create_actor_variables(self) -> Dict[str, Dict[str, snt.Module]]: emb_spec = tf2_utils.create_variables( self._observation_networks[net_key], [obs_spec] ) - self._embed_specs[agent_key] = emb_spec + self._embed_specs[net_key] = emb_spec # Create variables. tf2_utils.create_variables(self._policy_networks[net_key], [emb_spec]) @@ -279,7 +279,6 @@ def create_actor_variables(self) -> Dict[str, Dict[str, snt.Module]]: tf2_utils.create_variables( self._target_observation_networks[net_key], [obs_spec] ) - actor_networks: Dict[str, Dict[str, snt.Module]] = { "policies": self._policy_networks, "observations": self._observation_networks, diff --git a/mava/systems/tf/mappo/execution.py b/mava/systems/tf/mappo/execution.py index caab08e70..25fc9753c 100644 --- a/mava/systems/tf/mappo/execution.py +++ b/mava/systems/tf/mappo/execution.py @@ -248,6 +248,8 @@ def __init__( adder: Optional[adders.ParallelAdder] = None, counts: Optional[Dict[str, Any]] = None, variable_client: Optional[tf2_variable_utils.VariableClient] = None, + evaluator: bool = False, + interval: Optional[dict] = None, ): """Initialise the system executor Args: @@ -259,12 +261,17 @@ def __init__( client to copy weights from the trainer. Defaults to None. agent_net_keys: (dict, optional): specifies what network each agent uses. Defaults to {}. + evaluator (bool, optional): whether the executor will be used for + evaluation. Defaults to False. + interval: interval that evaluations are run at. """ self._agent_specs = agent_specs self._network_sampling_setup = network_sampling_setup self._counts = counts self._net_keys_to_ids = net_keys_to_ids self._network_int_keys_extras: Dict[str, np.array] = {} + self._evaluator = evaluator + self._interval = interval super().__init__( policy_networks=policy_networks, agent_net_keys=agent_net_keys, diff --git a/mava/systems/tf/mappo/training.py b/mava/systems/tf/mappo/training.py index b6386abb6..0e92fbe20 100644 --- a/mava/systems/tf/mappo/training.py +++ b/mava/systems/tf/mappo/training.py @@ -620,6 +620,7 @@ def __init__( max_gradient_norm: Optional[float] = None, counter: counting.Counter = None, logger: loggers.Logger = None, + learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, ): super().__init__( @@ -641,6 +642,7 @@ def __init__( max_gradient_norm=max_gradient_norm, counter=counter, logger=logger, + learning_rate_scheduler_fn=learning_rate_scheduler_fn, ) def _get_critic_feed( diff --git a/mava/wrappers/meltingpot.py b/mava/wrappers/meltingpot.py index 1f7d9c9ae..17338969f 100644 --- a/mava/wrappers/meltingpot.py +++ b/mava/wrappers/meltingpot.py @@ -27,7 +27,6 @@ try: import pygame # type: ignore - from meltingpot.python.scenario import Scenario # type: ignore from meltingpot.python.substrate import Substrate # type: ignore except ModuleNotFoundError: From ef7744142994a23e0d6cc0dd91beb94cee9a1ea0 Mon Sep 17 00:00:00 2001 From: Dries Date: Fri, 3 Dec 2021 13:33:10 +0200 Subject: [PATCH 015/104] Fix embed_spec bug. --- mava/components/tf/architectures/decentralised.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/mava/components/tf/architectures/decentralised.py b/mava/components/tf/architectures/decentralised.py index 26876785b..e259716d1 100644 --- a/mava/components/tf/architectures/decentralised.py +++ b/mava/components/tf/architectures/decentralised.py @@ -157,7 +157,7 @@ def create_actor_variables(self) -> Dict[str, Dict[str, snt.Module]]: emb_spec = tf2_utils.create_variables( self._observation_networks[agent_net_key], [obs_spec] ) - self._embed_specs[agent_key] = emb_spec + self._embed_specs[agent_net_key] = emb_spec # Create variables. tf2_utils.create_variables(self._policy_networks[agent_net_key], [emb_spec]) @@ -294,10 +294,8 @@ def create_critic_variables(self) -> Dict[str, Dict[str, snt.Module]]: # create critics for net_key in self._net_keys: - agent_key = self._net_spec_keys[net_key] - # get specs - emb_spec = embed_specs[agent_key] + emb_spec = embed_specs[net_key] # Create variables. tf2_utils.create_variables(self._critic_networks[net_key], [emb_spec]) @@ -371,7 +369,7 @@ def create_critic_variables(self) -> Dict[str, Dict[str, snt.Module]]: agent_key = self._net_spec_keys[net_key] # get specs - emb_spec = embed_specs[agent_key] + emb_spec = embed_specs[net_key] act_spec = act_specs[agent_key] # Create variables. From fb1c43c618780f3a1d9e461ce2c2e88290038832 Mon Sep 17 00:00:00 2001 From: Dries Date: Fri, 3 Dec 2021 14:15:24 +0200 Subject: [PATCH 016/104] Fix mypy issues. --- mava/systems/tf/mappo/execution.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mava/systems/tf/mappo/execution.py b/mava/systems/tf/mappo/execution.py index 25fc9753c..06f031679 100644 --- a/mava/systems/tf/mappo/execution.py +++ b/mava/systems/tf/mappo/execution.py @@ -18,7 +18,6 @@ from typing import Any, Dict, List, Optional, Tuple import dm_env -import numpy as np import sonnet as snt import tensorflow as tf import tensorflow_probability as tfp @@ -69,7 +68,7 @@ def __init__( self._agent_specs = agent_specs self._network_sampling_setup = network_sampling_setup self._counts = counts - self._network_int_keys_extras: Dict[str, np.array] = {} + self._network_int_keys_extras: Dict[str, Any] = {} self._net_keys_to_ids = net_keys_to_ids super().__init__( policy_networks=policy_networks, @@ -269,7 +268,7 @@ def __init__( self._network_sampling_setup = network_sampling_setup self._counts = counts self._net_keys_to_ids = net_keys_to_ids - self._network_int_keys_extras: Dict[str, np.array] = {} + self._network_int_keys_extras: Dict[str, Any] = {} self._evaluator = evaluator self._interval = interval super().__init__( From 58b68f90c7ade006b75cded768005f7d5c042b17 Mon Sep 17 00:00:00 2001 From: Dries Date: Fri, 3 Dec 2021 14:30:29 +0200 Subject: [PATCH 017/104] Fix mypy issue. --- mava/systems/tf/maddpg/training.py | 2 +- mava/systems/tf/mappo/training.py | 2 +- tests/conftest.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mava/systems/tf/maddpg/training.py b/mava/systems/tf/maddpg/training.py index 0a62c728b..b64fe1c66 100644 --- a/mava/systems/tf/maddpg/training.py +++ b/mava/systems/tf/maddpg/training.py @@ -1698,7 +1698,7 @@ def _get_critic_feed( # State based obs_trans_feed = extras["env_states"] target_obs_trans_feed = extras["env_states"] - if type(obs_trans_feed) == dict: + if type(obs_trans_feed) == dict: # type: ignore obs_trans_feed = obs_trans_feed[agent] target_obs_trans_feed = target_obs_trans_feed[agent] diff --git a/mava/systems/tf/mappo/training.py b/mava/systems/tf/mappo/training.py index 0e92fbe20..ab5649829 100644 --- a/mava/systems/tf/mappo/training.py +++ b/mava/systems/tf/mappo/training.py @@ -652,7 +652,7 @@ def _get_critic_feed( agent: str, ) -> tf.Tensor: # State based - if type(extras["env_states"]) == dict: + if type(extras["env_states"]) == dict: # type: ignore return extras["env_states"][agent] else: return extras["env_states"] diff --git a/tests/conftest.py b/tests/conftest.py index 0ff3e4845..02aa0a87c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -32,7 +32,7 @@ from mava.wrappers.flatland import FlatlandEnvWrapper _has_flatland = True -except ModuleNotFoundError: +except (ModuleNotFoundError, ImportError): _has_flatland = False try: from pettingzoo.utils.env import AECEnv, ParallelEnv From e4a343e4435eb57dc2656a80837ddd71d3d56116 Mon Sep 17 00:00:00 2001 From: Dries Date: Tue, 14 Dec 2021 13:49:46 +0200 Subject: [PATCH 018/104] Address some of the PR comments. --- mava/systems/tf/mappo/builder.py | 5 ----- mava/systems/tf/mappo/execution.py | 3 --- 2 files changed, 8 deletions(-) diff --git a/mava/systems/tf/mappo/builder.py b/mava/systems/tf/mappo/builder.py index c75bc660c..be335d649 100644 --- a/mava/systems/tf/mappo/builder.py +++ b/mava/systems/tf/mappo/builder.py @@ -245,11 +245,6 @@ def make_dataset_iterator( """ # Create tensorflow dataset to interface with reverb - # dataset = datasets.make_reverb_dataset( - # server_address=replay_client.server_address, - # batch_size=self._config.batch_size, - # ) - dataset = reverb.TrajectoryDataset.from_table_signature( server_address=replay_client.server_address, table=table_name, diff --git a/mava/systems/tf/mappo/execution.py b/mava/systems/tf/mappo/execution.py index 06f031679..e11f68896 100644 --- a/mava/systems/tf/mappo/execution.py +++ b/mava/systems/tf/mappo/execution.py @@ -117,11 +117,9 @@ def _policy( # Sample from the policy and compute the log likelihood. action = tfd.Categorical(logits).sample() - # action = tf2_utils.to_numpy_squeeze(action) # Cast for compatibility with reverb. # sample() returns a 'int32', which is a problem. - # if isinstance(policy, tfp.distributions.Categorical): action = tf.cast(action, "int64") return action, logits @@ -317,7 +315,6 @@ def _policy( # Cast for compatibility with reverb. # sample() returns a 'int32', which is a problem. - # if isinstance(policy, tfp.distributions.Categorical): action = tf.cast(action, "int64") return action, logits, new_state From d7f8ba553182cee9e8d92179cde90ce5938c070e Mon Sep 17 00:00:00 2001 From: Dries Date: Tue, 14 Dec 2021 14:01:19 +0200 Subject: [PATCH 019/104] Add termination condition to MA-PPO. --- mava/systems/tf/mappo/builder.py | 7 +++++++ mava/systems/tf/mappo/system.py | 8 ++++++++ 2 files changed, 15 insertions(+) diff --git a/mava/systems/tf/mappo/builder.py b/mava/systems/tf/mappo/builder.py index be335d649..13f02666c 100644 --- a/mava/systems/tf/mappo/builder.py +++ b/mava/systems/tf/mappo/builder.py @@ -67,6 +67,11 @@ class MAPPOConfig: norm during optimization. checkpoint: boolean to indicate whether to checkpoint models. checkpoint_subpath: subdirectory specifying where to store checkpoints. + termination_condition: An optional terminal condition can be provided + that stops the program once the condition is satisfied. Available options + include specifying maximum values for trainer_steps, trainer_walltime, + evaluator_steps, evaluator_episodes, executor_episodes or executor_steps. + E.g. termination_condition = {'trainer_steps': 100000}. replay_table_name: string indicating what name to give the replay table. evaluator_interval: intervals that evaluator are run at. learning_rate_scheduler_fn: function/class that takes in a trainer step t @@ -98,6 +103,7 @@ class MAPPOConfig: checkpoint: bool = True checkpoint_subpath: str = "~/mava/" replay_table_name: str = reverb_adders.DEFAULT_PRIORITY_TABLE + termination_condition: Optional[Dict[str, int]] = None learning_rate_scheduler_fn: Optional[Any] = None evaluator_interval: Optional[dict] = None @@ -337,6 +343,7 @@ def make_variable_server( self._config.checkpoint, self._config.checkpoint_subpath, self._config.checkpoint_minute_interval, + self._config.termination_condition, ) return variable_source diff --git a/mava/systems/tf/mappo/system.py b/mava/systems/tf/mappo/system.py index 230245acc..262ffd383 100644 --- a/mava/systems/tf/mappo/system.py +++ b/mava/systems/tf/mappo/system.py @@ -88,6 +88,7 @@ def __init__( # noqa eval_loop_fn: Callable = ParallelEnvironmentLoop, train_loop_fn_kwargs: Dict = {}, eval_loop_fn_kwargs: Dict = {}, + termination_condition: Optional[Dict[str, int]] = None, evaluator_interval: Optional[dict] = None, learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, ): @@ -164,6 +165,12 @@ def __init__( # noqa to the training loop. Defaults to {}. eval_loop_fn_kwargs (Dict, optional): possible keyword arguments to send to the evaluation loop. Defaults to {}. + termination_condition: An optional terminal condition can be + provided that stops the program once the condition is + satisfied. Available options include specifying maximum + values for trainer_steps, trainer_walltime, evaluator_steps, + evaluator_episodes, executor_episodes or executor_steps. + E.g. termination_condition = {'trainer_steps': 100000}. learning_rate_scheduler_fn: dict with two functions/classes (one for the policy and one for the critic optimizer), that takes in a trainer step t and returns the current learning rate, @@ -354,6 +361,7 @@ def __init__( # noqa network_sampling_setup=self._network_sampling_setup, # type: ignore net_keys_to_ids=net_keys_to_ids, unique_net_keys=unique_net_keys, + termination_condition=termination_condition, evaluator_interval=evaluator_interval, learning_rate_scheduler_fn=learning_rate_scheduler_fn, ), From aef20d27bf6d32412ae6155b080bc797640b43bf Mon Sep 17 00:00:00 2001 From: Dries Date: Wed, 12 Jan 2022 11:45:54 +0200 Subject: [PATCH 020/104] Address PR comments. --- mava/systems/tf/mappo/builder.py | 2 +- mava/systems/tf/mappo/system.py | 4 ++-- mava/systems/tf/mappo/training.py | 7 ++++--- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/mava/systems/tf/mappo/builder.py b/mava/systems/tf/mappo/builder.py index 13f02666c..7ca0556a2 100644 --- a/mava/systems/tf/mappo/builder.py +++ b/mava/systems/tf/mappo/builder.py @@ -254,7 +254,7 @@ def make_dataset_iterator( dataset = reverb.TrajectoryDataset.from_table_signature( server_address=replay_client.server_address, table=table_name, - max_in_flight_samples_per_worker=1, + max_in_flight_samples_per_worker=2 * self._config.batch_size, ).batch( self._config.batch_size, drop_remainder=True ) # .prefetch(self._config.prefetch_size) diff --git a/mava/systems/tf/mappo/system.py b/mava/systems/tf/mappo/system.py index 262ffd383..33a9ec190 100644 --- a/mava/systems/tf/mappo/system.py +++ b/mava/systems/tf/mappo/system.py @@ -71,7 +71,7 @@ def __init__( # noqa ] = snt.optimizers.Adam(learning_rate=5e-4), critic_optimizer: snt.Optimizer = snt.optimizers.Adam(learning_rate=1e-5), discount: float = 0.999, - lambda_gae: float = 1.0, + lambda_gae: float = 0.95, clipping_epsilon: float = 0.2, entropy_cost: float = 0.01, baseline_cost: float = 0.5, @@ -79,7 +79,7 @@ def __init__( # noqa max_queue_size: int = 1000, batch_size: int = 32, sequence_length: int = 20, - sequence_period: int = 10, + sequence_period: int = 19, checkpoint: bool = True, checkpoint_subpath: str = "~/mava/", checkpoint_minute_interval: int = 5, diff --git a/mava/systems/tf/mappo/training.py b/mava/systems/tf/mappo/training.py index ab5649829..b7d7856ab 100644 --- a/mava/systems/tf/mappo/training.py +++ b/mava/systems/tf/mappo/training.py @@ -59,7 +59,8 @@ def __init__( counts: Dict[str, Any], agent_net_keys: Dict[str, str], discount: float = 0.999, - lambda_gae: float = 1.0, # Question (dries): What is this used for? + lambda_gae: float = 0.95, # Note (dries): This might still be wrong + # e.g. should always be 1.0. entropy_cost: float = 0.01, baseline_cost: float = 0.5, clipping_epsilon: float = 0.2, @@ -550,7 +551,7 @@ def __init__( counts: Dict[str, Any], agent_net_keys: Dict[str, str], discount: float = 0.999, - lambda_gae: float = 1.0, + lambda_gae: float = 0.95, entropy_cost: float = 0.01, baseline_cost: float = 0.5, clipping_epsilon: float = 0.2, @@ -613,7 +614,7 @@ def __init__( variable_client: VariableClient, agent_net_keys: Dict[str, str], discount: float = 0.999, - lambda_gae: float = 1.0, + lambda_gae: float = 0.95, entropy_cost: float = 0.01, baseline_cost: float = 0.5, clipping_epsilon: float = 0.2, From 92abb30dd02f55d46a5b74ec6267b60bb5c8f425 Mon Sep 17 00:00:00 2001 From: Dries Date: Thu, 13 Jan 2022 09:45:27 +0200 Subject: [PATCH 021/104] Add the capability for MAPPO to use continuous action spaces. --- .../feedforward/decentralised/run_mappo.py | 105 ++++++++++++++++++ mava/systems/tf/mappo/builder.py | 4 +- mava/systems/tf/mappo/execution.py | 38 ++++--- mava/systems/tf/mappo/training.py | 39 +++++-- 4 files changed, 157 insertions(+), 29 deletions(-) create mode 100644 examples/petting_zoo/sisl/multiwalker/feedforward/decentralised/run_mappo.py diff --git a/examples/petting_zoo/sisl/multiwalker/feedforward/decentralised/run_mappo.py b/examples/petting_zoo/sisl/multiwalker/feedforward/decentralised/run_mappo.py new file mode 100644 index 000000000..b9ba80ab4 --- /dev/null +++ b/examples/petting_zoo/sisl/multiwalker/feedforward/decentralised/run_mappo.py @@ -0,0 +1,105 @@ +# python3 +# Copyright 2021 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Example running continous MAPPO on pettinzoo SISL environments.""" + +import functools +from datetime import datetime +from typing import Any + +import launchpad as lp +import sonnet as snt +from absl import app, flags + +from mava.systems.tf import mappo +from mava.utils import lp_utils +from mava.utils.environments import pettingzoo_utils +from mava.utils.loggers import logger_utils + +FLAGS = flags.FLAGS + +flags.DEFINE_string( + "env_class", + "sisl", + "Pettingzoo environment class, e.g. atari (str).", +) + +flags.DEFINE_string( + "env_name", + "multiwalker_v7", + "Pettingzoo environment name, e.g. pong (str).", +) +flags.DEFINE_string( + "mava_id", + str(datetime.now()), + "Experiment identifier that can be used to continue experiments.", +) +flags.DEFINE_string("base_dir", "~/mava", "Base dir to store experiments.") + + +def main(_: Any) -> None: + # Environment. + environment_factory = functools.partial( + pettingzoo_utils.make_environment, + env_class=FLAGS.env_class, + env_name=FLAGS.env_name, + remove_on_fall=False, + ) + + # Networks. + network_factory = lp_utils.partial_kwargs(mappo.make_default_networks) + + # Checkpointer appends "Checkpoints" to checkpoint_dir. + checkpoint_dir = f"{FLAGS.base_dir}/{FLAGS.mava_id}" + + # Log every [log_every] seconds. + log_every = 10 + logger_factory = functools.partial( + logger_utils.make_logger, + directory=FLAGS.base_dir, + to_terminal=True, + to_tensorboard=True, + time_stamp=FLAGS.mava_id, + time_delta=log_every, + ) + + # Distributed program. + program = mappo.MAPPO( + environment_factory=environment_factory, + network_factory=network_factory, + logger_factory=logger_factory, + num_executors=1, + policy_optimizer=snt.optimizers.Adam(learning_rate=1e-4), + critic_optimizer=snt.optimizers.Adam(learning_rate=1e-4), + checkpoint_subpath=checkpoint_dir, + max_gradient_norm=40.0, + ).build() + + # Ensure only trainer runs on gpu, while other processes run on cpu. + local_resources = lp_utils.to_device( + program_nodes=program.groups.keys(), nodes_on_gpu=["trainer"] + ) + + # Launch. + lp.launch( + program, + lp.LaunchType.LOCAL_MULTI_PROCESSING, + terminal="current_terminal", + local_resources=local_resources, + ) + + +if __name__ == "__main__": + app.run(main) diff --git a/mava/systems/tf/mappo/builder.py b/mava/systems/tf/mappo/builder.py index 7ca0556a2..84311e7d2 100644 --- a/mava/systems/tf/mappo/builder.py +++ b/mava/systems/tf/mappo/builder.py @@ -155,9 +155,7 @@ def add_logits_to_spec( new_act_spec = {"actions": agent_spec.actions} # Make dummy logits - new_act_spec["logits"] = tf.ones( - shape=(agent_spec.actions.num_values), dtype=tf.float32 - ) + new_act_spec["log_probs"] = tf.ones(shape=(), dtype=tf.float32) env_adder_spec._specs[key] = EnvironmentSpec( observations=agent_spec.observations, diff --git a/mava/systems/tf/mappo/execution.py b/mava/systems/tf/mappo/execution.py index e11f68896..8f84d05f5 100644 --- a/mava/systems/tf/mappo/execution.py +++ b/mava/systems/tf/mappo/execution.py @@ -113,16 +113,22 @@ def _policy( agent_key = self._agent_net_keys[agent] # Compute the policy, conditioned on the observation. - logits = self._policy_networks[agent_key](batched_observation) - - # Sample from the policy and compute the log likelihood. - action = tfd.Categorical(logits).sample() - - # Cast for compatibility with reverb. - # sample() returns a 'int32', which is a problem. - action = tf.cast(action, "int64") - - return action, logits + policy = self._policy_networks[agent_key](batched_observation) + + if type(policy) == tfd.transformed_distribution.TransformedDistribution: + # Sample from the policy and compute the log likelihood. + action = policy.sample() + log_prob = policy.log_prob(action) + else: + # Sample from the policy and compute the log likelihood. + policy = tfd.Categorical(policy) + action = policy.sample() + log_prob = policy.log_prob(action) + + # Cast for compatibility with reverb. + # sample() returns a 'int32', which is a problem. + action = tf.cast(action, "int64") + return action, log_prob def select_action( self, agent: str, observation: types.NestedArray @@ -138,15 +144,15 @@ def select_action( types.NestedArray: action and policy. """ # Step the recurrent policy forward given the current observation and state. - action, logits = self._policy( + action, log_prob = self._policy( agent, observation.observation, ) # Return a numpy array with squeezed out batch dimension. action = tf2_utils.to_numpy_squeeze(action) - logits = tf2_utils.to_numpy_squeeze(logits) - return action, logits + log_prob = tf2_utils.to_numpy_squeeze(log_prob) + return action, log_prob def select_actions( self, observations: Dict[str, types.NestedArray] @@ -164,10 +170,10 @@ def select_actions( # TODO (dries): Add this to a function and add tf.function here. actions = {} - logits = {} + log_prob = {} for agent, observation in observations.items(): - actions[agent], logits[agent] = self.select_action(agent, observation) - return actions, logits + actions[agent], log_prob[agent] = self.select_action(agent, observation) + return actions, log_prob def observe_first( self, diff --git a/mava/systems/tf/mappo/training.py b/mava/systems/tf/mappo/training.py index b7d7856ab..91055342c 100644 --- a/mava/systems/tf/mappo/training.py +++ b/mava/systems/tf/mappo/training.py @@ -300,11 +300,11 @@ def _forward_pass(self, inputs: Any) -> None: with tf.GradientTape(persistent=True) as tape: for agent in self._agents: - action, reward, discount, behaviour_logits, actor_observation = ( + action, reward, discount, behaviour_log_prob, actor_observation = ( actions[agent]["actions"], rewards[agent], discounts[agent], - actions[agent]["logits"], + actions[agent]["log_probs"], observations_trans[agent], ) @@ -313,7 +313,7 @@ def _forward_pass(self, inputs: Any) -> None: ) # Chop off final timestep for bootstrapping value - action = action[:-1] + behaviour_log_prob = behaviour_log_prob[:-1] reward = reward[:-1] discount = discount[:-1] @@ -336,13 +336,28 @@ def _forward_pass(self, inputs: Any) -> None: else: # Pass observations through the feedforward policy actor_observation, dims = train_utils.combine_dim(actor_observation) - logits = train_utils.extract_dim( - policy_network(actor_observation), dims - ) + + logits = policy_network(actor_observation) + + if type(logits) != tfp.distributions.TransformedDistribution: + logits = train_utils.extract_dim(logits, dims) + else: + logits = tfp.distributions.BatchReshape(logits, dims) + # logits = logits.reshape(dims) # Compute importance sampling weights: current policy / behavior policy. - pi_behaviour = tfd.Categorical(logits=behaviour_logits[:-1]) - pi_target = tfd.Categorical(logits=logits[:-1]) + if type(logits) in [ + tfd.transformed_distribution.TransformedDistribution, + tfp.distributions.BatchReshape, + ]: + # Sample from the policy and compute the log likelihood. + pi_target = logits + pi_target_log_prob = pi_target.log_prob(action)[:-1] + else: + # Sample from the policy and compute the log likelihood. + pi_target = tfd.Categorical(logits=logits[:-1]) + action = action[:-1] + pi_target_log_prob = pi_target.log_prob(action) # Calculate critic values. dims = critic_observation.shape[:2] @@ -374,7 +389,7 @@ def _forward_pass(self, inputs: Any) -> None: ) # Compute importance sampling weights: current policy / behavior policy. - log_rhos = pi_target.log_prob(action) - pi_behaviour.log_prob(action) + log_rhos = pi_target_log_prob - behaviour_log_prob importance_ratio = tf.exp(log_rhos) clipped_importance_ratio = tf.clip_by_value( importance_ratio, @@ -397,7 +412,11 @@ def _forward_pass(self, inputs: Any) -> None: ) # Entropy regulariser. - entropy_loss = trfl.policy_entropy_loss(pi_target).loss + # Entropy regularization. Only implemented for categorical dist. + try: + entropy_loss = trfl.policy_entropy_loss(pi_target).loss + except NotImplementedError: + entropy_loss = tf.convert_to_tensor(0.0) # Combine weighted sum of actor & entropy regularization. policy_loss = policy_gradient_loss + entropy_loss From 276933ca1907c5dd3a6c389301b1c512563351bc Mon Sep 17 00:00:00 2001 From: Dries Date: Wed, 9 Mar 2022 14:21:23 +0200 Subject: [PATCH 022/104] fix: mypy. --- examples/README.md | 16 +- .../feedforward/decentralised/run_mappo.py | 3 +- mava/adders/reverb/base.py | 7 +- .../components/tf/modules/mixing/monotonic.py | 92 ------ mava/systems/tf/mappo/execution.py | 8 +- mava/systems/tf/qmix/training.py | 305 ------------------ 6 files changed, 17 insertions(+), 414 deletions(-) rename examples/tf/{smac => }/sisl/multiwalker/feedforward/decentralised/run_mappo.py (95%) delete mode 100644 mava/components/tf/modules/mixing/monotonic.py delete mode 100644 mava/systems/tf/qmix/training.py diff --git a/examples/README.md b/examples/README.md index d528f4e75..fc748e6a9 100644 --- a/examples/README.md +++ b/examples/README.md @@ -56,7 +56,7 @@ We include a number of systems running on continuous control tasks. - **MAD4PG**: a MAD4PG system running on the RoboCup environment. - - *Recurrent* + - *Recurrent* - [state_based][robocup_mad4pg_ff_state_based]. ## Discrete control @@ -83,26 +83,26 @@ We also include a number of systems running on discrete action space environment - **VDN**: a VDN system running on the discrete action space simple_spread MPE environment. - - *Recurrent* + - *Recurrent* - [centralised][debug_vdn_rec_cen]. ### PettingZoo - Multi-Agent Atari - **MADQN**: a MADQN system running on the two-player competitive Atari Pong environment. - - *Recurrent* + - *Recurrent* - [decentralised][pz_madqn_pong_rec_dec]. ### PettingZoo - Multi-Agent Particle Environment - **MADDPG**: a MADDPG system running on the Simple Speaker Listener environment. - - *Feedforward* + - *Feedforward* - [decentralised][pz_maddpg_mpe_ssl_ff_dec]. - **MADDPG**: a MADDPG system running on the Simple Spread environment. - - *Feedforward* + - *Feedforward* - [decentralised][pz_maddpg_mpe_ss_ff_dec]. ### SMAC - StarCraft Multi-Agent Challenge @@ -116,19 +116,19 @@ We also include a number of systems running on discrete action space environment - **QMIX**: a QMIX system running on the SMAC environment. - - *Recurrent* + - *Recurrent* - [centralised][smac_qmix_rec_cen]. - **VDN**: a VDN system running on the SMAC environment. - - *Recurrent* + - *Recurrent* - [centralised][smac_vdn_rec_cen]. ### OpenSpiel - Tic Tac Toe - **MADQN**: a MADQN system running on the OpenSpiel environment. - - *Feedforward* + - *Feedforward* - [decentralised][openspiel_madqn_ff_dec]. diff --git a/examples/tf/smac/sisl/multiwalker/feedforward/decentralised/run_mappo.py b/examples/tf/sisl/multiwalker/feedforward/decentralised/run_mappo.py similarity index 95% rename from examples/tf/smac/sisl/multiwalker/feedforward/decentralised/run_mappo.py rename to examples/tf/sisl/multiwalker/feedforward/decentralised/run_mappo.py index b9ba80ab4..03acbced5 100644 --- a/examples/tf/smac/sisl/multiwalker/feedforward/decentralised/run_mappo.py +++ b/examples/tf/sisl/multiwalker/feedforward/decentralised/run_mappo.py @@ -81,8 +81,7 @@ def main(_: Any) -> None: network_factory=network_factory, logger_factory=logger_factory, num_executors=1, - policy_optimizer=snt.optimizers.Adam(learning_rate=1e-4), - critic_optimizer=snt.optimizers.Adam(learning_rate=1e-4), + optimizer=snt.optimizers.Adam(learning_rate=1e-4), checkpoint_subpath=checkpoint_dir, max_gradient_norm=40.0, ).build() diff --git a/mava/adders/reverb/base.py b/mava/adders/reverb/base.py index 3c37bd12c..27ff6b2a8 100644 --- a/mava/adders/reverb/base.py +++ b/mava/adders/reverb/base.py @@ -16,7 +16,6 @@ """Adders that use Reverb (github.com/deepmind/reverb) as a backend.""" import copy -import time from typing import ( Any, Callable, @@ -39,8 +38,6 @@ from acme import types from acme.adders.reverb.base import ReverbAdder -_MIN_WRITER_LIFESPAN_SECONDS = 60 - from mava import types as mava_types from mava.adders.base import ParallelAdder from mava.utils.sort_utils import sort_str_num @@ -392,7 +389,7 @@ def add_first( ) self._add_first_called = True - def reset_here(self, timeout_ms: Optional[int] = None): + def reset_writer(self, timeout_ms: Optional[int] = None) -> None: """Resets the adder's buffer.""" if self._writer: # Flush all appended data and clear the buffers. @@ -444,4 +441,4 @@ def add( dummy_step = tree.map_structure(np.zeros_like, current_step) self._writer.append(dummy_step) self._write_last() - self.reset_here() + self.reset_writer() diff --git a/mava/components/tf/modules/mixing/monotonic.py b/mava/components/tf/modules/mixing/monotonic.py deleted file mode 100644 index 2a8ed1305..000000000 --- a/mava/components/tf/modules/mixing/monotonic.py +++ /dev/null @@ -1,92 +0,0 @@ -# python3 -# Copyright 2021 InstaDeep Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Mixing for multi-agent RL systems""" -from typing import Dict, Union - -import sonnet as snt -import tensorflow as tf -from acme.tf import utils as tf2_utils - -from mava import specs as mava_specs -from mava.components.tf.architectures.decentralised import ( - DecentralisedValueActor, - DecentralisedValueActorCritic, -) -from mava.components.tf.modules.mixing import BaseMixingModule -from mava.components.tf.networks.monotonic import MonotonicMixingNetwork - - -class MonotonicMixing(BaseMixingModule): - """Multi-agent monotonic mixing architecture. - This is the component which can be used to add monotonic mixing to an underlying - agent architecture. It currently supports generalised monotonic mixing using - hypernetworks (1 or 2 layers) for control of decomposition parameters (QMix).""" - - def __init__( - self, - environment_spec: mava_specs.MAEnvironmentSpec, - architecture: Union[DecentralisedValueActor, DecentralisedValueActorCritic], - qmix_hidden_dim: int = 32, - num_hypernet_layers: int = 2, - hypernet_hidden_dim: int = 64, # Defaults to qmix_hidden_dim - ) -> None: - """Initializes the mixer. - Args: - architecture: the BaseArchitecture used. - """ - super(MonotonicMixing, self).__init__() - - assert hasattr( - architecture, "_n_agents" - ), "Architecture doesn't have _n_agents." - self._environment_spec = environment_spec - self._qmix_hidden_dim = qmix_hidden_dim - self._num_hypernet_layers = num_hypernet_layers - self._hypernet_hidden_dim = hypernet_hidden_dim - self._n_agents = architecture._n_agents - self._architecture = architecture - self._agent_networks: Dict[str, snt.Module] = {} - - def _create_mixing_layer(self, name: str = "mixing") -> snt.Module: - """Modify and return system architecture given mixing structure.""" - state_specs = self._environment_spec.get_extra_specs() - state_specs = state_specs["env_states"] - - q_value_dim = tf.TensorSpec(self._n_agents) - - # Implement method from base class - self._mixed_network = MonotonicMixingNetwork( - n_agents=self._n_agents, - qmix_hidden_dim=self._qmix_hidden_dim, - num_hypernet_layers=self._num_hypernet_layers, - hypernet_hidden_dim=self._hypernet_hidden_dim, - name=name, - ) - - tf2_utils.create_variables(self._mixed_network, [q_value_dim, state_specs]) - return self._mixed_network - - def create_system(self) -> Dict[str, Dict[str, snt.Module]]: - # Implement method from base class - self._agent_networks[ - "agent_networks" - ] = self._architecture.create_actor_variables() - self._agent_networks["mixing"] = self._create_mixing_layer(name="mixing") - self._agent_networks["target_mixing"] = self._create_mixing_layer( - name="target_mixing" - ) - - return self._agent_networks diff --git a/mava/systems/tf/mappo/execution.py b/mava/systems/tf/mappo/execution.py index eac37ba2d..f1240663a 100644 --- a/mava/systems/tf/mappo/execution.py +++ b/mava/systems/tf/mappo/execution.py @@ -29,7 +29,7 @@ from mava import adders from mava.systems.tf import executors -from mava.types import OLT, NestedArray +from mava.types import OLT from mava.utils.sort_utils import sample_new_agent_keys, sort_str_num from mava.utils.training_utils import action_mask_categorical_policies @@ -353,7 +353,11 @@ def _select_actions( self, observations: Dict[str, types.NestedArray], states: Dict[str, types.NestedArray], - ) -> Tuple[Dict[str, types.NestedArray], Dict[str, types.NestedArray]]: + ) -> Tuple[ + Dict[str, types.NestedArray], + Dict[str, types.NestedArray], + Dict[str, types.NestedArray], + ]: """select the actions for all agents in the system Args: diff --git a/mava/systems/tf/qmix/training.py b/mava/systems/tf/qmix/training.py deleted file mode 100644 index a3db09e5c..000000000 --- a/mava/systems/tf/qmix/training.py +++ /dev/null @@ -1,305 +0,0 @@ -# python3 -# Copyright 2021 InstaDeep Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""QMIX system trainer implementation.""" - -from typing import Any, Callable, Dict, List, Optional, Sequence - -import numpy as np -import reverb -import sonnet as snt -import tensorflow as tf -from acme.tf import utils as tf2_utils -from acme.utils import counting, loggers -from trfl.indexing_ops import batched_index - -from mava import types as mava_types -from mava.components.tf.modules.communication import BaseCommunicationModule -from mava.systems.tf.madqn.training import MADQNTrainer -from mava.utils import training_utils as train_utils - -train_utils.set_growing_gpu_memory() - - -class QMIXTrainer(MADQNTrainer): - """QMIX trainer. - This is the trainer component of a QMIX system. i.e. it takes a dataset as input - and implements update functionality to learn from this dataset. - """ - - def __init__( - self, - agents: List[str], - agent_types: List[str], - q_networks: Dict[str, snt.Module], - target_q_networks: Dict[str, snt.Module], - mixing_network: snt.Module, - target_mixing_network: snt.Module, - target_update_period: int, - dataset: tf.data.Dataset, - optimizer: snt.Optimizer, - discount: float, - agent_net_keys: Dict[str, str], - checkpoint_minute_interval: int, - communication_module: Optional[BaseCommunicationModule] = None, - max_gradient_norm: float = None, - counter: counting.Counter = None, - fingerprint: bool = False, - logger: loggers.Logger = None, - checkpoint: bool = True, - checkpoint_subpath: str = "~/mava/", - learning_rate_scheduler_fn: Optional[Callable[[int], None]] = None, - ) -> None: - """Initialise QMIX trainer - - Args: - agents (List[str]): agent ids, e.g. "agent_0". - agent_types (List[str]): agent types, e.g. "speaker" or "listener". - q_networks (Dict[str, snt.Module]): q-value networks. - target_q_networks (Dict[str, snt.Module]): target q-value networks. - mixing_network (snt.Module): mixing networks learning factorised q-value - weights. - target_mixing_network (snt.Module): target mixing networks. - target_update_period (int): number of steps before updating target networks. - dataset (tf.data.Dataset): training dataset. - optimizer (Union[snt.Optimizer, Dict[str, snt.Optimizer]]): type of - optimizer for updating the parameters of the networks. - discount (float): discount factor for TD updates. - agent_net_keys: (dict, optional): specifies what network each agent uses. - Defaults to {}. - checkpoint_minute_interval (int): The number of minutes to wait between - checkpoints. - communication_module (BaseCommunicationModule): module for communication - between agents. Defaults to None. - max_gradient_norm (float, optional): maximum allowed norm for gradients - before clipping is applied. Defaults to None. - counter (counting.Counter, optional): step counter object. Defaults to None. - fingerprint (bool, optional): whether to apply replay stabilisation using - policy fingerprints. Defaults to False. - logger (loggers.Logger, optional): logger object for logging trainer - statistics. Defaults to None. - checkpoint (bool, optional): whether to checkpoint networks. Defaults to - True. - checkpoint_subpath (str, optional): subdirectory for storing checkpoints. - Defaults to "~/mava/". - learning_rate_scheduler_fn: function/class that takes in a trainer step t - and returns the current learning rate. - """ - - self._mixing_network = mixing_network - self._target_mixing_network = target_mixing_network - self._optimizer = optimizer - - super(QMIXTrainer, self).__init__( - agents=agents, - agent_types=agent_types, - q_networks=q_networks, - target_q_networks=target_q_networks, - target_update_period=target_update_period, - dataset=dataset, - optimizer=optimizer, - discount=discount, - agent_net_keys=agent_net_keys, - checkpoint_minute_interval=checkpoint_minute_interval, - communication_module=communication_module, - max_gradient_norm=max_gradient_norm, - counter=counter, - fingerprint=fingerprint, - logger=logger, - checkpoint=checkpoint, - checkpoint_subpath=checkpoint_subpath, - learning_rate_scheduler_fn=learning_rate_scheduler_fn, - ) - - # Checkpoint the mixing networks - # TODO add checkpointing for mixing networks - - def _update_target_networks(self) -> None: - """Sync the target network parameters with the latest online network - parameters""" - - # Update target networks (incl. mixing networks). - if tf.math.mod(self._num_steps, self._target_update_period) == 0: - for key in self.unique_net_keys: - online_variables = [ - *self._q_networks[key].variables, - ] - target_variables = [ - *self._target_q_networks[key].variables, - ] - - # Make online -> target network update ops. - for src, dest in zip(online_variables, target_variables): - dest.assign(src) - - # NOTE These shouldn't really be in the agent for loop. - online_variables = [*self._mixing_network.variables] - target_variables = [*self._target_mixing_network.variables] - - # Make online -> target network update ops. - for src, dest in zip(online_variables, target_variables): - dest.assign(src) - - self._num_steps.assign_add(1) - - @tf.function - def _step( - self, - ) -> Dict[str, Dict[str, Any]]: - """Trainer forward and backward passes.""" - - # Update the target networks - self._update_target_networks() - - # Get data from replay (dropping extras if any). Note there is no - # extra data here because we do not insert any into Reverb. - inputs = next(self._iterator) - - self._forward(inputs) - - self._backward() - - # Log losses per agent - return {agent: {"policy_loss": self.loss} for agent in self._agents} - - def _forward(self, inputs: reverb.ReplaySample) -> None: - """Trainer forward pass - - Args: - inputs (Any): input data from the data table (transitions) - """ - - # Unpack input data as follows: - # o_tm1 = dictionary of observations one for each agent - # a_tm1 = dictionary of actions taken from obs in o_tm1 - # e_tm1 [Optional] = extra data that the agents persist in replay. - # r_t = dictionary of rewards or rewards sequences - # (if using N step transitions) ensuing from actions a_tm1 - # d_t = environment discount ensuing from actions a_tm1. - # This discount is applied to future rewards after r_t. - # o_t = dictionary of next observations or next observation sequences - # e_t = [Optional] = extra data that the agents persist in replay. - trans = mava_types.Transition(*inputs.data) - - o_tm1, o_t, a_tm1, r_t, d_t, e_tm1, e_t = ( - trans.observations, - trans.next_observations, - trans.actions, - trans.rewards, - trans.discounts, - trans.extras, - trans.next_extras, - ) - - s_tm1 = e_tm1["env_states"] - s_t = e_t["env_states"] - - # Do forward passes through the networks and calculate the losses - with tf.GradientTape(persistent=True) as tape: - q_acts = [] # Q vals - q_targets = [] # Target Q vals - for agent in self._agents: - agent_key = self._agent_net_keys[agent] - - o_tm1_feed, o_t_feed, a_tm1_feed = self._get_feed( - o_tm1, o_t, a_tm1, agent - ) - q_tm1 = self._q_networks[agent_key](o_tm1_feed) - q_t_value = self._target_q_networks[agent_key](o_t_feed) - q_t_selector = self._q_networks[agent_key](o_t_feed) - best_action = tf.argmax(q_t_selector, axis=1, output_type=tf.int32) - - # TODO Make use of q_t_selector for fingerprinting. Speak to Claude. - q_act = batched_index(q_tm1, a_tm1_feed, keepdims=True) # [B, 1] - q_target = batched_index( - q_t_value, best_action, keepdims=True - ) # [B, 1] - - q_acts.append(q_act) - q_targets.append(q_target) - - rewards = tf.concat( - [tf.reshape(val, (-1, 1)) for val in list(r_t.values())], axis=1 - ) - rewards = tf.reduce_mean(rewards, axis=1) # [B] - - pcont = tf.concat( - [tf.reshape(val, (-1, 1)) for val in list(d_t.values())], axis=1 - ) - pcont = tf.reduce_mean(pcont, axis=1) - discount = tf.cast(self._discount, list(d_t.values())[0].dtype) - pcont = discount * pcont # [B] - - q_acts = tf.concat(q_acts, axis=1) # [B, num_agents] - q_targets = tf.concat(q_targets, axis=1) # [B, num_agents] - - q_tot_mixed = self._mixing_network(q_acts, s_tm1) # [B, 1, 1] - q_tot_target_mixed = self._target_mixing_network( - q_targets, s_t - ) # [B, 1, 1] - - # Calculate Q loss. - targets = rewards + pcont * q_tot_target_mixed - td_error = targets - q_tot_mixed - - # Loss is MSE scaled by 0.5, so the gradient is equal to the TD error. - self.loss = 0.5 * tf.reduce_mean(tf.square(td_error)) - self.tape = tape - - def _backward(self) -> None: - """Trainer backward pass updating network parameters""" - - for agent in self._agents: - agent_key = self._agent_net_keys[agent] - # Update agent networks - variables = [*self._q_networks[agent_key].trainable_variables] - gradients = self.tape.gradient(self.loss, variables) - gradients = tf.clip_by_global_norm(gradients, self._max_gradient_norm)[0] - self._optimizers[agent_key].apply(gradients, variables) - - # Update mixing network - variables = [*self._mixing_network.trainable_variables] - gradients = self.tape.gradient(self.loss, variables) - - gradients = tf.clip_by_global_norm(gradients, self._max_gradient_norm)[0] - self._optimizer.apply(gradients, variables) - - train_utils.safe_del(self, "tape") - - # TODO(Kale-ab): Ini _system_network_variables - def get_variables(self, names: Sequence[str]) -> Dict[str, Dict[str, np.ndarray]]: - """get network variables - - Args: - names (Sequence[str]): network names - - Returns: - Dict[str, Dict[str, np.ndarray]]: network variables - """ - - variables: Dict[str, Dict[str, np.ndarray]] = {} - variables = {} - for network_type in names: - if network_type == "mixing": - # Includes the hypernet variables - variables[network_type] = self._mixing_network.variables - else: # Collect variables for each agent network - variables[network_type] = { - key: tf2_utils.to_numpy( - self._system_network_variables[network_type][key] - ) - for key in self.unique_net_keys - } - return variables From 28c3217af4cf87c70d64a9cf21d060cedf4d210c Mon Sep 17 00:00:00 2001 From: AsadJeewa Date: Tue, 22 Mar 2022 23:29:16 +0200 Subject: [PATCH 023/104] temp obs gradient var --- mava/systems/tf/maddpg/training.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/mava/systems/tf/maddpg/training.py b/mava/systems/tf/maddpg/training.py index 8fb792fa6..6b0d1541b 100644 --- a/mava/systems/tf/maddpg/training.py +++ b/mava/systems/tf/maddpg/training.py @@ -71,6 +71,7 @@ def __init__( max_gradient_norm: float = None, logger: loggers.Logger = None, learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, + update_obs_once: bool = False, ): """Initialise MADDPG trainer @@ -108,12 +109,14 @@ def __init__( learning_rate_scheduler_fn: dict with two functions (one for the policy and one for the critic optimizer), that takes in a trainer step t and returns the current learning rate. + update_obs_once: bool = use only critic gradients to update shared observatipon network """ self._agents = agents self._agent_net_keys = agent_net_keys self._variable_client = variable_client self._learning_rate_scheduler_fn = learning_rate_scheduler_fn + self._update_obs_once = update_obs_once # Setup counts self._counts = counts @@ -936,6 +939,7 @@ def __init__( logger: loggers.Logger = None, bootstrap_n: int = 10, learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, + update_obs_once: bool = False, ): """Initialise Recurrent MADDPG trainer @@ -973,6 +977,7 @@ def __init__( learning_rate_scheduler_fn: dict with two functions (one for the policy and one for the critic optimizer), that takes in a trainer step t and returns the current learning rate. + update_obs_once: bool = use only critic gradients to update shared observatipon network """ self._bootstrap_n = bootstrap_n @@ -980,6 +985,7 @@ def __init__( self._agent_net_keys = agent_net_keys self._variable_client = variable_client self._learning_rate_scheduler_fn = learning_rate_scheduler_fn + self._update_obs_once = update_obs_once # Setup counts self._counts = counts From d6595f70c8059f5ffcbe0a0e3a39206c37ddacec Mon Sep 17 00:00:00 2001 From: AsadJeewa Date: Tue, 22 Mar 2022 23:29:44 +0200 Subject: [PATCH 024/104] temp obs gradient var --- mava/systems/tf/maddpg/builder.py | 2 ++ mava/systems/tf/maddpg/system.py | 3 +++ 2 files changed, 5 insertions(+) diff --git a/mava/systems/tf/maddpg/builder.py b/mava/systems/tf/maddpg/builder.py index c01550608..41a68c33c 100644 --- a/mava/systems/tf/maddpg/builder.py +++ b/mava/systems/tf/maddpg/builder.py @@ -130,6 +130,7 @@ class MADDPGConfig: termination_condition: Optional[Dict[str, int]] = None evaluator_interval: Optional[dict] = None learning_rate_scheduler_fn: Optional[Any] = None + update_obs_once: bool = False class MADDPGBuilder: @@ -604,6 +605,7 @@ def make_trainer( "counts": counts, "logger": logger, "learning_rate_scheduler_fn": self._config.learning_rate_scheduler_fn, + "update_obs_once": self._config.update_obs_once, } if connection_spec: trainer_config["connection_spec"] = connection_spec diff --git a/mava/systems/tf/maddpg/system.py b/mava/systems/tf/maddpg/system.py index baecaa884..c57b779e2 100644 --- a/mava/systems/tf/maddpg/system.py +++ b/mava/systems/tf/maddpg/system.py @@ -108,6 +108,7 @@ def __init__( # noqa termination_condition: Optional[Dict[str, int]] = None, evaluator_interval: Optional[dict] = None, learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, + update_obs_once: bool = False, ): """Initialise the system @@ -205,6 +206,7 @@ def __init__( # noqa happen at every timestep. E.g. to evaluate a system after every 100 executor episodes, evaluator_interval = {"executor_episodes": 100}. + update_obs_once: bool = use only critic gradients to update shared observatipon network """ if not environment_spec: @@ -398,6 +400,7 @@ def __init__( # noqa termination_condition=termination_condition, evaluator_interval=evaluator_interval, learning_rate_scheduler_fn=learning_rate_scheduler_fn, + update_obs_once=update_obs_once, ), trainer_fn=trainer_fn, executor_fn=executor_fn, From 19671465e622cd02f4b012b0621bf7392168889b Mon Sep 17 00:00:00 2001 From: Dries Date: Wed, 23 Mar 2022 08:05:22 +0200 Subject: [PATCH 025/104] fix: Test small fix. --- mava/adders/reverb/base.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/mava/adders/reverb/base.py b/mava/adders/reverb/base.py index 27ff6b2a8..1b1071a79 100644 --- a/mava/adders/reverb/base.py +++ b/mava/adders/reverb/base.py @@ -94,8 +94,7 @@ def get_trajectory_net_agents( trajectory: Union[Trajectory, mava_types.Transition], trajectory_net_keys: Dict[str, str], ) -> Tuple[List, Dict[str, List]]: - """Returns a dictionary that maps network_keys to a list of agents using that - specific network. + """Maps network_keys to a list of agents using that specific network. Args: trajectory: Episode experience recorded by @@ -441,4 +440,7 @@ def add( dummy_step = tree.map_structure(np.zeros_like, current_step) self._writer.append(dummy_step) self._write_last() - self.reset_writer() + + self.reset() + # TODO (dries): Why does reset writer not work? + # self.reset_writer() From 1fcaf0c5d750d7dc80f4e4253f3e589bf235494d Mon Sep 17 00:00:00 2001 From: Dries Date: Wed, 23 Mar 2022 08:22:54 +0200 Subject: [PATCH 026/104] fix: Change writer back. --- mava/adders/reverb/base.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/mava/adders/reverb/base.py b/mava/adders/reverb/base.py index 1b1071a79..39c430b3c 100644 --- a/mava/adders/reverb/base.py +++ b/mava/adders/reverb/base.py @@ -388,13 +388,6 @@ def add_first( ) self._add_first_called = True - def reset_writer(self, timeout_ms: Optional[int] = None) -> None: - """Resets the adder's buffer.""" - if self._writer: - # Flush all appended data and clear the buffers. - self._writer.end_episode(clear_buffers=True, timeout_ms=timeout_ms) - self._add_first_called = False - def add( self, actions: Dict[str, types.NestedArray], @@ -440,7 +433,4 @@ def add( dummy_step = tree.map_structure(np.zeros_like, current_step) self._writer.append(dummy_step) self._write_last() - self.reset() - # TODO (dries): Why does reset writer not work? - # self.reset_writer() From 75ba2748b34fed0872c240ca9fc0e2b52204b070 Mon Sep 17 00:00:00 2001 From: Dries Date: Wed, 23 Mar 2022 08:24:30 +0200 Subject: [PATCH 027/104] fix: Change back. --- mava/adders/reverb/base.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/mava/adders/reverb/base.py b/mava/adders/reverb/base.py index 39c430b3c..32ff4b17c 100644 --- a/mava/adders/reverb/base.py +++ b/mava/adders/reverb/base.py @@ -388,6 +388,13 @@ def add_first( ) self._add_first_called = True + def reset_writer(self, timeout_ms: Optional[int] = None) -> None: + """Resets the adder's buffer.""" + if self._writer: + # Flush all appended data and clear the buffers. + self._writer.end_episode(clear_buffers=True, timeout_ms=timeout_ms) + self._add_first_called = False + def add( self, actions: Dict[str, types.NestedArray], @@ -433,4 +440,4 @@ def add( dummy_step = tree.map_structure(np.zeros_like, current_step) self._writer.append(dummy_step) self._write_last() - self.reset() + self.reset_writer() From ee366e7aa89195c1d74afef3e4af8a9cb6b05304 Mon Sep 17 00:00:00 2001 From: AsadJeewa Date: Wed, 23 Mar 2022 11:37:43 +0200 Subject: [PATCH 028/104] temp obs gradient var --- mava/systems/tf/maddpg/training.py | 34 +++++++++++++++++++++++------- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/mava/systems/tf/maddpg/training.py b/mava/systems/tf/maddpg/training.py index 6b0d1541b..7a81c0aa7 100644 --- a/mava/systems/tf/maddpg/training.py +++ b/mava/systems/tf/maddpg/training.py @@ -460,10 +460,13 @@ def _backward(self) -> None: agent_key = self._agent_net_keys[agent] # Get trainable variables. - policy_variables = ( - self._observation_networks[agent_key].trainable_variables - + self._policy_networks[agent_key].trainable_variables - ) + if self._update_obs_once: + policy_variables = self._policy_networks[agent_key].trainable_variables + else: + policy_variables = ( + self._observation_networks[agent_key].trainable_variables + + self._policy_networks[agent_key].trainable_variables + ) critic_variables = ( # In this agent, the critic loss trains the observation network. self._observation_networks[agent_key].trainable_variables @@ -578,6 +581,7 @@ def __init__( max_gradient_norm: float = None, logger: loggers.Logger = None, learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, + update_obs_once: bool = False, ): """Initialise the decentralised MADDPG trainer.""" super().__init__( @@ -602,6 +606,7 @@ def __init__( variable_client=variable_client, counts=counts, learning_rate_scheduler_fn=learning_rate_scheduler_fn, + update_obs_once=update_obs_once, ) @@ -631,6 +636,7 @@ def __init__( max_gradient_norm: float = None, logger: loggers.Logger = None, learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, + update_obs_once: bool = False, ): """Initialise the centralised MADDPG trainer.""" @@ -656,6 +662,7 @@ def __init__( variable_client=variable_client, counts=counts, learning_rate_scheduler_fn=learning_rate_scheduler_fn, + update_obs_once=update_obs_once, ) def _get_critic_feed( @@ -727,6 +734,7 @@ def __init__( max_gradient_norm: float = None, logger: loggers.Logger = None, learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, + update_obs_once: bool = False, ): """Initialise the networked MADDPG trainer.""" @@ -752,6 +760,7 @@ def __init__( variable_client=variable_client, counts=counts, learning_rate_scheduler_fn=learning_rate_scheduler_fn, + update_obs_once=update_obs_once, ) self._connection_spec = connection_spec @@ -1413,10 +1422,15 @@ def _backward(self) -> None: agent_key = self._agent_net_keys[agent] # Get trainable variables. - policy_variables = ( - self._observation_networks[agent_key].trainable_variables - + self._policy_networks[agent_key].trainable_variables - ) + print("CHECK: ", self._update_obs_once) + if self._update_obs_once: + policy_variables = self._policy_networks[agent_key].trainable_variables + else: + policy_variables = ( + self._observation_networks[agent_key].trainable_variables + + self._policy_networks[agent_key].trainable_variables + ) + critic_variables = ( # In this agent, the critic loss trains the observation network. self._observation_networks[agent_key].trainable_variables @@ -1554,6 +1568,7 @@ def __init__( logger: loggers.Logger = None, bootstrap_n: int = 10, learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, + update_obs_once: bool = False, ): super().__init__( @@ -1579,6 +1594,7 @@ def __init__( counts=counts, bootstrap_n=bootstrap_n, learning_rate_scheduler_fn=learning_rate_scheduler_fn, + update_obs_once=update_obs_once, ) @@ -1613,6 +1629,7 @@ def __init__( logger: loggers.Logger = None, bootstrap_n: int = 10, learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, + update_obs_once: bool = False, ): super().__init__( @@ -1638,6 +1655,7 @@ def __init__( counts=counts, bootstrap_n=bootstrap_n, learning_rate_scheduler_fn=learning_rate_scheduler_fn, + update_obs_once=update_obs_once, ) def _get_critic_feed( From 21b376b6b66360e74f209b3ad3527448b85dcd98 Mon Sep 17 00:00:00 2001 From: Dries Date: Wed, 23 Mar 2022 12:21:23 +0200 Subject: [PATCH 029/104] fix: Distributional head inside networks. --- mava/systems/tf/mappo/execution.py | 35 ++++++---------- mava/systems/tf/mappo/networks.py | 65 ++++++++++++++---------------- mava/systems/tf/mappo/training.py | 4 -- 3 files changed, 43 insertions(+), 61 deletions(-) diff --git a/mava/systems/tf/mappo/execution.py b/mava/systems/tf/mappo/execution.py index f1240663a..fbcee12c4 100644 --- a/mava/systems/tf/mappo/execution.py +++ b/mava/systems/tf/mappo/execution.py @@ -115,18 +115,18 @@ def _policy( # Compute the policy, conditioned on the observation. policy = self._policy_networks[agent_key](batched_observation) - if type(policy) == tfd.transformed_distribution.TransformedDistribution: - # Sample from the policy and compute the log likelihood. - action = policy.sample() - else: - # Sample from the policy and compute the log likelihood. - policy = tfd.Categorical(policy) - action = policy.sample() - - # Cast for compatibility with reverb. - # sample() returns a 'int32', which is a problem. - action = tf.cast(action, "int64") + # Mask categorical policies using legal actions + if hasattr(observation_olt, "legal_actions") and isinstance( + policy, tfp.distributions.Categorical + ): + batched_legals = tf2_utils.add_batch_dim(observation_olt.legal_actions) + + policy = action_mask_categorical_policies( + policy=policy, batched_legal_actions=batched_legals + ) + # Sample from the policy and compute the log likelihood. + action = policy.sample() log_prob = policy.log_prob(action) return action, log_prob @@ -306,18 +306,6 @@ def _policy( # Compute the policy, conditioned on the observation. policy, new_state = self._policy_networks[agent_key](batched_observation, state) - if type(policy) == tfd.transformed_distribution.TransformedDistribution: - # Sample from the policy and compute the log likelihood. - action = policy.sample() - else: - # Sample from the policy and compute the log likelihood. - policy = tfd.Categorical(policy) - action = policy.sample() - - # Cast for compatibility with reverb. - # sample() returns a 'int32', which is a problem. - action = tf.cast(action, "int64") - # Mask categorical policies using legal actions if hasattr(observation_olt, "legal_actions") and isinstance( policy, tfp.distributions.Categorical @@ -329,6 +317,7 @@ def _policy( ) # Sample from the policy and compute the log likelihood. + action = policy.sample() log_prob = policy.log_prob(action) return action, log_prob, new_state diff --git a/mava/systems/tf/mappo/networks.py b/mava/systems/tf/mappo/networks.py index d55aefafa..4e8a76f7b 100644 --- a/mava/systems/tf/mappo/networks.py +++ b/mava/systems/tf/mappo/networks.py @@ -105,44 +105,39 @@ def make_default_networks( # Create the shared observation network; here simply a state-less operation. observation_network = tf2_utils.to_sonnet_module(tf.identity) - # Note: The discrete case must be placed first as it inherits from BoundedArray. - if isinstance(specs[key].actions, dm_env.specs.DiscreteArray): # discrete - num_actions = specs[key].actions.num_values - - if archecture_type == ArchitectureType.feedforward: - policy_layers = [ - networks.LayerNormMLP( - policy_networks_layer_sizes[key], - activate_final=False, - seed=seed, - ), - ] - elif archecture_type == ArchitectureType.recurrent: - policy_layers = [ - networks.LayerNormMLP( - policy_networks_layer_sizes[key][:-1], - activate_final=True, - seed=seed, - ), - snt.LSTM(policy_networks_layer_sizes[key][-1]), - networks.LayerNormMLP( - (num_actions,), - activate_final=False, - seed=seed, - ), - ] + discrete_actions = isinstance(specs[key].actions, dm_env.specs.DiscreteArray) + + policy_layers = [ + networks.LayerNormMLP( + policy_networks_layer_sizes[key], + activate_final=False, + seed=seed, + ) + ] - policy_network = policy_network_func(policy_layers) + # Add recurrence if the architecture is recurrent + if archecture_type == ArchitectureType.recurrent: + policy_layers.append( + snt.LSTM(policy_networks_layer_sizes[key][-1]), + ) + + # Get the number of actions + num_actions = ( + specs[key].actions.num_values + if discrete_actions + else np.prod(specs[key].actions.shape, dtype=int) + ) + # Note: The discrete case must be placed first as it inherits from BoundedArray. + if discrete_actions: # discrete + policy_layers.append( + networks.CategoricalHead( + num_values=num_actions, dtype=specs[key].actions.dtype + ), + ) elif isinstance(specs[key].actions, dm_env.specs.BoundedArray): # continuous - num_actions = np.prod(specs[key].actions.shape, dtype=int) - # No recurrent setting implemented yet. - assert archecture_type == ArchitectureType.feedforward - policy_network = snt.Sequential( + policy_layers.extend( [ - networks.LayerNormMLP( - policy_networks_layer_sizes[key], activate_final=True, seed=seed - ), networks.MultivariateNormalDiagHead( num_dimensions=num_actions, w_init=tf.initializers.VarianceScaling(1e-4, seed=seed), @@ -154,6 +149,8 @@ def make_default_networks( else: raise ValueError(f"Unknown action_spec type, got {specs[key].actions}.") + policy_network = policy_network_func(policy_layers) + critic_network = snt.Sequential( [ networks.LayerNormMLP( diff --git a/mava/systems/tf/mappo/training.py b/mava/systems/tf/mappo/training.py index 11d100380..b82b4e9a4 100644 --- a/mava/systems/tf/mappo/training.py +++ b/mava/systems/tf/mappo/training.py @@ -386,10 +386,6 @@ def _forward_pass(self, inputs: Any) -> None: policy = policy_network(actor_observation) policy = tfd.BatchReshape(policy, batch_shape=dims, name="policy") - if type(policy) != tfd.transformed_distribution.TransformedDistribution: - # Sample from the policy and compute the log likelihood. - policy = tfd.Categorical(policy) - critic_observation = snt.merge_leading_dims( critic_observation, num_dims=2 ) From 0c49d7a3970673aa816b0e08caf2bbe07a1e1ee7 Mon Sep 17 00:00:00 2001 From: AsadJeewa Date: Wed, 23 Mar 2022 16:30:38 +0200 Subject: [PATCH 030/104] feat: configurable obs network ddpg d4pg --- mava/systems/tf/mad4pg/networks.py | 3 +++ mava/systems/tf/maddpg/networks.py | 8 ++++++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/mava/systems/tf/mad4pg/networks.py b/mava/systems/tf/mad4pg/networks.py index 8b8daf5a7..18651d4c8 100644 --- a/mava/systems/tf/mad4pg/networks.py +++ b/mava/systems/tf/mad4pg/networks.py @@ -14,6 +14,7 @@ # limitations under the License. from typing import Dict, Mapping, Optional, Sequence, Union +import sonnet as snt from acme import types from dm_env import specs @@ -36,6 +37,7 @@ def make_default_networks( net_spec_keys: Dict[str, str] = {}, policy_networks_layer_sizes: Union[Dict[str, Sequence], Sequence] = None, critic_networks_layer_sizes: Union[Dict[str, Sequence], Sequence] = (512, 512, 256), + observation_network: snt.Module = None, sigma: float = 0.3, architecture_type: ArchitectureType = ArchitectureType.feedforward, num_atoms: int = 51, @@ -76,6 +78,7 @@ def make_default_networks( net_spec_keys=net_spec_keys, policy_networks_layer_sizes=policy_networks_layer_sizes, critic_networks_layer_sizes=critic_networks_layer_sizes, + observation_network=observation_network, sigma=sigma, architecture_type=architecture_type, vmin=vmin, diff --git a/mava/systems/tf/maddpg/networks.py b/mava/systems/tf/maddpg/networks.py index 0e134872c..8d9738d86 100644 --- a/mava/systems/tf/maddpg/networks.py +++ b/mava/systems/tf/maddpg/networks.py @@ -37,6 +37,7 @@ def make_default_networks( net_spec_keys: Dict[str, str] = {}, policy_networks_layer_sizes: Union[Dict[str, Sequence], Sequence] = None, critic_networks_layer_sizes: Union[Dict[str, Sequence], Sequence] = (512, 512, 256), + observation_network: snt.Module = None, sigma: float = 0.3, architecture_type: ArchitectureType = ArchitectureType.feedforward, vmin: Optional[float] = None, @@ -127,8 +128,11 @@ def make_default_networks( # Get total number of action dimensions from action spec. num_dimensions = np.prod(agent_act_spec.shape, dtype=int) - # An optional network to process observations - observation_network = tf2_utils.to_sonnet_module(tf.identity) + # Create the shared observation network; here simply a state-less operation. + if observation_network is None: + observation_network = tf2_utils.to_sonnet_module(tf.identity) + else: + observation_network = observation_network # Create the policy network. if architecture_type == ArchitectureType.feedforward: policy_network = [ From 315a1512af4e85383a800e331264377483781287 Mon Sep 17 00:00:00 2001 From: Dries Date: Thu, 24 Mar 2022 13:47:24 +0200 Subject: [PATCH 031/104] fix: Change static unroll function to a manual unroll function. --- mava/systems/tf/mappo/training.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/mava/systems/tf/mappo/training.py b/mava/systems/tf/mappo/training.py index 7528f8848..173c30540 100644 --- a/mava/systems/tf/mappo/training.py +++ b/mava/systems/tf/mappo/training.py @@ -374,15 +374,24 @@ def _forward_pass(self, inputs: Any) -> None: policy_network = self._policy_networks[agent_key] critic_network = self._critic_networks[agent_key] dims = actor_observation.shape[:2] + # Do policy forward pass. if "core_states" in extras: # Unroll current policy over actor_observation. agent_core_state = core_states[agent][0] - policy, updated_states = snt.static_unroll( - policy_network, - actor_observation, - agent_core_state, - ) + + # Manual perform unroll + action_prob = [] + policy_entropy = [] + for t in range(len(actor_observation)): + outputs, agent_core_state = policy_network( + actor_observation[t], agent_core_state + ) + action_prob.append(outputs.log_prob(action[t])) + policy_entropy.append(outputs.entropy()) + + action_prob = tf.stack(action_prob, axis=0) + policy_entropy = tf.stack(policy_entropy, axis=0) else: # Reshape inputs. actor_observation = snt.merge_leading_dims( @@ -390,6 +399,8 @@ def _forward_pass(self, inputs: Any) -> None: ) policy = policy_network(actor_observation) policy = tfd.BatchReshape(policy, batch_shape=dims, name="policy") + action_prob = policy.log_prob(action) + policy_entropy = policy.entropy() critic_observation = snt.merge_leading_dims( critic_observation, num_dims=2 @@ -439,7 +450,7 @@ def _forward_pass(self, inputs: Any) -> None: critic_loss = critic_loss * self._baseline_cost # Compute importance sampling weights: current policy / behavior policy. - log_rhos = policy.log_prob(action)[:-1] - behaviour_log_prob[:-1] + log_rhos = action_prob[:-1] - behaviour_log_prob[:-1] rhos = tf.exp(log_rhos) clipped_rhos = tf.clip_by_value( @@ -459,7 +470,7 @@ def _forward_pass(self, inputs: Any) -> None: # Entropy regulariser. # Entropy regularization. Only implemented for categorical dist. try: - masked_entropy_loss = policy.entropy()[:-1] * loss_mask[:-1] + masked_entropy_loss = policy_entropy[:-1] * loss_mask[:-1] entropy_loss = -tf.reduce_sum(masked_entropy_loss) / tf.reduce_sum( loss_mask[:-1] ) From c0e312749d9be20af6874b09d0501375a32167ed Mon Sep 17 00:00:00 2001 From: Dries Date: Fri, 25 Mar 2022 11:25:35 +0200 Subject: [PATCH 032/104] fix: Change setting in PPO sequence adder. Remove custom adder code. --- mava/adders/reverb/base.py | 9 +-------- mava/systems/tf/mappo/builder.py | 7 +++++-- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/mava/adders/reverb/base.py b/mava/adders/reverb/base.py index 32ff4b17c..39c430b3c 100644 --- a/mava/adders/reverb/base.py +++ b/mava/adders/reverb/base.py @@ -388,13 +388,6 @@ def add_first( ) self._add_first_called = True - def reset_writer(self, timeout_ms: Optional[int] = None) -> None: - """Resets the adder's buffer.""" - if self._writer: - # Flush all appended data and clear the buffers. - self._writer.end_episode(clear_buffers=True, timeout_ms=timeout_ms) - self._add_first_called = False - def add( self, actions: Dict[str, types.NestedArray], @@ -440,4 +433,4 @@ def add( dummy_step = tree.map_structure(np.zeros_like, current_step) self._writer.append(dummy_step) self._write_last() - self.reset_writer() + self.reset() diff --git a/mava/systems/tf/mappo/builder.py b/mava/systems/tf/mappo/builder.py index a2a05207c..63de1e47f 100644 --- a/mava/systems/tf/mappo/builder.py +++ b/mava/systems/tf/mappo/builder.py @@ -22,7 +22,8 @@ import reverb import sonnet as snt import tensorflow as tf -from acme.adders.reverb.sequence import EndBehavior + +# from acme.adders.reverb.sequence import EndBehavior from acme.specs import EnvironmentSpec from acme.tf import utils as tf2_utils @@ -299,8 +300,10 @@ def make_adder( period=self._config.sequence_period, sequence_length=self._config.sequence_length, use_next_extras=False, - end_of_episode_behavior=EndBehavior.CONTINUE, + # end_of_episode_behavior=EndBehavior.CONTINUE, ) + # Note (dries) using end_of_episode_behavior=EndBehavior.CONTINUE seems + # to break things. def create_counter_variables( self, variables: Dict[str, tf.Variable] From 921717207afb0044e3bbb3f9ac5e6f5710708d43 Mon Sep 17 00:00:00 2001 From: Dries Date: Fri, 25 Mar 2022 12:48:43 +0200 Subject: [PATCH 033/104] fix: Small comment updates. --- mava/systems/tf/mappo/builder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mava/systems/tf/mappo/builder.py b/mava/systems/tf/mappo/builder.py index 63de1e47f..27624fca7 100644 --- a/mava/systems/tf/mappo/builder.py +++ b/mava/systems/tf/mappo/builder.py @@ -302,8 +302,8 @@ def make_adder( use_next_extras=False, # end_of_episode_behavior=EndBehavior.CONTINUE, ) - # Note (dries) using end_of_episode_behavior=EndBehavior.CONTINUE seems - # to break things. + # Note (dries): Using end_of_episode_behavior=EndBehavior.CONTINUE can + # break the adders when using multiple trainers. def create_counter_variables( self, variables: Dict[str, tf.Variable] From 66757a2f80e27f0810d5284b88c478a6fcb204d0 Mon Sep 17 00:00:00 2001 From: RuanJohn Date: Mon, 28 Mar 2022 16:02:28 +0200 Subject: [PATCH 034/104] bugfix: architecture type typo fix --- mava/systems/tf/mappo/networks.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mava/systems/tf/mappo/networks.py b/mava/systems/tf/mappo/networks.py index 89691744b..4e5a0821f 100644 --- a/mava/systems/tf/mappo/networks.py +++ b/mava/systems/tf/mappo/networks.py @@ -41,7 +41,7 @@ def make_default_networks( 256, ), critic_networks_layer_sizes: Union[Dict[str, Sequence], Sequence] = (512, 512, 256), - archecture_type: ArchitectureType = ArchitectureType.feedforward, + architecture_type: ArchitectureType = ArchitectureType.feedforward, observation_network: snt.Module = None, seed: Optional[int] = None, ) -> Dict[str, snt.Module]: @@ -76,7 +76,7 @@ def make_default_networks( # Set Policy function and layer size # Default size per arch type. - if archecture_type == ArchitectureType.feedforward: + if architecture_type == ArchitectureType.feedforward: if not policy_networks_layer_sizes: policy_networks_layer_sizes = ( 256, @@ -84,7 +84,7 @@ def make_default_networks( 256, ) policy_network_func = snt.Sequential - elif archecture_type == ArchitectureType.recurrent: + elif architecture_type == ArchitectureType.recurrent: if not policy_networks_layer_sizes: policy_networks_layer_sizes = (128, 128) policy_network_func = snt.DeepRNN @@ -120,7 +120,7 @@ def make_default_networks( ] # Add recurrence if the architecture is recurrent - if archecture_type == ArchitectureType.recurrent: + if architecture_type == ArchitectureType.recurrent: policy_layers.append( snt.LSTM(policy_networks_layer_sizes[key][-1]), ) From b8ccfe05752ae8ec578c56bcefe2a99ff9a836aa Mon Sep 17 00:00:00 2001 From: Dries Date: Mon, 28 Mar 2022 16:53:36 +0200 Subject: [PATCH 035/104] fix: Baseline cost defualt. --- mava/systems/tf/mappo/system.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mava/systems/tf/mappo/system.py b/mava/systems/tf/mappo/system.py index 82434efa1..8ed769291 100644 --- a/mava/systems/tf/mappo/system.py +++ b/mava/systems/tf/mappo/system.py @@ -76,7 +76,7 @@ def __init__( # noqa lambda_gae: float = 0.95, clipping_epsilon: float = 0.2, entropy_cost: float = 0.01, - baseline_cost: float = 1.0, + baseline_cost: float = 0.5, max_gradient_norm: Optional[float] = 0.5, max_queue_size: Optional[int] = None, batch_size: int = 512, From f0186b6b5d530fa4d74cb4791eca42d0036c1856 Mon Sep 17 00:00:00 2001 From: Dries Date: Mon, 28 Mar 2022 16:56:35 +0200 Subject: [PATCH 036/104] fix: Update defualt sequence length. --- mava/systems/tf/mappo/system.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mava/systems/tf/mappo/system.py b/mava/systems/tf/mappo/system.py index 8ed769291..b321c6bc8 100644 --- a/mava/systems/tf/mappo/system.py +++ b/mava/systems/tf/mappo/system.py @@ -82,7 +82,7 @@ def __init__( # noqa batch_size: int = 512, minibatch_size: int = None, num_epochs: int = 10, - sequence_length: int = 10, + sequence_length: int = 20, sequence_period: Optional[int] = None, max_executor_steps: int = None, checkpoint: bool = True, From 091d73a6700f25ddfb9ea1e131602818e2a04889 Mon Sep 17 00:00:00 2001 From: Dries Date: Tue, 29 Mar 2022 11:56:32 +0200 Subject: [PATCH 037/104] Fix entropy term in trainer for continuous action space environments. --- mava/systems/tf/mappo/training.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/mava/systems/tf/mappo/training.py b/mava/systems/tf/mappo/training.py index 173c30540..79a32926b 100644 --- a/mava/systems/tf/mappo/training.py +++ b/mava/systems/tf/mappo/training.py @@ -388,10 +388,15 @@ def _forward_pass(self, inputs: Any) -> None: actor_observation[t], agent_core_state ) action_prob.append(outputs.log_prob(action[t])) - policy_entropy.append(outputs.entropy()) + try: + policy_entropy.append(outputs.entropy()) + except NotImplementedError: + pass action_prob = tf.stack(action_prob, axis=0) - policy_entropy = tf.stack(policy_entropy, axis=0) + + if len(policy_entropy) > 0: + policy_entropy = tf.stack(policy_entropy, axis=0) else: # Reshape inputs. actor_observation = snt.merge_leading_dims( @@ -400,7 +405,11 @@ def _forward_pass(self, inputs: Any) -> None: policy = policy_network(actor_observation) policy = tfd.BatchReshape(policy, batch_shape=dims, name="policy") action_prob = policy.log_prob(action) - policy_entropy = policy.entropy() + + try: + policy_entropy = policy.entropy() + except NotImplementedError: + pass critic_observation = snt.merge_leading_dims( critic_observation, num_dims=2 @@ -469,13 +478,16 @@ def _forward_pass(self, inputs: Any) -> None: # Entropy regulariser. # Entropy regularization. Only implemented for categorical dist. - try: + # TODO (dries): Get this entropy term to work with univariate gaussian + # distributions as well. The clipping needs to be fixed in that case. + # (SAC paper, Appendix C) + if len(policy_entropy) == 0: masked_entropy_loss = policy_entropy[:-1] * loss_mask[:-1] entropy_loss = -tf.reduce_sum(masked_entropy_loss) / tf.reduce_sum( loss_mask[:-1] ) - except NotImplementedError: + else: entropy_loss = tf.convert_to_tensor(0.0) entropy_loss = self._entropy_cost * entropy_loss From 7ce71c7273e90c016b4b96ed97b8a78121303fb9 Mon Sep 17 00:00:00 2001 From: Dries Date: Tue, 29 Mar 2022 12:01:50 +0200 Subject: [PATCH 038/104] fix: Small fix to entropy loss. --- mava/systems/tf/mappo/training.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mava/systems/tf/mappo/training.py b/mava/systems/tf/mappo/training.py index 79a32926b..c4eb4f107 100644 --- a/mava/systems/tf/mappo/training.py +++ b/mava/systems/tf/mappo/training.py @@ -376,13 +376,13 @@ def _forward_pass(self, inputs: Any) -> None: dims = actor_observation.shape[:2] # Do policy forward pass. + policy_entropy = [] if "core_states" in extras: # Unroll current policy over actor_observation. agent_core_state = core_states[agent][0] # Manual perform unroll action_prob = [] - policy_entropy = [] for t in range(len(actor_observation)): outputs, agent_core_state = policy_network( actor_observation[t], agent_core_state @@ -481,7 +481,7 @@ def _forward_pass(self, inputs: Any) -> None: # TODO (dries): Get this entropy term to work with univariate gaussian # distributions as well. The clipping needs to be fixed in that case. # (SAC paper, Appendix C) - if len(policy_entropy) == 0: + if len(policy_entropy) > 0: masked_entropy_loss = policy_entropy[:-1] * loss_mask[:-1] entropy_loss = -tf.reduce_sum(masked_entropy_loss) / tf.reduce_sum( loss_mask[:-1] From 64cb86aeb390d953130d97fd68bff055d950bccf Mon Sep 17 00:00:00 2001 From: Dries Date: Tue, 29 Mar 2022 12:08:05 +0200 Subject: [PATCH 039/104] fix: Small fix to variable spelling. --- .../debugging/simple_spread/recurrent/state_based/run_mappo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/tf/debugging/simple_spread/recurrent/state_based/run_mappo.py b/examples/tf/debugging/simple_spread/recurrent/state_based/run_mappo.py index 3a94bdaf6..6e2a1c972 100644 --- a/examples/tf/debugging/simple_spread/recurrent/state_based/run_mappo.py +++ b/examples/tf/debugging/simple_spread/recurrent/state_based/run_mappo.py @@ -67,7 +67,7 @@ def main(_: Any) -> None: # Networks. network_factory = lp_utils.partial_kwargs( mappo.make_default_networks, - archecture_type=ArchitectureType.recurrent + architecture_type=ArchitectureType.recurrent if recurrent_ppo else ArchitectureType.feedforward, ) From d797241936d4ea44250b9e12841b3256a5a6e03f Mon Sep 17 00:00:00 2001 From: Dries Date: Thu, 31 Mar 2022 13:16:47 +0200 Subject: [PATCH 040/104] fix: Fix continuous action space PPO by creating custom clipped Gaussian head. --- mava/systems/tf/mappo/execution.py | 1 + mava/systems/tf/mappo/networks.py | 43 ++++++++++++++++++++++++++++-- mava/systems/tf/mappo/training.py | 33 +++++++---------------- 3 files changed, 51 insertions(+), 26 deletions(-) diff --git a/mava/systems/tf/mappo/execution.py b/mava/systems/tf/mappo/execution.py index fbcee12c4..2d85339cf 100644 --- a/mava/systems/tf/mappo/execution.py +++ b/mava/systems/tf/mappo/execution.py @@ -127,6 +127,7 @@ def _policy( # Sample from the policy and compute the log likelihood. action = policy.sample() + log_prob = policy.log_prob(action) return action, log_prob diff --git a/mava/systems/tf/mappo/networks.py b/mava/systems/tf/mappo/networks.py index 4e5a0821f..b14c7d257 100644 --- a/mava/systems/tf/mappo/networks.py +++ b/mava/systems/tf/mappo/networks.py @@ -12,12 +12,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Optional, Sequence, Union +from typing import Any, Dict, Optional, Sequence, Union import dm_env import numpy as np import sonnet as snt import tensorflow as tf +import tensorflow_probability as tfp from acme.tf import utils as tf2_utils from dm_env import specs @@ -28,6 +29,44 @@ Array = specs.Array BoundedArray = specs.BoundedArray DiscreteArray = specs.DiscreteArray +tfd = tfp.distributions + + +class ClippedGaussianDistribution: + def __init__(self, guassian_dist: Any, action_specs: Any): + self._guassian_dist = guassian_dist + self.clip_fn = networks.ClipToSpec(action_specs) + + def entropy(self) -> tf.Tensor: + # Note (dires): This calculates the approximate entropy of the + # clipped Gaussian distribution by setting it to the + # unclipped guassian entropy. + return self._guassian_dist.entropy() + + def sample(self) -> tf.Tensor: + # TODO: Add range specs clipping here. + return self.clip_fn(self._guassian_dist.sample()) + + def log_prob(self, action: tf.Tensor) -> tf.Tensor: + return self._guassian_dist.log_prob(action) + + def batch_reshape(self, dims: tf.Tensor, name: str = "reshaped") -> None: + self._guassian_dist = tfd.BatchReshape( + self._guassian_dist, batch_shape=dims, name=name + ) + + +class ClippedGaussianHead(snt.Module): + def __init__( + self, + action_specs: Any, + name: Optional[str] = None, + ): + super().__init__(name=name) + self._action_specs = action_specs + + def __call__(self, x: Any) -> ClippedGaussianDistribution: + return ClippedGaussianDistribution(x, action_specs=self._action_specs) # TODO Update for recurrent version. @@ -147,7 +186,7 @@ def make_default_networks( w_init=tf.initializers.VarianceScaling(1e-4, seed=seed), b_init=tf.initializers.Zeros(), ), - networks.TanhToSpec(specs[key].actions), + ClippedGaussianHead(specs[key].actions), ] ) else: diff --git a/mava/systems/tf/mappo/training.py b/mava/systems/tf/mappo/training.py index c4eb4f107..1e91829d8 100644 --- a/mava/systems/tf/mappo/training.py +++ b/mava/systems/tf/mappo/training.py @@ -22,7 +22,6 @@ import numpy as np import sonnet as snt import tensorflow as tf -import tensorflow_probability as tfp import tree from acme.tf import utils as tf2_utils from acme.utils import counting, loggers @@ -35,8 +34,6 @@ train_utils.set_growing_gpu_memory() -tfd = tfp.distributions - class MAPPOTrainer(mava.Trainer): """MAPPO trainer. @@ -388,28 +385,21 @@ def _forward_pass(self, inputs: Any) -> None: actor_observation[t], agent_core_state ) action_prob.append(outputs.log_prob(action[t])) - try: - policy_entropy.append(outputs.entropy()) - except NotImplementedError: - pass + policy_entropy.append(outputs.entropy()) action_prob = tf.stack(action_prob, axis=0) - - if len(policy_entropy) > 0: - policy_entropy = tf.stack(policy_entropy, axis=0) + policy_entropy = tf.stack(policy_entropy, axis=0) else: # Reshape inputs. actor_observation = snt.merge_leading_dims( actor_observation, num_dims=2 ) policy = policy_network(actor_observation) - policy = tfd.BatchReshape(policy, batch_shape=dims, name="policy") + + policy.batch_reshape(dims, name="policy") action_prob = policy.log_prob(action) - try: - policy_entropy = policy.entropy() - except NotImplementedError: - pass + policy_entropy = policy.entropy() critic_observation = snt.merge_leading_dims( critic_observation, num_dims=2 @@ -481,15 +471,10 @@ def _forward_pass(self, inputs: Any) -> None: # TODO (dries): Get this entropy term to work with univariate gaussian # distributions as well. The clipping needs to be fixed in that case. # (SAC paper, Appendix C) - if len(policy_entropy) > 0: - masked_entropy_loss = policy_entropy[:-1] * loss_mask[:-1] - entropy_loss = -tf.reduce_sum(masked_entropy_loss) / tf.reduce_sum( - loss_mask[:-1] - ) - - else: - entropy_loss = tf.convert_to_tensor(0.0) - + masked_entropy_loss = policy_entropy[:-1] * loss_mask[:-1] + entropy_loss = -tf.reduce_sum(masked_entropy_loss) / tf.reduce_sum( + loss_mask[:-1] + ) entropy_loss = self._entropy_cost * entropy_loss # Combine weighted sum of actor & entropy regularization. From b34c707d36660b33d4edb3bbeb9f00d4473a1b4d Mon Sep 17 00:00:00 2001 From: Dries Date: Thu, 31 Mar 2022 13:22:12 +0200 Subject: [PATCH 041/104] fix: Small fixes. --- mava/systems/tf/mappo/networks.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mava/systems/tf/mappo/networks.py b/mava/systems/tf/mappo/networks.py index b14c7d257..ead56102f 100644 --- a/mava/systems/tf/mappo/networks.py +++ b/mava/systems/tf/mappo/networks.py @@ -38,7 +38,7 @@ def __init__(self, guassian_dist: Any, action_specs: Any): self.clip_fn = networks.ClipToSpec(action_specs) def entropy(self) -> tf.Tensor: - # Note (dires): This calculates the approximate entropy of the + # Note (dries): This calculates the approximate entropy of the # clipped Gaussian distribution by setting it to the # unclipped guassian entropy. return self._guassian_dist.entropy() @@ -48,6 +48,7 @@ def sample(self) -> tf.Tensor: return self.clip_fn(self._guassian_dist.sample()) def log_prob(self, action: tf.Tensor) -> tf.Tensor: + # This calculates the approximate log prob of the action. return self._guassian_dist.log_prob(action) def batch_reshape(self, dims: tf.Tensor, name: str = "reshaped") -> None: From ec31d19eb5265c62fcef7c7e07077eeba91d8658 Mon Sep 17 00:00:00 2001 From: Dries Date: Thu, 31 Mar 2022 13:51:01 +0200 Subject: [PATCH 042/104] fix: Replace clip to spec with tanh to spec. --- mava/systems/tf/mappo/networks.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mava/systems/tf/mappo/networks.py b/mava/systems/tf/mappo/networks.py index ead56102f..6f4defcb1 100644 --- a/mava/systems/tf/mappo/networks.py +++ b/mava/systems/tf/mappo/networks.py @@ -35,7 +35,7 @@ class ClippedGaussianDistribution: def __init__(self, guassian_dist: Any, action_specs: Any): self._guassian_dist = guassian_dist - self.clip_fn = networks.ClipToSpec(action_specs) + self.clip_fn = networks.TanhToSpec(action_specs) def entropy(self) -> tf.Tensor: # Note (dries): This calculates the approximate entropy of the @@ -45,11 +45,11 @@ def entropy(self) -> tf.Tensor: def sample(self) -> tf.Tensor: # TODO: Add range specs clipping here. - return self.clip_fn(self._guassian_dist.sample()) + return self.clip_fn(self._guassian_dist).sample() def log_prob(self, action: tf.Tensor) -> tf.Tensor: # This calculates the approximate log prob of the action. - return self._guassian_dist.log_prob(action) + return self.clip_fn(self._guassian_dist).log_prob(action) def batch_reshape(self, dims: tf.Tensor, name: str = "reshaped") -> None: self._guassian_dist = tfd.BatchReshape( From 527d8df45b8b906c657f6834e8c5111596b245cd Mon Sep 17 00:00:00 2001 From: Dries Date: Thu, 31 Mar 2022 13:54:00 +0200 Subject: [PATCH 043/104] fix: Remove comment. --- mava/systems/tf/mappo/networks.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mava/systems/tf/mappo/networks.py b/mava/systems/tf/mappo/networks.py index 6f4defcb1..940ecb60f 100644 --- a/mava/systems/tf/mappo/networks.py +++ b/mava/systems/tf/mappo/networks.py @@ -48,7 +48,6 @@ def sample(self) -> tf.Tensor: return self.clip_fn(self._guassian_dist).sample() def log_prob(self, action: tf.Tensor) -> tf.Tensor: - # This calculates the approximate log prob of the action. return self.clip_fn(self._guassian_dist).log_prob(action) def batch_reshape(self, dims: tf.Tensor, name: str = "reshaped") -> None: From 40a430e28d169002145297d86c26aeb9b707c147 Mon Sep 17 00:00:00 2001 From: Dries Date: Thu, 31 Mar 2022 13:56:14 +0200 Subject: [PATCH 044/104] fix: Remove comment. --- mava/systems/tf/mappo/networks.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mava/systems/tf/mappo/networks.py b/mava/systems/tf/mappo/networks.py index 940ecb60f..3915e9372 100644 --- a/mava/systems/tf/mappo/networks.py +++ b/mava/systems/tf/mappo/networks.py @@ -44,7 +44,6 @@ def entropy(self) -> tf.Tensor: return self._guassian_dist.entropy() def sample(self) -> tf.Tensor: - # TODO: Add range specs clipping here. return self.clip_fn(self._guassian_dist).sample() def log_prob(self, action: tf.Tensor) -> tf.Tensor: From c9a1902723e1fd2cfd27ab95133d0a88a1eead65 Mon Sep 17 00:00:00 2001 From: AsadJeewa Date: Thu, 31 Mar 2022 16:39:37 +0200 Subject: [PATCH 045/104] debug: d4pg update_obs_once --- mava/systems/tf/mad4pg/system.py | 3 +++ mava/systems/tf/mad4pg/training.py | 34 ++++++++++++++++++++++++++++++ mava/systems/tf/maddpg/system.py | 2 +- mava/systems/tf/maddpg/training.py | 9 ++++++-- 4 files changed, 45 insertions(+), 3 deletions(-) diff --git a/mava/systems/tf/mad4pg/system.py b/mava/systems/tf/mad4pg/system.py index 906acf5a7..0ad9358cb 100644 --- a/mava/systems/tf/mad4pg/system.py +++ b/mava/systems/tf/mad4pg/system.py @@ -89,6 +89,7 @@ def __init__( learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, termination_condition: Optional[Dict[str, int]] = None, evaluator_interval: Optional[dict] = None, + update_obs_once: bool = False, ): """Initialise the system @@ -186,6 +187,7 @@ def __init__( happen at every timestep. E.g. to evaluate a system after every 100 executor episodes, evaluator_interval = {"executor_episodes": 100}. + update_obs_once: bool = use only critic gradients to update shared observation network """ super().__init__( @@ -228,4 +230,5 @@ def __init__( learning_rate_scheduler_fn=learning_rate_scheduler_fn, termination_condition=termination_condition, evaluator_interval=evaluator_interval, + update_obs_once=update_obs_once, ) diff --git a/mava/systems/tf/mad4pg/training.py b/mava/systems/tf/mad4pg/training.py index 3237fa245..edb4babec 100644 --- a/mava/systems/tf/mad4pg/training.py +++ b/mava/systems/tf/mad4pg/training.py @@ -74,6 +74,7 @@ def __init__( max_gradient_norm: float = None, logger: loggers.Logger = None, learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, + update_obs_once: bool = False, ): """Initialise MAD4PG trainer @@ -109,6 +110,8 @@ def __init__( learning_rate_scheduler_fn: dict with two functions (one for the policy and one for the critic optimizer), that takes in a trainer step t and returns the current learning rate. + update_obs_once: bool = use only critic gradients to update shared + observation network """ super().__init__( @@ -133,6 +136,7 @@ def __init__( max_gradient_norm=max_gradient_norm, logger=logger, learning_rate_scheduler_fn=learning_rate_scheduler_fn, + update_obs_once=update_obs_once, ) # Forward pass that calculates loss. @@ -250,6 +254,7 @@ def __init__( max_gradient_norm: float = None, logger: loggers.Logger = None, learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, + update_obs_once: bool = False, ): """Initialise decentralised MAD4PG trainer @@ -285,6 +290,8 @@ def __init__( learning_rate_scheduler_fn: dict with two functions (one for the policy and one for the critic optimizer), that takes in a trainer step t and returns the current learning rate. + update_obs_once: bool = use only critic gradients to update shared + observation network """ super().__init__( @@ -309,6 +316,7 @@ def __init__( variable_client=variable_client, counts=counts, learning_rate_scheduler_fn=learning_rate_scheduler_fn, + update_obs_once=update_obs_once, ) @@ -338,6 +346,7 @@ def __init__( max_gradient_norm: float = None, logger: loggers.Logger = None, learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, + update_obs_once: bool = False, ): """Initialise centralised MAD4PG trainer @@ -373,6 +382,8 @@ def __init__( learning_rate_scheduler_fn: dict with two functions (one for the policy and one for the critic optimizer), that takes in a trainer step t and returns the current learning rate. + update_obs_once: bool = use only critic gradients to update shared + observation network """ super().__init__( @@ -397,6 +408,7 @@ def __init__( variable_client=variable_client, counts=counts, learning_rate_scheduler_fn=learning_rate_scheduler_fn, + update_obs_once=update_obs_once, ) @@ -426,6 +438,7 @@ def __init__( max_gradient_norm: float = None, logger: loggers.Logger = None, learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, + update_obs_once: bool = False, ): """Initialise state-based MAD4PG trainer @@ -461,6 +474,8 @@ def __init__( learning_rate_scheduler_fn: dict with two functions (one for the policy and one for the critic optimizer), that takes in a trainer step t and returns the current learning rate. + update_obs_once: bool = use only critic gradients to update shared + observation network """ super().__init__( @@ -485,6 +500,7 @@ def __init__( variable_client=variable_client, counts=counts, learning_rate_scheduler_fn=learning_rate_scheduler_fn, + update_obs_once=update_obs_once, ) @@ -519,6 +535,7 @@ def __init__( logger: loggers.Logger = None, bootstrap_n: int = 10, learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, + update_obs_once: bool = False, ): """Initialise Recurrent MAD4PG trainer @@ -555,6 +572,8 @@ def __init__( learning_rate_scheduler_fn: dict with two functions (one for the policy and one for the critic optimizer), that takes in a trainer step t and returns the current learning rate. + update_obs_once: bool = use only critic gradients to update shared + observation network """ super().__init__( @@ -580,6 +599,7 @@ def __init__( counts=counts, bootstrap_n=bootstrap_n, learning_rate_scheduler_fn=learning_rate_scheduler_fn, + update_obs_once=update_obs_once, ) # Forward pass that calculates loss. @@ -751,6 +771,7 @@ def __init__( logger: loggers.Logger = None, bootstrap_n: int = 10, learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, + update_obs_once: bool = False, ): """Initialise Recurrent MAD4PG trainer @@ -787,6 +808,8 @@ def __init__( learning_rate_scheduler_fn: dict with two functions (one for the policy and one for the critic optimizer), that takes in a trainer step t and returns the current learning rate. + update_obs_once: bool = use only critic gradients to update shared + observation network """ super().__init__( @@ -812,6 +835,7 @@ def __init__( counts=counts, bootstrap_n=bootstrap_n, learning_rate_scheduler_fn=learning_rate_scheduler_fn, + update_obs_once=update_obs_once, ) @@ -844,6 +868,7 @@ def __init__( logger: loggers.Logger = None, bootstrap_n: int = 10, learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, + update_obs_once: bool = False, ): """Initialise Recurrent MAD4PG trainer @@ -880,6 +905,8 @@ def __init__( learning_rate_scheduler_fn: dict with two functions (one for the policy and one for the critic optimizer), that takes in a trainer step t and returns the current learning rate. + update_obs_once: bool = use only critic gradients to update shared + observation network """ super().__init__( @@ -905,6 +932,7 @@ def __init__( counts=counts, bootstrap_n=bootstrap_n, learning_rate_scheduler_fn=learning_rate_scheduler_fn, + update_obs_once=update_obs_once, ) @@ -937,6 +965,7 @@ def __init__( logger: loggers.Logger = None, bootstrap_n: int = 10, learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, + update_obs_once: bool = False, ): """Initialise Recurrent MAD4PG trainer @@ -973,6 +1002,8 @@ def __init__( learning_rate_scheduler_fn: dict with two functions (one for the policy and one for the critic optimizer), that takes in a trainer step t and returns the current learning rate. + update_obs_once: bool = use only critic gradients to update shared + observation network """ super().__init__( @@ -998,6 +1029,7 @@ def __init__( counts=counts, bootstrap_n=bootstrap_n, learning_rate_scheduler_fn=learning_rate_scheduler_fn, + update_obs_once=update_obs_once, ) @@ -1065,6 +1097,8 @@ def __init__( learning_rate_scheduler_fn: dict with two functions (one for the policy and one for the critic optimizer), that takes in a trainer step t and returns the current learning rate. + update_obs_once: bool = use only critic gradients to update shared + observation network """ super().__init__( diff --git a/mava/systems/tf/maddpg/system.py b/mava/systems/tf/maddpg/system.py index c57b779e2..85eba3675 100644 --- a/mava/systems/tf/maddpg/system.py +++ b/mava/systems/tf/maddpg/system.py @@ -206,7 +206,7 @@ def __init__( # noqa happen at every timestep. E.g. to evaluate a system after every 100 executor episodes, evaluator_interval = {"executor_episodes": 100}. - update_obs_once: bool = use only critic gradients to update shared observatipon network + update_obs_once: bool = use only critic gradients to update shared observation network """ if not environment_spec: diff --git a/mava/systems/tf/maddpg/training.py b/mava/systems/tf/maddpg/training.py index 7a81c0aa7..1f4641f53 100644 --- a/mava/systems/tf/maddpg/training.py +++ b/mava/systems/tf/maddpg/training.py @@ -109,7 +109,7 @@ def __init__( learning_rate_scheduler_fn: dict with two functions (one for the policy and one for the critic optimizer), that takes in a trainer step t and returns the current learning rate. - update_obs_once: bool = use only critic gradients to update shared observatipon network + update_obs_once: bool = use only critic gradients to update shared observation network """ self._agents = agents @@ -452,6 +452,10 @@ def _forward(self, inputs: reverb.ReplaySample) -> None: def _backward(self) -> None: """Trainer backward pass updating network parameters""" + # print("CHECK: ", self._update_obs_once) + # tf.print("CHECK TF", self._update_obs_once) + # exit() + # Calculate the gradients and update the networks policy_losses = self.policy_losses critic_losses = self.critic_losses @@ -986,7 +990,7 @@ def __init__( learning_rate_scheduler_fn: dict with two functions (one for the policy and one for the critic optimizer), that takes in a trainer step t and returns the current learning rate. - update_obs_once: bool = use only critic gradients to update shared observatipon network + update_obs_once: bool = use only critic gradients to update shared observation network """ self._bootstrap_n = bootstrap_n @@ -1423,6 +1427,7 @@ def _backward(self) -> None: # Get trainable variables. print("CHECK: ", self._update_obs_once) + exit() if self._update_obs_once: policy_variables = self._policy_networks[agent_key].trainable_variables else: From 771fc5a5f485fb04a87fb98a1e44875747d697e5 Mon Sep 17 00:00:00 2001 From: AsadJeewa Date: Thu, 31 Mar 2022 16:51:33 +0200 Subject: [PATCH 046/104] remove debug code --- mava/systems/tf/mad4pg/system.py | 3 ++- mava/systems/tf/maddpg/training.py | 10 ++++------ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/mava/systems/tf/mad4pg/system.py b/mava/systems/tf/mad4pg/system.py index 0ad9358cb..c5ebfac0e 100644 --- a/mava/systems/tf/mad4pg/system.py +++ b/mava/systems/tf/mad4pg/system.py @@ -187,7 +187,8 @@ def __init__( happen at every timestep. E.g. to evaluate a system after every 100 executor episodes, evaluator_interval = {"executor_episodes": 100}. - update_obs_once: bool = use only critic gradients to update shared observation network + update_obs_once: bool = use only critic gradients to update shared + observation network """ super().__init__( diff --git a/mava/systems/tf/maddpg/training.py b/mava/systems/tf/maddpg/training.py index 1f4641f53..27cd686da 100644 --- a/mava/systems/tf/maddpg/training.py +++ b/mava/systems/tf/maddpg/training.py @@ -109,7 +109,8 @@ def __init__( learning_rate_scheduler_fn: dict with two functions (one for the policy and one for the critic optimizer), that takes in a trainer step t and returns the current learning rate. - update_obs_once: bool = use only critic gradients to update shared observation network + update_obs_once: bool = use only critic gradients to update shared + observation network """ self._agents = agents @@ -452,10 +453,6 @@ def _forward(self, inputs: reverb.ReplaySample) -> None: def _backward(self) -> None: """Trainer backward pass updating network parameters""" - # print("CHECK: ", self._update_obs_once) - # tf.print("CHECK TF", self._update_obs_once) - # exit() - # Calculate the gradients and update the networks policy_losses = self.policy_losses critic_losses = self.critic_losses @@ -990,7 +987,8 @@ def __init__( learning_rate_scheduler_fn: dict with two functions (one for the policy and one for the critic optimizer), that takes in a trainer step t and returns the current learning rate. - update_obs_once: bool = use only critic gradients to update shared observation network + update_obs_once: bool = use only critic gradients to update shared + observation network """ self._bootstrap_n = bootstrap_n From 50eaae7a9bda3e60fe8062366a66acbbe857a398 Mon Sep 17 00:00:00 2001 From: AsadJeewa Date: Thu, 31 Mar 2022 16:53:17 +0200 Subject: [PATCH 047/104] remove debug code --- mava/systems/tf/maddpg/training.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mava/systems/tf/maddpg/training.py b/mava/systems/tf/maddpg/training.py index 27cd686da..3a31a66b1 100644 --- a/mava/systems/tf/maddpg/training.py +++ b/mava/systems/tf/maddpg/training.py @@ -1424,8 +1424,7 @@ def _backward(self) -> None: agent_key = self._agent_net_keys[agent] # Get trainable variables. - print("CHECK: ", self._update_obs_once) - exit() + if self._update_obs_once: policy_variables = self._policy_networks[agent_key].trainable_variables else: From 86147a2c4835eda9fa190a77030300f094ef78d4 Mon Sep 17 00:00:00 2001 From: Dries Date: Fri, 1 Apr 2022 07:32:48 +0200 Subject: [PATCH 048/104] feature: Update Gaussian head's settings control the possible distributions better. --- mava/systems/tf/mappo/networks.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mava/systems/tf/mappo/networks.py b/mava/systems/tf/mappo/networks.py index 3915e9372..6c7f876ab 100644 --- a/mava/systems/tf/mappo/networks.py +++ b/mava/systems/tf/mappo/networks.py @@ -184,6 +184,9 @@ def make_default_networks( num_dimensions=num_actions, w_init=tf.initializers.VarianceScaling(1e-4, seed=seed), b_init=tf.initializers.Zeros(), + min_scale=1e-4, + tanh_mean=True, + use_tfd_independent=True, ), ClippedGaussianHead(specs[key].actions), ] From 91fc6a0f14b81e008229bbb39e3f7eb3ae55c484 Mon Sep 17 00:00:00 2001 From: Dries Date: Fri, 1 Apr 2022 07:35:37 +0200 Subject: [PATCH 049/104] feature: Small update. --- mava/systems/tf/mappo/networks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mava/systems/tf/mappo/networks.py b/mava/systems/tf/mappo/networks.py index 6c7f876ab..2aad0317f 100644 --- a/mava/systems/tf/mappo/networks.py +++ b/mava/systems/tf/mappo/networks.py @@ -184,7 +184,7 @@ def make_default_networks( num_dimensions=num_actions, w_init=tf.initializers.VarianceScaling(1e-4, seed=seed), b_init=tf.initializers.Zeros(), - min_scale=1e-4, + min_scale=1e-3, tanh_mean=True, use_tfd_independent=True, ), From 3ca5feb441e2ef5cf894893be965a337d59b3582 Mon Sep 17 00:00:00 2001 From: Dries Date: Fri, 1 Apr 2022 10:10:43 +0200 Subject: [PATCH 050/104] fix: Update black version. --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index ff712240f..d96824bba 100644 --- a/setup.py +++ b/setup.py @@ -55,7 +55,7 @@ "mypy==0.941", "pytest-xdist", "flake8==3.8.2", - "black==21.4b1", + "black==22.3.0", "pytest-cov", "interrogate", "pydocstyle", From 3852b0a443f8c7fbde18aeb2e49e2eed91a06f01 Mon Sep 17 00:00:00 2001 From: Kale-ab Date: Fri, 1 Apr 2022 10:39:07 +0200 Subject: [PATCH 051/104] chore: Ignore files generated in testing when formatting code. --- .github/workflows/ci.yaml | 8 ++++---- pyproject.toml | 1 + 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index b4d839035..997dabf54 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -18,11 +18,11 @@ jobs: steps: - name: Checkout mava uses: actions/checkout@v2 - - name: Run tests in docker - run: | - docker run --mount "type=bind,src=$(pwd),dst=/tmp/mava" \ - -w "/tmp/mava" --rm ${{ matrix.docker-image }} /bin/bash bash_scripts/tests.sh - name: Check format and types run: | docker run --mount "type=bind,src=$(pwd),dst=/tmp/mava" \ -w "/tmp/mava" --rm ${{ matrix.docker-image }} /bin/bash bash_scripts/check_format.sh + - name: Run tests in docker + run: | + docker run --mount "type=bind,src=$(pwd),dst=/tmp/mava" \ + -w "/tmp/mava" --rm ${{ matrix.docker-image }} /bin/bash bash_scripts/tests.sh diff --git a/pyproject.toml b/pyproject.toml index 2a0573e09..d159cef74 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ exclude = ''' | dist | proto | docs + | mava_testing )/ ) ''' From 9da3b553877b7b5489177795743e2eace1c590b5 Mon Sep 17 00:00:00 2001 From: Dries Date: Fri, 1 Apr 2022 10:43:12 +0200 Subject: [PATCH 052/104] fix: Update black version. --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e4872232f..48e8ee082 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -26,7 +26,7 @@ repos: - id: check-yaml - repo: https://github.com/psf/black - rev: 21.4b1 + rev: 22.3.0 hooks: - id: black exclude_types: [image] From a76921aed7b738657846e38b3b04e2e006ca1c5e Mon Sep 17 00:00:00 2001 From: Kale-ab Date: Fri, 1 Apr 2022 11:05:04 +0200 Subject: [PATCH 053/104] chore: Minor black changes. --- mava/wrappers/open_spiel.py | 2 +- mava/wrappers/pettingzoo.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mava/wrappers/open_spiel.py b/mava/wrappers/open_spiel.py index 4b27e96c3..c2c36b16e 100644 --- a/mava/wrappers/open_spiel.py +++ b/mava/wrappers/open_spiel.py @@ -191,7 +191,7 @@ def env_done(self) -> bool: """ return not self.agents - def agent_iter(self, max_iter: int = 2 ** 63) -> Iterator: + def agent_iter(self, max_iter: int = 2**63) -> Iterator: """Agent iterator to loop through agents. Args: diff --git a/mava/wrappers/pettingzoo.py b/mava/wrappers/pettingzoo.py index d0d92abf6..b7e566c67 100644 --- a/mava/wrappers/pettingzoo.py +++ b/mava/wrappers/pettingzoo.py @@ -141,7 +141,7 @@ def env_done(self) -> bool: """ return not self.agents - def agent_iter(self, max_iter: int = 2 ** 63) -> Iterator: + def agent_iter(self, max_iter: int = 2**63) -> Iterator: """Agent iterator to loop through agents. Args: From 4184fa7eef653385b0c1b0c88c174fa1138814cb Mon Sep 17 00:00:00 2001 From: Dries Date: Tue, 12 Apr 2022 11:36:29 +0200 Subject: [PATCH 054/104] feat: Add capability to fix the sampler for MADDPG. --- .../run_maddpg_scale_trainers_fixed_nets.py | 118 ++++++++++++++++++ mava/systems/tf/mad4pg/execution.py | 8 ++ mava/systems/tf/maddpg/builder.py | 4 + mava/systems/tf/maddpg/execution.py | 10 ++ mava/systems/tf/maddpg/system.py | 5 + mava/utils/sort_utils.py | 11 +- 6 files changed, 154 insertions(+), 2 deletions(-) create mode 100644 examples/tf/debugging/simple_spread/feedforward/decentralised/run_maddpg_scale_trainers_fixed_nets.py diff --git a/examples/tf/debugging/simple_spread/feedforward/decentralised/run_maddpg_scale_trainers_fixed_nets.py b/examples/tf/debugging/simple_spread/feedforward/decentralised/run_maddpg_scale_trainers_fixed_nets.py new file mode 100644 index 000000000..ef9c5ba34 --- /dev/null +++ b/examples/tf/debugging/simple_spread/feedforward/decentralised/run_maddpg_scale_trainers_fixed_nets.py @@ -0,0 +1,118 @@ +# python3 +# Copyright 2021 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Example running feedforward MADDPG on debug MPE environments. \ + NB: Using multiple trainers with non-shared weights is still in its \ + experimental phase of development. This feature will become faster and \ + more stable in future Mava updates.""" + +import functools +from datetime import datetime +from typing import Any + +import launchpad as lp +import sonnet as snt +from absl import app, flags + +from mava.systems.tf import maddpg +from mava.systems.tf.maddpg import make_default_networks +from mava.utils import enums, lp_utils +from mava.utils.environments import debugging_utils +from mava.utils.loggers import logger_utils + +FLAGS = flags.FLAGS +flags.DEFINE_string( + "env_name", + "simple_spread", + "Debugging environment name (str).", +) +flags.DEFINE_string( + "action_space", + "continuous", + "Environment action space type (str).", +) +flags.DEFINE_string( + "mava_id", + str(datetime.now()), + "Experiment identifier that can be used to continue experiments.", +) +flags.DEFINE_string("base_dir", "~/mava/", "Base dir to store experiments.") + + +def main(_: Any) -> None: + """Run main script + + Args: + _ : _ + """ + + # environment + environment_factory = functools.partial( + debugging_utils.make_environment, + env_name=FLAGS.env_name, + action_space=FLAGS.action_space, + ) + + # networks + network_factory = lp_utils.partial_kwargs(make_default_networks) + + # Checkpointer appends "Checkpoints" to checkpoint_dir + checkpoint_dir = f"{FLAGS.base_dir}/{FLAGS.mava_id}" + + # Log every [log_every] seconds. + log_every = 10 + logger_factory = functools.partial( + logger_utils.make_logger, + directory=FLAGS.base_dir, + to_terminal=True, + to_tensorboard=True, + time_stamp=FLAGS.mava_id, + time_delta=log_every, + ) + + # distributed program + """NB: Using multiple trainers with non-shared weights is still in its + experimental phase of development. This feature will become faster and + more stable in future Mava updates.""" + program = maddpg.MADDPG( + environment_factory=environment_factory, + network_factory=network_factory, + logger_factory=logger_factory, + num_executors=2, + shared_weights=False, + trainer_networks=enums.Trainer.one_trainer_per_network, + network_sampling_setup=[["network_agent_0"], ["network_agent_1"]], + fix_sampler=[0, 0, 1], + policy_optimizer=snt.optimizers.Adam(learning_rate=1e-4), + critic_optimizer=snt.optimizers.Adam(learning_rate=1e-4), + checkpoint_subpath=checkpoint_dir, + max_gradient_norm=40.0, + ).build() + + # Ensure only trainer runs on gpu, while other processes run on cpu. + local_resources = lp_utils.to_device( + program_nodes=program.groups.keys(), nodes_on_gpu=["trainer"] + ) + + lp.launch( + program, + lp.LaunchType.LOCAL_MULTI_PROCESSING, + terminal="current_terminal", + local_resources=local_resources, + ) + + +if __name__ == "__main__": + app.run(main) diff --git a/mava/systems/tf/mad4pg/execution.py b/mava/systems/tf/mad4pg/execution.py index e27057029..0e595a5a3 100644 --- a/mava/systems/tf/mad4pg/execution.py +++ b/mava/systems/tf/mad4pg/execution.py @@ -39,6 +39,7 @@ def __init__( agent_specs: Dict[str, EnvironmentSpec], agent_net_keys: Dict[str, str], network_sampling_setup: List, + fix_sampler: Optional[List], net_keys_to_ids: Dict[str, int], evaluator: bool = False, adder: Optional[adders.ReverbParallelAdder] = None, @@ -57,6 +58,8 @@ def __init__( agent_net_keys: specifies what network each agent uses. network_sampling_setup: List of networks that are randomly sampled from by the executors at the start of an environment run. + fix_sampler: Optional list that can fix the executor sampler to sample + in a specific way. net_keys_to_ids: Specifies a mapping from network keys to their integer id. adder: adder which sends data to a replay buffer. Defaults to None. @@ -76,6 +79,7 @@ def __init__( counts=counts, agent_net_keys=agent_net_keys, network_sampling_setup=network_sampling_setup, + fix_sampler=fix_sampler, net_keys_to_ids=net_keys_to_ids, evaluator=evaluator, interval=interval, @@ -94,6 +98,7 @@ def __init__( agent_specs: Dict[str, EnvironmentSpec], agent_net_keys: Dict[str, str], network_sampling_setup: List, + fix_sampler: Optional[List], net_keys_to_ids: Dict[str, int], evaluator: bool = False, adder: Optional[adders.ReverbParallelAdder] = None, @@ -111,6 +116,8 @@ def __init__( agent_net_keys: specifies what network each agent uses. network_sampling_setup: List of networks that are randomly sampled from by the executors at the start of an environment run. + fix_sampler: Optional list that can fix the executor sampler to sample + in a specific way. net_keys_to_ids: Specifies a mapping from network keys to their integer id. adder: adder which sends data to a replay buffer. Defaults to None. @@ -132,6 +139,7 @@ def __init__( counts=counts, agent_net_keys=agent_net_keys, network_sampling_setup=network_sampling_setup, + fix_sampler=fix_sampler, net_keys_to_ids=net_keys_to_ids, evaluator=evaluator, interval=interval, diff --git a/mava/systems/tf/maddpg/builder.py b/mava/systems/tf/maddpg/builder.py index c01550608..47911d711 100644 --- a/mava/systems/tf/maddpg/builder.py +++ b/mava/systems/tf/maddpg/builder.py @@ -56,6 +56,8 @@ class MADDPGConfig: table_network_config: Networks each table (trainer) expects. network_sampling_setup: List of networks that are randomly sampled from by the executors at the start of an environment run. + fix_sampler: Optional list that can fix the executor sampler to sample + in a specific way. net_keys_to_ids: mapping from net_key to network id. unique_net_keys: list of unique net_keys. checkpoint_minute_interval: The number of minutes to wait between @@ -105,6 +107,7 @@ class MADDPGConfig: trainer_networks: Dict[str, List] table_network_config: Dict[str, List] network_sampling_setup: List + fix_sampler: Optional[List] net_keys_to_ids: Dict[str, int] unique_net_keys: List[str] checkpoint_minute_interval: int @@ -498,6 +501,7 @@ def make_executor( agent_specs=self._config.environment_spec.get_agent_specs(), agent_net_keys=self._config.agent_net_keys, network_sampling_setup=self._config.network_sampling_setup, + fix_sampler=self._config.fix_sampler, variable_client=variable_client, adder=adder, evaluator=evaluator, diff --git a/mava/systems/tf/maddpg/execution.py b/mava/systems/tf/maddpg/execution.py index 8d43f8351..9af582115 100644 --- a/mava/systems/tf/maddpg/execution.py +++ b/mava/systems/tf/maddpg/execution.py @@ -51,6 +51,7 @@ def __init__( agent_specs: Dict[str, EnvironmentSpec], agent_net_keys: Dict[str, str], network_sampling_setup: List, + fix_sampler: Optional[List], net_keys_to_ids: Dict[str, int], evaluator: bool = False, adder: Optional[adders.ReverbParallelAdder] = None, @@ -68,6 +69,8 @@ def __init__( agent_net_keys: specifies what network each agent uses. network_sampling_setup: List of networks that are randomly sampled from by the executors at the start of an environment run. + fix_sampler: Optional list that can fix the executor sampler to sample + in a specific way. net_keys_to_ids: Specifies a mapping from network keys to their integer id. adder: adder which sends data to a replay buffer. Defaults to None. @@ -82,6 +85,7 @@ def __init__( # Store these for later use. self._agent_specs = agent_specs self._network_sampling_setup = network_sampling_setup + self._fix_sampler = fix_sampler self._counts = counts self._network_int_keys_extras: Dict[str, np.ndarray] = {} self._net_keys_to_ids = net_keys_to_ids @@ -199,6 +203,7 @@ def observe_first( agents, self._network_sampling_setup, self._net_keys_to_ids, + self._fix_sampler, ) extras["network_int_keys"] = self._network_int_keys_extras @@ -244,6 +249,7 @@ def __init__( agent_specs: Dict[str, EnvironmentSpec], agent_net_keys: Dict[str, str], network_sampling_setup: List, + fix_sampler: Optional[List], net_keys_to_ids: Dict[str, int], evaluator: bool = False, adder: Optional[adders.ReverbParallelAdder] = None, @@ -261,6 +267,8 @@ def __init__( agent_net_keys: specifies what network each agent uses. network_sampling_setup: List of networks that are randomly sampled from by the executors at the start of an environment run. + fix_sampler: Optional list that can fix the executor sampler to sample + in a specific way. net_keys_to_ids: Specifies a mapping from network keys to their integer id. adder: adder which sends data to a replay buffer. Defaults to None. @@ -277,6 +285,7 @@ def __init__( # Store these for later use. self._agent_specs = agent_specs self._network_sampling_setup = network_sampling_setup + self._fix_sampler = fix_sampler self._counts = counts self._net_keys_to_ids = net_keys_to_ids self._network_int_keys_extras: Dict[str, np.ndarray] = {} @@ -423,6 +432,7 @@ def observe_first( agents, self._network_sampling_setup, self._net_keys_to_ids, + self._fix_sampler, ) if self._store_recurrent_state: diff --git a/mava/systems/tf/maddpg/system.py b/mava/systems/tf/maddpg/system.py index baecaa884..842544e9d 100644 --- a/mava/systems/tf/maddpg/system.py +++ b/mava/systems/tf/maddpg/system.py @@ -75,6 +75,7 @@ def __init__( # noqa network_sampling_setup: Union[ List, enums.NetworkSampler ] = enums.NetworkSampler.fixed_agent_networks, + fix_sampler: Optional[List] = None, shared_weights: bool = True, environment_spec: mava_specs.MAEnvironmentSpec = None, discount: float = 0.99, @@ -142,6 +143,8 @@ def __init__( # noqa Custom list: Alternatively one can specify a custom nested list, with network keys in, that will be used by the executors at the start of each episode to sample networks for each agent. + fix_sampler: Optional list that can fix the executor sampler to sample + in a specific way. shared_weights: whether agents should share weights or not. When network_sampling_setup are provided the value of shared_weights is ignored. @@ -269,6 +272,7 @@ def __init__( # noqa _, self._agent_net_keys = sample_new_agent_keys( agents, self._network_sampling_setup, # type: ignore + fix_sampler=fix_sampler, ) # Check that the environment and agent_net_keys has the same amount of agents @@ -373,6 +377,7 @@ def __init__( # noqa table_network_config=table_network_config, num_executors=num_executors, network_sampling_setup=self._network_sampling_setup, # type: ignore + fix_sampler=fix_sampler, net_keys_to_ids=net_keys_to_ids, unique_net_keys=unique_net_keys, discount=discount, diff --git a/mava/utils/sort_utils.py b/mava/utils/sort_utils.py index 38f4fd55a..2fbe24d2e 100644 --- a/mava/utils/sort_utils.py +++ b/mava/utils/sort_utils.py @@ -1,6 +1,6 @@ import copy import re -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Optional, Tuple import numpy as np from numpy.random import randint @@ -27,6 +27,7 @@ def sample_new_agent_keys( agents: List, network_sampling_setup: List, net_keys_to_ids: Dict[str, int] = None, + fix_sampler: Optional[List] = None, ) -> Tuple[Dict[str, np.ndarray], Dict[str, str]]: """ Samples new agent networks using the network sampling setup. @@ -42,8 +43,13 @@ def sample_new_agent_keys( save_net_keys = {} agent_net_keys = {} agent_slots = copy.copy(agents) + loop_i = 0 while len(agent_slots) > 0: - sample = network_sampling_setup[randint(len(network_sampling_setup))] + if not fix_sampler: + sample_i = randint(len(network_sampling_setup)) + else: + sample_i = fix_sampler[loop_i] + sample = network_sampling_setup[sample_i] for net_key in sample: agent = agent_slots.pop(0) agent_net_keys[agent] = net_key @@ -51,5 +57,6 @@ def sample_new_agent_keys( save_net_keys[agent] = np.array( net_keys_to_ids[net_key], dtype=np.int32 ) + loop_i += 1 return save_net_keys, agent_net_keys From 24dfe329f907421714fe6dbef8a6e1b82cb1793e Mon Sep 17 00:00:00 2001 From: Dries Date: Tue, 12 Apr 2022 11:42:04 +0200 Subject: [PATCH 055/104] fix: Small fix to MAD4PG. --- mava/systems/tf/mad4pg/system.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mava/systems/tf/mad4pg/system.py b/mava/systems/tf/mad4pg/system.py index 906acf5a7..40f0ec46d 100644 --- a/mava/systems/tf/mad4pg/system.py +++ b/mava/systems/tf/mad4pg/system.py @@ -58,6 +58,7 @@ def __init__( network_sampling_setup: Union[ List, enums.NetworkSampler ] = enums.NetworkSampler.fixed_agent_networks, + fix_sampler: Optional[List] = None, shared_weights: bool = True, discount: float = 0.99, batch_size: int = 256, @@ -122,6 +123,8 @@ def __init__( Custom list: Alternatively one can specify a custom nested list, with network keys in, that will be used by the executors at the start of each episode to sample networks for each agent. + fix_sampler: Optional list that can fix the executor sampler to sample + in a specific way. shared_weights: whether agents should share weights or not. When network_sampling_setup are provided the value of shared_weights is ignored. @@ -199,6 +202,7 @@ def __init__( environment_spec=environment_spec, trainer_networks=trainer_networks, network_sampling_setup=network_sampling_setup, + fix_sampler=fix_sampler, shared_weights=shared_weights, discount=discount, batch_size=batch_size, From ae33e1ce7aed985f23f52c7c78d2d581d8c96560 Mon Sep 17 00:00:00 2001 From: Edan Toledo Date: Tue, 12 Apr 2022 13:23:13 +0200 Subject: [PATCH 056/104] fix: multiple agents and multiple trainers using separate tables and separate networks --- mava/systems/tf/maddpg/builder.py | 31 ++++++++++++++++++++++--------- mava/systems/tf/maddpg/system.py | 7 +++++++ 2 files changed, 29 insertions(+), 9 deletions(-) diff --git a/mava/systems/tf/maddpg/builder.py b/mava/systems/tf/maddpg/builder.py index 47911d711..8bc62ee4e 100644 --- a/mava/systems/tf/maddpg/builder.py +++ b/mava/systems/tf/maddpg/builder.py @@ -58,6 +58,7 @@ class MADDPGConfig: sampled from by the executors at the start of an environment run. fix_sampler: Optional list that can fix the executor sampler to sample in a specific way. + TODO: net_spec_keys net_keys_to_ids: mapping from net_key to network id. unique_net_keys: list of unique net_keys. checkpoint_minute_interval: The number of minutes to wait between @@ -108,6 +109,7 @@ class MADDPGConfig: table_network_config: Dict[str, List] network_sampling_setup: List fix_sampler: Optional[List] + net_spec_keys: Dict[str, str] net_keys_to_ids: Dict[str, int] unique_net_keys: List[str] checkpoint_minute_interval: int @@ -203,19 +205,26 @@ def convert_discrete_to_bounded( ) return env_adder_spec - def covert_specs(self, spec: Dict[str, Any], num_networks: int) -> Dict[str, Any]: + def covert_specs( + self, spec: Dict[str, Any], list_of_networks: List + ) -> Dict[str, Any]: if type(spec) is not dict: return spec - agents = sort_str_num(self._config.agent_net_keys.keys())[:num_networks] + agents = [] + for network in list_of_networks: + agents.append(self._config.net_spec_keys[network]) + + agents = sort_str_num(agents) converted_spec: Dict[str, Any] = {} + if agents[0] in spec.keys(): - for agent in agents: - converted_spec[agent] = spec[agent] + for i, agent in enumerate(agents): + converted_spec[self._agents[i]] = spec[agent] else: # For the extras for key in spec.keys(): - converted_spec[key] = self.covert_specs(spec[key], num_networks) + converted_spec[key] = self.covert_specs(spec[key], list_of_networks) return converted_spec def make_replay_tables( @@ -276,21 +285,24 @@ def limiter_fn() -> reverb.rate_limiters: # Create table per trainer replay_tables = [] + for table_key in self._config.table_network_config.keys(): + # TODO (dries): Clean the below coverter code up. # Convert a Mava spec - num_networks = len(self._config.table_network_config[table_key]) + + list_of_networks = self._config.table_network_config[table_key] env_spec = copy.deepcopy(env_adder_spec) - env_spec._specs = self.covert_specs(env_spec._specs, num_networks) + env_spec._specs = self.covert_specs(env_spec._specs, list_of_networks) env_spec._keys = list(sort_str_num(env_spec._specs.keys())) if env_spec.extra_specs is not None: env_spec.extra_specs = self.covert_specs( - env_spec.extra_specs, num_networks + env_spec.extra_specs, list_of_networks ) extra_specs = self.covert_specs( self._extra_specs, - num_networks, + list_of_networks, ) replay_tables.append( @@ -303,6 +315,7 @@ def limiter_fn() -> reverb.rate_limiters: signature=adder_sig_fn(env_spec, extra_specs), ) ) + return replay_tables def make_dataset_iterator( diff --git a/mava/systems/tf/maddpg/system.py b/mava/systems/tf/maddpg/system.py index 842544e9d..e6a9a3a90 100644 --- a/mava/systems/tf/maddpg/system.py +++ b/mava/systems/tf/maddpg/system.py @@ -323,6 +323,12 @@ def __init__( # noqa for i in range(len(unique_net_keys)): self._net_spec_keys[unique_net_keys[i]] = agents[i % len(agents)] + #### TODO MAKE NET SPEC KEYS AN ARGUMENT + self._net_spec_keys = { + "network_agent_0": "agent_0", + "network_agent_1": f"agent_{5}", + } + # Setup table_network_config table_network_config = {} for trainer_key in self._trainer_networks.keys(): @@ -378,6 +384,7 @@ def __init__( # noqa num_executors=num_executors, network_sampling_setup=self._network_sampling_setup, # type: ignore fix_sampler=fix_sampler, + net_spec_keys=self._net_spec_keys, net_keys_to_ids=net_keys_to_ids, unique_net_keys=unique_net_keys, discount=discount, From bdcdfabfe63801ed90a56bbc56e731cdd52a12c8 Mon Sep 17 00:00:00 2001 From: Dries Date: Tue, 12 Apr 2022 14:24:42 +0200 Subject: [PATCH 057/104] fix: Small fixes to the net_spec_keys setup code for MADDPG. --- mava/systems/tf/mad4pg/system.py | 4 ++++ mava/systems/tf/maddpg/builder.py | 3 ++- mava/systems/tf/maddpg/system.py | 17 ++++++++--------- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/mava/systems/tf/mad4pg/system.py b/mava/systems/tf/mad4pg/system.py index 40f0ec46d..6a23de0ca 100644 --- a/mava/systems/tf/mad4pg/system.py +++ b/mava/systems/tf/mad4pg/system.py @@ -59,6 +59,7 @@ def __init__( List, enums.NetworkSampler ] = enums.NetworkSampler.fixed_agent_networks, fix_sampler: Optional[List] = None, + net_spec_keys: Dict = {}, shared_weights: bool = True, discount: float = 0.99, batch_size: int = 256, @@ -125,6 +126,8 @@ def __init__( the start of each episode to sample networks for each agent. fix_sampler: Optional list that can fix the executor sampler to sample in a specific way. + net_spec_keys: Optional network to agent mapping used to get the environment + specs for each network. shared_weights: whether agents should share weights or not. When network_sampling_setup are provided the value of shared_weights is ignored. @@ -203,6 +206,7 @@ def __init__( trainer_networks=trainer_networks, network_sampling_setup=network_sampling_setup, fix_sampler=fix_sampler, + net_spec_keys=net_spec_keys, shared_weights=shared_weights, discount=discount, batch_size=batch_size, diff --git a/mava/systems/tf/maddpg/builder.py b/mava/systems/tf/maddpg/builder.py index 8bc62ee4e..371bcdb43 100644 --- a/mava/systems/tf/maddpg/builder.py +++ b/mava/systems/tf/maddpg/builder.py @@ -58,7 +58,8 @@ class MADDPGConfig: sampled from by the executors at the start of an environment run. fix_sampler: Optional list that can fix the executor sampler to sample in a specific way. - TODO: net_spec_keys + net_spec_keys: Optional network to agent mapping used to get the environment + specs for each network. net_keys_to_ids: mapping from net_key to network id. unique_net_keys: list of unique net_keys. checkpoint_minute_interval: The number of minutes to wait between diff --git a/mava/systems/tf/maddpg/system.py b/mava/systems/tf/maddpg/system.py index e6a9a3a90..1dab57abe 100644 --- a/mava/systems/tf/maddpg/system.py +++ b/mava/systems/tf/maddpg/system.py @@ -76,6 +76,7 @@ def __init__( # noqa List, enums.NetworkSampler ] = enums.NetworkSampler.fixed_agent_networks, fix_sampler: Optional[List] = None, + net_spec_keys: Dict = {}, shared_weights: bool = True, environment_spec: mava_specs.MAEnvironmentSpec = None, discount: float = 0.99, @@ -145,6 +146,8 @@ def __init__( # noqa the start of each episode to sample networks for each agent. fix_sampler: Optional list that can fix the executor sampler to sample in a specific way. + net_spec_keys: Optional network to agent mapping used to get the environment + specs for each network. shared_weights: whether agents should share weights or not. When network_sampling_setup are provided the value of shared_weights is ignored. @@ -318,16 +321,12 @@ def __init__( # noqa # Check that all agent_net_keys are in trainer_networks assert unique_net_keys == unique_trainer_net_keys + # Setup specs for each network - self._net_spec_keys = {} - for i in range(len(unique_net_keys)): - self._net_spec_keys[unique_net_keys[i]] = agents[i % len(agents)] - - #### TODO MAKE NET SPEC KEYS AN ARGUMENT - self._net_spec_keys = { - "network_agent_0": "agent_0", - "network_agent_1": f"agent_{5}", - } + self._net_spec_keys = net_spec_keys + if not self._net_spec_keys: + for i in range(len(unique_net_keys)): + self._net_spec_keys[unique_net_keys[i]] = agents[i % len(agents)] # Setup table_network_config table_network_config = {} From 5774a57eccfc1d1ba79023a61b42a52d50cf9e69 Mon Sep 17 00:00:00 2001 From: Edan Toledo Date: Tue, 12 Apr 2022 15:15:36 +0200 Subject: [PATCH 058/104] feat: add multiple trainer, multiple agent, multiple agent architecture support for madqn --- mava/systems/tf/madqn/builder.py | 31 +++++++++++++++++++++--------- mava/systems/tf/madqn/execution.py | 10 ++++++++++ mava/systems/tf/madqn/system.py | 11 ++++++++++- 3 files changed, 42 insertions(+), 10 deletions(-) diff --git a/mava/systems/tf/madqn/builder.py b/mava/systems/tf/madqn/builder.py index a4f10fdd1..8623c1872 100644 --- a/mava/systems/tf/madqn/builder.py +++ b/mava/systems/tf/madqn/builder.py @@ -60,6 +60,10 @@ class MADQNConfig: table_network_config: Networks each table (trainer) expects. network_sampling_setup: List of networks that are randomly sampled from by the executors at the start of an environment run. + fix_sampler: Optional list that can fix the executor sampler to sample + in a specific way. + net_spec_keys: Optional network to agent mapping used to get the environment + specs for each network. net_keys_to_ids: mapping from net_key to network id. unique_net_keys: list of unique net_keys. checkpoint_minute_interval: The number of minutes to wait between @@ -108,6 +112,8 @@ class MADQNConfig: trainer_networks: Dict[str, List] table_network_config: Dict[str, List] network_sampling_setup: List + fix_sampler: Optional[List] + net_spec_keys: Dict[str, str] net_keys_to_ids: Dict[str, int] unique_net_keys: List[str] checkpoint_minute_interval: int @@ -165,7 +171,9 @@ def __init__( self._trainer_fn = trainer_fn self._executor_fn = executor_fn - def covert_specs(self, spec: Dict[str, Any], num_networks: int) -> Dict[str, Any]: + def covert_specs( + self, spec: Dict[str, Any], list_of_networks: List + ) -> Dict[str, Any]: """Convert specs. Args: @@ -178,15 +186,19 @@ def covert_specs(self, spec: Dict[str, Any], num_networks: int) -> Dict[str, Any if type(spec) is not dict: return spec - agents = sort_str_num(self._config.agent_net_keys.keys())[:num_networks] + agents = [] + for network in list_of_networks: + agents.append(self._config.net_spec_keys[network]) + + agents = sort_str_num(agents) converted_spec: Dict[str, Any] = {} if agents[0] in spec.keys(): - for agent in agents: - converted_spec[agent] = spec[agent] + for i, agent in enumerate(agents): + converted_spec[self._agents[i]] = spec[agent] else: # For the extras for key in spec.keys(): - converted_spec[key] = self.covert_specs(spec[key], num_networks) + converted_spec[key] = self.covert_specs(spec[key], list_of_networks) return converted_spec def make_replay_tables( @@ -250,18 +262,18 @@ def limiter_fn() -> reverb.rate_limiters: for table_key in self._config.table_network_config.keys(): # TODO (dries): Clean the below coverter code up. # Convert a Mava spec - num_networks = len(self._config.table_network_config[table_key]) + list_of_networks = self._config.table_network_config[table_key] env_spec = copy.deepcopy(environment_spec) - env_spec._specs = self.covert_specs(env_spec._specs, num_networks) + env_spec._specs = self.covert_specs(env_spec._specs, list_of_networks) env_spec._keys = list(sort_str_num(env_spec._specs.keys())) if env_spec.extra_specs is not None: env_spec.extra_specs = self.covert_specs( - env_spec.extra_specs, num_networks + env_spec.extra_specs, list_of_networks ) extra_specs = self.covert_specs( self._extra_specs, - num_networks, + list_of_networks, ) replay_tables.append( @@ -504,6 +516,7 @@ def make_executor( agent_specs=self._config.environment_spec.get_agent_specs(), agent_net_keys=self._config.agent_net_keys, network_sampling_setup=self._config.network_sampling_setup, + fix_sampler=self._config.fix_sampler, variable_client=variable_client, adder=adder, evaluator=evaluator, diff --git a/mava/systems/tf/madqn/execution.py b/mava/systems/tf/madqn/execution.py index addac6859..ae2ba6c5f 100644 --- a/mava/systems/tf/madqn/execution.py +++ b/mava/systems/tf/madqn/execution.py @@ -115,6 +115,7 @@ def __init__( agent_specs: Dict[str, EnvironmentSpec], agent_net_keys: Dict[str, str], network_sampling_setup: List, + fix_sampler: Optional[List], net_keys_to_ids: Dict[str, int], evaluator: bool = False, adder: Optional[adders.ReverbParallelAdder] = None, @@ -135,6 +136,8 @@ def __init__( agent_net_keys: specifies what network each agent uses. network_sampling_setup: List of networks that are randomly sampled from by the executors at the start of an environment run. + fix_sampler: Optional list that can fix the executor sampler to sample + in a specific way. net_keys_to_ids: Specifies a mapping from network keys to their integer id. adder: adder which sends data to a replay buffer. Defaults to None. @@ -148,6 +151,7 @@ def __init__( # Store these for later use. self._agent_specs = agent_specs self._network_sampling_setup = network_sampling_setup + self._fix_sampler = fix_sampler self._counts = counts self._network_int_keys_extras: Dict[str, np.ndarray] = {} self._net_keys_to_ids = net_keys_to_ids @@ -257,6 +261,7 @@ def observe_first( agents, self._network_sampling_setup, self._net_keys_to_ids, + self._fix_sampler, ) extras["network_int_keys"] = self._network_int_keys_extras @@ -308,6 +313,7 @@ def __init__( agent_specs: Dict[str, EnvironmentSpec], agent_net_keys: Dict[str, str], network_sampling_setup: List, + fix_sampler: Optional[List], net_keys_to_ids: Dict[str, int], evaluator: bool = False, adder: Optional[adders.ReverbParallelAdder] = None, @@ -330,6 +336,8 @@ def __init__( agent_net_keys: specifies what network each agent uses. network_sampling_setup: List of networks that are randomly sampled from by the executors at the start of an environment run. + fix_sampler: Optional list that can fix the executor sampler to sample + in a specific way. net_keys_to_ids: Specifies a mapping from network keys to their integer id. adder: adder which sends data to a replay buffer. Defaults to None. @@ -346,6 +354,7 @@ def __init__( # Store these for later use. self._agent_specs = agent_specs self._network_sampling_setup = network_sampling_setup + self._fix_sampler = fix_sampler self._counts = counts self._net_keys_to_ids = net_keys_to_ids self._network_int_keys_extras: Dict[str, np.ndarray] = {} @@ -483,6 +492,7 @@ def observe_first( agents, self._network_sampling_setup, self._net_keys_to_ids, + self._fix_sampler ) if self._store_recurrent_state: diff --git a/mava/systems/tf/madqn/system.py b/mava/systems/tf/madqn/system.py index cacf23595..81519c8e0 100644 --- a/mava/systems/tf/madqn/system.py +++ b/mava/systems/tf/madqn/system.py @@ -72,6 +72,8 @@ def __init__( # noqa network_sampling_setup: Union[ List, enums.NetworkSampler ] = enums.NetworkSampler.fixed_agent_networks, + fix_sampler: Optional[List] = None, + net_spec_keys: Dict = {}, shared_weights: bool = True, environment_spec: mava_specs.MAEnvironmentSpec = None, discount: float = 0.99, @@ -134,6 +136,10 @@ def __init__( # noqa Custom list: Alternatively one can specify a custom nested list, with network keys in, that will be used by the executors at the start of each episode to sample networks for each agent. + fix_sampler: Optional list that can fix the executor sampler to sample + in a specific way. + net_spec_keys: Optional network to agent mapping used to get the environment + specs for each network. shared_weights: whether agents should share weights or not. When network_sampling_setup are provided the value of shared_weights is ignored. @@ -257,6 +263,7 @@ def __init__( # noqa _, self._agent_net_keys = sample_new_agent_keys( agents, self._network_sampling_setup, # type: ignore + fix_sampler=fix_sampler ) # Check that the environment and agent_net_keys has the same amount of agents sample_length = len(self._network_sampling_setup[0]) # type: ignore @@ -302,7 +309,7 @@ def __init__( # noqa # Check that all agent_net_keys are in trainer_networks assert unique_net_keys == unique_trainer_net_keys # Setup specs for each network - self._net_spec_keys = {} + self._net_spec_keys = net_spec_keys for i in range(len(unique_net_keys)): self._net_spec_keys[unique_net_keys[i]] = agents[i % len(agents)] @@ -384,6 +391,8 @@ def __init__( # noqa table_network_config=table_network_config, num_executors=num_executors, network_sampling_setup=self._network_sampling_setup, # type: ignore + fix_sampler=fix_sampler, + net_spec_keys=self._net_spec_keys, net_keys_to_ids=net_keys_to_ids, unique_net_keys=unique_net_keys, discount=discount, From 1a84f2f795e9ca4f55d5d7f8c73556131cb413f3 Mon Sep 17 00:00:00 2001 From: Edan Toledo Date: Tue, 12 Apr 2022 15:20:35 +0200 Subject: [PATCH 059/104] fix: allow net_spec keys to be passed into create_default_networks --- mava/systems/tf/madqn/networks.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/mava/systems/tf/madqn/networks.py b/mava/systems/tf/madqn/networks.py index 84912c858..b6b5fc04f 100644 --- a/mava/systems/tf/madqn/networks.py +++ b/mava/systems/tf/madqn/networks.py @@ -36,6 +36,7 @@ def make_default_networks( environment_spec: mava_specs.MAEnvironmentSpec, agent_net_keys: Dict[str, str], + net_spec_keys: Dict[str,str] = {}, value_networks_layer_sizes: Union[Dict[str, Sequence], Sequence] = None, architecture_type: ArchitectureType = ArchitectureType.feedforward, atari_torso_observation_network: bool = False, @@ -76,9 +77,12 @@ def make_default_networks( assert value_network_func is not None specs = environment_spec.get_agent_specs() - + # Create agent_type specs - specs = {agent_net_keys[key]: specs[key] for key in specs.keys()} + if not net_spec_keys: + specs = {agent_net_keys[key]: specs[key] for key in specs.keys()} + else: + specs = {net_key: specs[value] for net_key, value in net_spec_keys.items()} if isinstance(value_networks_layer_sizes, Sequence): value_networks_layer_sizes = { From df45468b349b571eda49efdab223f2b264935d19 Mon Sep 17 00:00:00 2001 From: Dries Date: Tue, 12 Apr 2022 16:15:36 +0200 Subject: [PATCH 060/104] fix: Remove redundant statement. --- mava/systems/tf/mappo/training.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/mava/systems/tf/mappo/training.py b/mava/systems/tf/mappo/training.py index 1e91829d8..b8c0fb342 100644 --- a/mava/systems/tf/mappo/training.py +++ b/mava/systems/tf/mappo/training.py @@ -322,10 +322,6 @@ def _forward_pass(self, inputs: Any) -> None: Args: inputs: input data from the data table (transitions) """ - - # Convert to sequence data - data = tf2_utils.batch_to_sequence(inputs) - # Unpack input data as follows: data = tf2_utils.batch_to_sequence(inputs) observations, actions, rewards, discounts, extras = ( From 49f8534da647fc7107e0d162a0ffd61035cd157f Mon Sep 17 00:00:00 2001 From: Dries Date: Tue, 12 Apr 2022 16:18:01 +0200 Subject: [PATCH 061/104] fix: Remove redundant statement. --- mava/systems/tf/mappo/training.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/mava/systems/tf/mappo/training.py b/mava/systems/tf/mappo/training.py index b8c0fb342..d1420a359 100644 --- a/mava/systems/tf/mappo/training.py +++ b/mava/systems/tf/mappo/training.py @@ -335,9 +335,6 @@ def _forward_pass(self, inputs: Any) -> None: if "core_states" in extras: core_states = tree.map_structure(lambda s: s[0], extras["core_states"]) - # transform observation using observation networks - observations_trans = self._transform_observations(observations) - # Store losses. policy_losses: Dict[str, Any] = {} critic_losses: Dict[str, Any] = {} From 34c538cfa39580cdcf94f47bbd1fabbcf122f1e8 Mon Sep 17 00:00:00 2001 From: Dries Date: Tue, 12 Apr 2022 16:42:41 +0200 Subject: [PATCH 062/104] feat: Small improvement. --- mava/systems/tf/mappo/training.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mava/systems/tf/mappo/training.py b/mava/systems/tf/mappo/training.py index d1420a359..989fc7864 100644 --- a/mava/systems/tf/mappo/training.py +++ b/mava/systems/tf/mappo/training.py @@ -279,11 +279,11 @@ def _step( } # Get data from replay. inputs = next(self._iterator) - # Split for possible minibatches - batch_size = inputs.data.observations[self._agents[0]].observation.shape[0] - dataset = tf.data.Dataset.from_tensor_slices(inputs.data) - dataset = dataset.shuffle(batch_size).batch(self._minibatch_size) for _ in range(self._num_epochs): + # Split for possible minibatches + batch_size = inputs.data.observations[self._agents[0]].observation.shape[0] + dataset = tf.data.Dataset.from_tensor_slices(inputs.data) + dataset = dataset.shuffle(batch_size).batch(self._minibatch_size) for minibatch_data in dataset: loss = self._minibatch_update(minibatch_data) From 18d420f0d86f6575f2d4193a381cfb45763ab9ea Mon Sep 17 00:00:00 2001 From: Edan Toledo Date: Tue, 12 Apr 2022 17:19:05 +0200 Subject: [PATCH 063/104] fix: linting --- mava/systems/tf/madqn/execution.py | 2 +- mava/systems/tf/madqn/networks.py | 4 ++-- mava/systems/tf/madqn/system.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mava/systems/tf/madqn/execution.py b/mava/systems/tf/madqn/execution.py index ae2ba6c5f..0f1b00eff 100644 --- a/mava/systems/tf/madqn/execution.py +++ b/mava/systems/tf/madqn/execution.py @@ -492,7 +492,7 @@ def observe_first( agents, self._network_sampling_setup, self._net_keys_to_ids, - self._fix_sampler + self._fix_sampler, ) if self._store_recurrent_state: diff --git a/mava/systems/tf/madqn/networks.py b/mava/systems/tf/madqn/networks.py index b6b5fc04f..fd79b81c1 100644 --- a/mava/systems/tf/madqn/networks.py +++ b/mava/systems/tf/madqn/networks.py @@ -36,7 +36,7 @@ def make_default_networks( environment_spec: mava_specs.MAEnvironmentSpec, agent_net_keys: Dict[str, str], - net_spec_keys: Dict[str,str] = {}, + net_spec_keys: Dict[str, str] = {}, value_networks_layer_sizes: Union[Dict[str, Sequence], Sequence] = None, architecture_type: ArchitectureType = ArchitectureType.feedforward, atari_torso_observation_network: bool = False, @@ -77,7 +77,7 @@ def make_default_networks( assert value_network_func is not None specs = environment_spec.get_agent_specs() - + # Create agent_type specs if not net_spec_keys: specs = {agent_net_keys[key]: specs[key] for key in specs.keys()} diff --git a/mava/systems/tf/madqn/system.py b/mava/systems/tf/madqn/system.py index 81519c8e0..1819f91f4 100644 --- a/mava/systems/tf/madqn/system.py +++ b/mava/systems/tf/madqn/system.py @@ -263,7 +263,7 @@ def __init__( # noqa _, self._agent_net_keys = sample_new_agent_keys( agents, self._network_sampling_setup, # type: ignore - fix_sampler=fix_sampler + fix_sampler=fix_sampler, ) # Check that the environment and agent_net_keys has the same amount of agents sample_length = len(self._network_sampling_setup[0]) # type: ignore From 7d411dcc06b93d76a373b33489663725723b30b9 Mon Sep 17 00:00:00 2001 From: Dries Date: Wed, 13 Apr 2022 06:00:33 +0200 Subject: [PATCH 064/104] fix: PPO training for networks with Categorical heads. --- mava/systems/tf/mappo/training.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/mava/systems/tf/mappo/training.py b/mava/systems/tf/mappo/training.py index 989fc7864..6a4e43cbb 100644 --- a/mava/systems/tf/mappo/training.py +++ b/mava/systems/tf/mappo/training.py @@ -22,6 +22,7 @@ import numpy as np import sonnet as snt import tensorflow as tf +import tensorflow_probability as tfp import tree from acme.tf import utils as tf2_utils from acme.utils import counting, loggers @@ -32,6 +33,7 @@ from mava.utils import training_utils as train_utils from mava.utils.sort_utils import sort_str_num +tfd = tfp.distributions train_utils.set_growing_gpu_memory() @@ -389,7 +391,15 @@ def _forward_pass(self, inputs: Any) -> None: ) policy = policy_network(actor_observation) - policy.batch_reshape(dims, name="policy") + if isinstance(policy, tfp.distributions.Distribution): + # Tensorflow probability. + policy = tfd.BatchReshape( + policy, batch_shape=dims, name="policy" + ) + else: + # Custom distribution function. + policy.batch_reshape(dims, name="policy") + action_prob = policy.log_prob(action) policy_entropy = policy.entropy() From 25f3ea1015e624dba55474bd8b3ca674b9cc2c45 Mon Sep 17 00:00:00 2001 From: Dries Date: Wed, 13 Apr 2022 06:22:54 +0200 Subject: [PATCH 065/104] fix: Small fix to dataset shuffler. --- mava/systems/tf/mappo/training.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/mava/systems/tf/mappo/training.py b/mava/systems/tf/mappo/training.py index 6a4e43cbb..591cdb6ae 100644 --- a/mava/systems/tf/mappo/training.py +++ b/mava/systems/tf/mappo/training.py @@ -281,12 +281,13 @@ def _step( } # Get data from replay. inputs = next(self._iterator) + batch_size = inputs.data.observations[self._agents[0]].observation.shape[0] + dataset = tf.data.Dataset.from_tensor_slices(inputs.data) for _ in range(self._num_epochs): # Split for possible minibatches - batch_size = inputs.data.observations[self._agents[0]].observation.shape[0] - dataset = tf.data.Dataset.from_tensor_slices(inputs.data) - dataset = dataset.shuffle(batch_size).batch(self._minibatch_size) - for minibatch_data in dataset: + dataset = dataset.shuffle(batch_size) + minibatch_dataset = dataset.batch(self._minibatch_size) + for minibatch_data in minibatch_dataset: loss = self._minibatch_update(minibatch_data) # Logging sum of losses @@ -354,6 +355,8 @@ def _forward_pass(self, inputs: Any) -> None: observations_trans[agent], ) + print("termination: ", termination.shape) + loss_mask = tf.concat( (tf.ones((1, termination.shape[1])), termination[:-1]), 0 ) From dbb5797799766f083d08e920bbf4ce22ae24b61d Mon Sep 17 00:00:00 2001 From: Dries Date: Wed, 13 Apr 2022 06:31:53 +0200 Subject: [PATCH 066/104] fix: Remove print statement. --- mava/systems/tf/mappo/training.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/mava/systems/tf/mappo/training.py b/mava/systems/tf/mappo/training.py index 591cdb6ae..998f27e4b 100644 --- a/mava/systems/tf/mappo/training.py +++ b/mava/systems/tf/mappo/training.py @@ -355,8 +355,6 @@ def _forward_pass(self, inputs: Any) -> None: observations_trans[agent], ) - print("termination: ", termination.shape) - loss_mask = tf.concat( (tf.ones((1, termination.shape[1])), termination[:-1]), 0 ) From a2578194f206ebfe37144bad3d77130e1349f94e Mon Sep 17 00:00:00 2001 From: Dries Date: Wed, 13 Apr 2022 08:08:44 +0200 Subject: [PATCH 067/104] Small fixes to trainer variable client and Hyperparameter settings. --- .../feedforward/decentralised/run_mappo.py | 1 - mava/systems/tf/mappo/builder.py | 4 ++-- mava/systems/tf/mappo/system.py | 16 ++++++++++++---- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/examples/tf/debugging/simple_spread/feedforward/decentralised/run_mappo.py b/examples/tf/debugging/simple_spread/feedforward/decentralised/run_mappo.py index 60225a278..eb475f72d 100644 --- a/examples/tf/debugging/simple_spread/feedforward/decentralised/run_mappo.py +++ b/examples/tf/debugging/simple_spread/feedforward/decentralised/run_mappo.py @@ -85,7 +85,6 @@ def main(_: Any) -> None: logger_factory=logger_factory, num_executors=1, checkpoint_subpath=checkpoint_dir, - num_epochs=15, ).build() # Ensure only trainer runs on gpu, while other processes run on cpu. diff --git a/mava/systems/tf/mappo/builder.py b/mava/systems/tf/mappo/builder.py index 27624fca7..7f8a29990 100644 --- a/mava/systems/tf/mappo/builder.py +++ b/mava/systems/tf/mappo/builder.py @@ -99,7 +99,7 @@ class MAPPOConfig: lambda_gae: float = 0.95 max_queue_size: Optional[int] = 1000 executor_variable_update_period: int = 100 - batch_size: int = 32 + batch_size: int = 512 minibatch_size: Optional[int] = None num_epochs: int = 10 entropy_cost: float = 0.01 @@ -496,7 +496,7 @@ def make_trainer( variables=variables, get_keys=get_keys, set_keys=set_keys, - update_period=10, + update_period=1, ) # Get all the initial variables diff --git a/mava/systems/tf/mappo/system.py b/mava/systems/tf/mappo/system.py index b321c6bc8..0c6aeb9b3 100644 --- a/mava/systems/tf/mappo/system.py +++ b/mava/systems/tf/mappo/system.py @@ -70,7 +70,7 @@ def __init__( # noqa snt.Optimizer, Dict[str, snt.Optimizer] ] = snt.optimizers.Adam(learning_rate=5e-4), critic_optimizer: Optional[snt.Optimizer] = snt.optimizers.Adam( - learning_rate=5e-4 + learning_rate=1e-3 ), discount: float = 0.99, lambda_gae: float = 0.95, @@ -80,10 +80,10 @@ def __init__( # noqa max_gradient_norm: Optional[float] = 0.5, max_queue_size: Optional[int] = None, batch_size: int = 512, - minibatch_size: int = None, - num_epochs: int = 10, + minibatch_size: int = 128, + num_epochs: int = 5, sequence_length: int = 20, - sequence_period: Optional[int] = None, + sequence_period: Optional[int] = 10, max_executor_steps: int = None, checkpoint: bool = True, checkpoint_subpath: str = "~/mava/", @@ -191,6 +191,14 @@ def __init__( # noqa normalize_advantage: whether to normalize the advantage estimate. This can hurt peformance when shared weights are used. """ + + if num_epochs > 5: + raise ValueError( + "Epoch must be smaller than 5. For epoch values large than 5" + + " the client update should be moved to inside the epoch loop" + + "instead of inside the step function." + ) + # minibatch size defaults to train batch size if minibatch_size: self._minibatch_size = minibatch_size From 8e25670fb95a8bea7f75f02a257ea8ddf21ffe31 Mon Sep 17 00:00:00 2001 From: Sasha Abramowitz Date: Wed, 13 Apr 2022 10:26:08 +0200 Subject: [PATCH 068/104] feat: added multiple network fix --- mava/systems/tf/mappo/builder.py | 56 +++++++++++++++++++++++------- mava/systems/tf/mappo/execution.py | 18 +++++++++- mava/systems/tf/mappo/system.py | 30 ++++++++++++++-- mava/utils/sort_utils.py | 11 ++++-- 4 files changed, 96 insertions(+), 19 deletions(-) diff --git a/mava/systems/tf/mappo/builder.py b/mava/systems/tf/mappo/builder.py index 7f8a29990..3e6e2b78b 100644 --- a/mava/systems/tf/mappo/builder.py +++ b/mava/systems/tf/mappo/builder.py @@ -48,6 +48,16 @@ class MAPPOConfig: used if using single optim. agent_net_keys: (dict, optional): specifies what network each agent uses. Defaults to {}. + trainer_networks: networks each trainer trains on. + table_network_config: Networks each table (trainer) expects. + network_sampling_setup: List of networks that are randomly + sampled from by the executors at the start of an environment run. + fix_sampler: Optional list that can fix the executor sampler to sample + in a specific way. + net_spec_keys: Optional network to agent mapping used to get the environment + specs for each network. + net_keys_to_ids: mapping from net_key to network id. + unique_net_keys: list of unique net_keys. checkpoint_minute_interval (int): The number of minutes to wait between checkpoints. sequence_length: recurrent sequence rollout length. @@ -90,6 +100,8 @@ class MAPPOConfig: trainer_networks: Dict[str, List] table_network_config: Dict[str, List] network_sampling_setup: List + fix_sampler: Optional[List] + net_spec_keys: Dict[str, str] net_keys_to_ids: Dict[str, int] unique_net_keys: List[str] checkpoint_minute_interval: int @@ -171,19 +183,34 @@ def add_log_prob_to_spec( ) return env_adder_spec - def covert_specs(self, spec: Dict[str, Any], num_networks: int) -> Dict[str, Any]: + def get_nets_specific_specs( + self, spec: Dict[str, Any], network_names: List + ) -> Dict[str, Any]: + """Convert specs. + Args: + spec: agent specs + network_names: names of the networks in the desired reverb table (to get specs for) + Returns: + distilled version of agent specs containing only specs related to `networks_names` + """ if type(spec) is not dict: return spec - agents = sort_str_num(self._config.agent_net_keys.keys())[:num_networks] + agents = [] + for network in network_names: + agents.append(self._config.net_spec_keys[network]) + + agents = sort_str_num(agents) converted_spec: Dict[str, Any] = {} if agents[0] in spec.keys(): - for agent in agents: - converted_spec[agent] = spec[agent] + for i, agent in enumerate(agents): + converted_spec[self._agents[i]] = spec[agent] else: # For the extras for key in spec.keys(): - converted_spec[key] = self.covert_specs(spec[key], num_networks) + converted_spec[key] = self.get_nets_specific_specs( + spec[key], network_names + ) return converted_spec def make_replay_tables( @@ -203,24 +230,26 @@ def make_replay_tables( # Create system architecture with target networks. adder_env_spec = self.add_log_prob_to_spec(environment_spec) - # Create table per trainer replay_tables = [] for table_key in self._config.table_network_config.keys(): - # TODO (dries): Clean the below coverter code up. + # TODO (dries): Clean the below converter code up. # Convert a Mava spec - num_networks = len(self._config.table_network_config[table_key]) + tables_network_names = self._config.table_network_config[table_key] env_spec = copy.deepcopy(adder_env_spec) - env_spec._specs = self.covert_specs(env_spec._specs, num_networks) + env_spec._specs = self.get_nets_specific_specs( + env_spec._specs, tables_network_names + ) env_spec._keys = list(sort_str_num(env_spec._specs.keys())) if env_spec.extra_specs is not None: - env_spec.extra_specs = self.covert_specs( - env_spec.extra_specs, num_networks + env_spec.extra_specs = self.get_nets_specific_specs( + env_spec.extra_specs, tables_network_names ) - extra_specs = self.covert_specs( + extra_specs = self.get_nets_specific_specs( self._extra_specs, - num_networks, + tables_network_names, ) + signature = reverb_adders.ParallelSequenceAdder.signature( env_spec, sequence_length=self._config.sequence_length, @@ -429,6 +458,7 @@ def make_executor( agent_specs=self._config.environment_spec.get_agent_specs(), agent_net_keys=self._config.agent_net_keys, network_sampling_setup=self._config.network_sampling_setup, + fix_sampler=self._config.fix_sampler, variable_client=variable_client, adder=adder, evaluator=evaluator, diff --git a/mava/systems/tf/mappo/execution.py b/mava/systems/tf/mappo/execution.py index 2d85339cf..65ad556fe 100644 --- a/mava/systems/tf/mappo/execution.py +++ b/mava/systems/tf/mappo/execution.py @@ -48,6 +48,7 @@ def __init__( agent_specs: Dict[str, EnvironmentSpec], agent_net_keys: Dict[str, str], network_sampling_setup: List, + fix_sampler: Optional[List], net_keys_to_ids: Dict[str, int], adder: Optional[adders.ReverbParallelAdder] = None, counts: Optional[Dict[str, Any]] = None, @@ -62,12 +63,18 @@ def __init__( adder: adder which sends data to a replay buffer. variable_client: client to copy weights from the trainer. agent_net_keys: specifies what network each agent uses. + network_sampling_setup: List of networks that are randomly + sampled from by the executors at the start of an environment run. + fix_sampler: Optional list that can fix the executor sampler to sample + in a specific way. + net_keys_to_ids: Specifies a mapping from network keys to their integer id. evaluator: whether the executor will be used for evaluation. Defaults to False. interval: interval that evaluations are run at. """ self._agent_specs = agent_specs self._network_sampling_setup = network_sampling_setup + self._fix_sampler = fix_sampler self._counts = counts self._network_int_keys_extras: Dict[str, Any] = {} self._net_keys_to_ids = net_keys_to_ids @@ -192,6 +199,7 @@ def observe_first( agents, self._network_sampling_setup, self._net_keys_to_ids, + self._fix_sampler, ) extras["network_int_keys"] = self._network_int_keys_extras self._adder.add_first(timestep, extras) @@ -242,6 +250,7 @@ def __init__( agent_specs: Dict[str, EnvironmentSpec], agent_net_keys: Dict[str, str], network_sampling_setup: List, + fix_sampler: Optional[List], net_keys_to_ids: Dict[str, int], adder: Optional[adders.ReverbParallelAdder] = None, counts: Optional[Dict[str, Any]] = None, @@ -253,17 +262,23 @@ def __init__( Args: policy_networks: policy networks for each agent in the system. + agent_net_keys: specifies what network each agent uses. + network_sampling_setup: List of networks that are randomly + sampled from by the executors at the start of an environment run. + fix_sampler: Optional list that can fix the executor sampler to sample + in a specific way. + net_keys_to_ids: Specifies a mapping from network keys to their integer id. adder: adder which sends data to a replay buffer. Defaults to None. variable_client (Optional[tf2_variable_utils.VariableClient], optional): client to copy weights from the trainer. Defaults to None. - agent_net_keys: specifies what network each agent uses. evaluator (bool, optional): whether the executor will be used for evaluation. Defaults to False. interval: interval that evaluations are run at. """ self._agent_specs = agent_specs self._network_sampling_setup = network_sampling_setup + self._fix_sampler = fix_sampler self._counts = counts self._net_keys_to_ids = net_keys_to_ids self._network_int_keys_extras: Dict[str, Any] = {} @@ -397,6 +412,7 @@ def observe_first( agents, self._network_sampling_setup, self._net_keys_to_ids, + self._fix_sampler, ) numpy_states = { diff --git a/mava/systems/tf/mappo/system.py b/mava/systems/tf/mappo/system.py index 0c6aeb9b3..5b10b6194 100644 --- a/mava/systems/tf/mappo/system.py +++ b/mava/systems/tf/mappo/system.py @@ -63,6 +63,8 @@ def __init__( # noqa network_sampling_setup: Union[ List, enums.NetworkSampler ] = enums.NetworkSampler.fixed_agent_networks, + fix_sampler: Optional[List] = None, + net_spec_keys: Dict = {}, environment_spec: mava_specs.MAEnvironmentSpec = None, shared_weights: bool = True, executor_variable_update_period: int = 100, @@ -116,6 +118,24 @@ def __init__( # noqa or recurrent. Defaults to execution.MAPPOFeedForwardExecutor. num_executors : number of executor processes to run in parallel. Defaults to 1. + trainer_networks: networks each trainer trains on. + network_sampling_setup: List of networks that are randomly + sampled from by the executors at the start of an environment run. + enums.NetworkSampler settings: + fixed_agent_networks: Keeps the networks + used by each agent fixed throughout training. + random_agent_networks: Creates N network policies, where N is the + number of agents. Randomly select policies from this sets for each + agent at the start of a episode. This sampling is done with + replacement so the same policy can be selected for more than one + agent for a given episode. + Custom list: Alternatively one can specify a custom nested list, + with network keys in, that will be used by the executors at + the start of each episode to sample networks for each agent. + fix_sampler: Optional list that can fix the executor sampler to sample + in a specific way. + net_spec_keys: Optional network to agent mapping used to get the environment + specs for each network. num_caches : number of trainer node caches. Defaults to 0. environment_spec : description of the action, observation spaces etc. for each agent in the system. @@ -294,6 +314,7 @@ def __init__( # noqa _, self._agent_net_keys = sample_new_agent_keys( agents, self._network_sampling_setup, # type: ignore + fix_sampler=fix_sampler, ) # Check that the environment and agent_net_keys has the same amount of agents @@ -340,9 +361,10 @@ def __init__( # noqa # Check that all agent_net_keys are in trainer_networks assert unique_net_keys == unique_trainer_net_keys # Setup specs for each network - self._net_spec_keys = {} - for i in range(len(unique_net_keys)): - self._net_spec_keys[unique_net_keys[i]] = agents[i % len(agents)] + self._net_spec_keys = net_spec_keys + if not net_spec_keys: + for i in range(len(unique_net_keys)): + self._net_spec_keys[unique_net_keys[i]] = agents[i % len(agents)] # Setup table_network_config table_network_config = {} @@ -407,6 +429,8 @@ def __init__( # noqa trainer_networks=self._trainer_networks, table_network_config=table_network_config, network_sampling_setup=self._network_sampling_setup, # type: ignore + fix_sampler=fix_sampler, + net_spec_keys=self._net_spec_keys, net_keys_to_ids=net_keys_to_ids, unique_net_keys=unique_net_keys, termination_condition=termination_condition, diff --git a/mava/utils/sort_utils.py b/mava/utils/sort_utils.py index 38f4fd55a..2fbe24d2e 100644 --- a/mava/utils/sort_utils.py +++ b/mava/utils/sort_utils.py @@ -1,6 +1,6 @@ import copy import re -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Optional, Tuple import numpy as np from numpy.random import randint @@ -27,6 +27,7 @@ def sample_new_agent_keys( agents: List, network_sampling_setup: List, net_keys_to_ids: Dict[str, int] = None, + fix_sampler: Optional[List] = None, ) -> Tuple[Dict[str, np.ndarray], Dict[str, str]]: """ Samples new agent networks using the network sampling setup. @@ -42,8 +43,13 @@ def sample_new_agent_keys( save_net_keys = {} agent_net_keys = {} agent_slots = copy.copy(agents) + loop_i = 0 while len(agent_slots) > 0: - sample = network_sampling_setup[randint(len(network_sampling_setup))] + if not fix_sampler: + sample_i = randint(len(network_sampling_setup)) + else: + sample_i = fix_sampler[loop_i] + sample = network_sampling_setup[sample_i] for net_key in sample: agent = agent_slots.pop(0) agent_net_keys[agent] = net_key @@ -51,5 +57,6 @@ def sample_new_agent_keys( save_net_keys[agent] = np.array( net_keys_to_ids[net_key], dtype=np.int32 ) + loop_i += 1 return save_net_keys, agent_net_keys From a88de00e1e0a386fbf2b8949b10d65ed04ae0ad1 Mon Sep 17 00:00:00 2001 From: Dries Date: Wed, 13 Apr 2022 10:49:25 +0200 Subject: [PATCH 069/104] feat: Small updates to hyperparameters. Moving system closer to develop's setup. --- mava/systems/tf/mappo/builder.py | 4 ++-- mava/systems/tf/mappo/system.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mava/systems/tf/mappo/builder.py b/mava/systems/tf/mappo/builder.py index 7f8a29990..77488e2e8 100644 --- a/mava/systems/tf/mappo/builder.py +++ b/mava/systems/tf/mappo/builder.py @@ -93,8 +93,8 @@ class MAPPOConfig: net_keys_to_ids: Dict[str, int] unique_net_keys: List[str] checkpoint_minute_interval: int - sequence_length: int = 10 - sequence_period: int = 9 + sequence_length: int = 21 + sequence_period: int = 20 discount: float = 0.99 lambda_gae: float = 0.95 max_queue_size: Optional[int] = 1000 diff --git a/mava/systems/tf/mappo/system.py b/mava/systems/tf/mappo/system.py index 0c6aeb9b3..0bae0442a 100644 --- a/mava/systems/tf/mappo/system.py +++ b/mava/systems/tf/mappo/system.py @@ -70,7 +70,7 @@ def __init__( # noqa snt.Optimizer, Dict[str, snt.Optimizer] ] = snt.optimizers.Adam(learning_rate=5e-4), critic_optimizer: Optional[snt.Optimizer] = snt.optimizers.Adam( - learning_rate=1e-3 + learning_rate=5e-4 ), discount: float = 0.99, lambda_gae: float = 0.95, @@ -80,7 +80,7 @@ def __init__( # noqa max_gradient_norm: Optional[float] = 0.5, max_queue_size: Optional[int] = None, batch_size: int = 512, - minibatch_size: int = 128, + minibatch_size: int = None, num_epochs: int = 5, sequence_length: int = 20, sequence_period: Optional[int] = 10, From 10185f6e5b4ca5bb17e8ec42dc9078e2f906754f Mon Sep 17 00:00:00 2001 From: Edan Toledo Date: Wed, 13 Apr 2022 14:48:07 +0200 Subject: [PATCH 070/104] fix: add if statement to check if default net spec keys is given and refactor --- mava/systems/tf/madqn/builder.py | 12 ++++++------ mava/systems/tf/madqn/system.py | 9 ++++++--- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/mava/systems/tf/madqn/builder.py b/mava/systems/tf/madqn/builder.py index 8623c1872..29401707c 100644 --- a/mava/systems/tf/madqn/builder.py +++ b/mava/systems/tf/madqn/builder.py @@ -171,14 +171,14 @@ def __init__( self._trainer_fn = trainer_fn self._executor_fn = executor_fn - def covert_specs( + def convert_specs( self, spec: Dict[str, Any], list_of_networks: List ) -> Dict[str, Any]: """Convert specs. Args: spec: agent specs - num_networks: the number of networks + list_of_networks: the list of networks each trainer uses. Returns: converted specs @@ -198,7 +198,7 @@ def covert_specs( else: # For the extras for key in spec.keys(): - converted_spec[key] = self.covert_specs(spec[key], list_of_networks) + converted_spec[key] = self.convert_specs(spec[key], list_of_networks) return converted_spec def make_replay_tables( @@ -264,14 +264,14 @@ def limiter_fn() -> reverb.rate_limiters: # Convert a Mava spec list_of_networks = self._config.table_network_config[table_key] env_spec = copy.deepcopy(environment_spec) - env_spec._specs = self.covert_specs(env_spec._specs, list_of_networks) + env_spec._specs = self.convert_specs(env_spec._specs, list_of_networks) env_spec._keys = list(sort_str_num(env_spec._specs.keys())) if env_spec.extra_specs is not None: - env_spec.extra_specs = self.covert_specs( + env_spec.extra_specs = self.convert_specs( env_spec.extra_specs, list_of_networks ) - extra_specs = self.covert_specs( + extra_specs = self.convert_specs( self._extra_specs, list_of_networks, ) diff --git a/mava/systems/tf/madqn/system.py b/mava/systems/tf/madqn/system.py index 1819f91f4..5a32736f7 100644 --- a/mava/systems/tf/madqn/system.py +++ b/mava/systems/tf/madqn/system.py @@ -309,9 +309,12 @@ def __init__( # noqa # Check that all agent_net_keys are in trainer_networks assert unique_net_keys == unique_trainer_net_keys # Setup specs for each network - self._net_spec_keys = net_spec_keys - for i in range(len(unique_net_keys)): - self._net_spec_keys[unique_net_keys[i]] = agents[i % len(agents)] + if net_spec_keys: + self._net_spec_keys = net_spec_keys + else: + self._net_spec_keys={} + for i in range(len(unique_net_keys)): + self._net_spec_keys[unique_net_keys[i]] = agents[i % len(agents)] # Setup table_network_config table_network_config = {} From 80cf51de6362ac32b0460723e33500f46b44ec54 Mon Sep 17 00:00:00 2001 From: Edan Toledo Date: Wed, 13 Apr 2022 14:48:47 +0200 Subject: [PATCH 071/104] fix: linting issue --- mava/systems/tf/madqn/system.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mava/systems/tf/madqn/system.py b/mava/systems/tf/madqn/system.py index 5a32736f7..c9eac11f2 100644 --- a/mava/systems/tf/madqn/system.py +++ b/mava/systems/tf/madqn/system.py @@ -312,7 +312,7 @@ def __init__( # noqa if net_spec_keys: self._net_spec_keys = net_spec_keys else: - self._net_spec_keys={} + self._net_spec_keys = {} for i in range(len(unique_net_keys)): self._net_spec_keys[unique_net_keys[i]] = agents[i % len(agents)] From 2851235a413523646f5ec1d82f3de74ec0de1d96 Mon Sep 17 00:00:00 2001 From: Dries Date: Thu, 14 Apr 2022 09:20:11 +0200 Subject: [PATCH 072/104] fix: Small training fixes. --- mava/systems/tf/mappo/builder.py | 1 - mava/systems/tf/mappo/system.py | 8 -------- mava/systems/tf/mappo/training.py | 5 ++++- 3 files changed, 4 insertions(+), 10 deletions(-) diff --git a/mava/systems/tf/mappo/builder.py b/mava/systems/tf/mappo/builder.py index 22f926ab7..faa0297fc 100644 --- a/mava/systems/tf/mappo/builder.py +++ b/mava/systems/tf/mappo/builder.py @@ -330,7 +330,6 @@ def make_adder( table_network_config=self._config.table_network_config, period=self._config.sequence_period, sequence_length=self._config.sequence_length, - use_next_extras=False, # end_of_episode_behavior=EndBehavior.CONTINUE, ) # Note (dries): Using end_of_episode_behavior=EndBehavior.CONTINUE can diff --git a/mava/systems/tf/mappo/system.py b/mava/systems/tf/mappo/system.py index 27968a649..7a93b069e 100644 --- a/mava/systems/tf/mappo/system.py +++ b/mava/systems/tf/mappo/system.py @@ -211,14 +211,6 @@ def __init__( # noqa normalize_advantage: whether to normalize the advantage estimate. This can hurt peformance when shared weights are used. """ - - if num_epochs > 5: - raise ValueError( - "Epoch must be smaller than 5. For epoch values large than 5" - + " the client update should be moved to inside the epoch loop" - + "instead of inside the step function." - ) - # minibatch size defaults to train batch size if minibatch_size: self._minibatch_size = minibatch_size diff --git a/mava/systems/tf/mappo/training.py b/mava/systems/tf/mappo/training.py index 998f27e4b..e804018d2 100644 --- a/mava/systems/tf/mappo/training.py +++ b/mava/systems/tf/mappo/training.py @@ -284,6 +284,9 @@ def _step( batch_size = inputs.data.observations[self._agents[0]].observation.shape[0] dataset = tf.data.Dataset.from_tensor_slices(inputs.data) for _ in range(self._num_epochs): + # Update the variable source and the trainer + self._variable_client.set_and_get_async() + # Split for possible minibatches dataset = dataset.shuffle(batch_size) minibatch_dataset = dataset.batch(self._minibatch_size) @@ -633,7 +636,7 @@ def __init__( agent_net_keys: Dict[str, str], minibatch_size: Optional[int] = None, num_epochs: int = 10, - discount: float = 0.999, + discount: float = 0.99, lambda_gae: float = 0.95, entropy_cost: float = 0.01, baseline_cost: float = 0.5, From 167415438a28b169463d263c4b5e7da534acfe98 Mon Sep 17 00:00:00 2001 From: Dries Date: Thu, 14 Apr 2022 10:36:54 +0200 Subject: [PATCH 073/104] fix: Big bugfix in MAPPO trainer code setup. --- mava/systems/tf/mappo/builder.py | 7 ++++--- mava/systems/tf/mappo/system.py | 4 ++-- mava/systems/tf/mappo/training.py | 6 +++++- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/mava/systems/tf/mappo/builder.py b/mava/systems/tf/mappo/builder.py index faa0297fc..935d2418e 100644 --- a/mava/systems/tf/mappo/builder.py +++ b/mava/systems/tf/mappo/builder.py @@ -105,8 +105,8 @@ class MAPPOConfig: net_keys_to_ids: Dict[str, int] unique_net_keys: List[str] checkpoint_minute_interval: int - sequence_length: int = 21 - sequence_period: int = 20 + sequence_length: int = 10 + sequence_period: int = 9 discount: float = 0.99 lambda_gae: float = 0.95 max_queue_size: Optional[int] = 1000 @@ -555,7 +555,8 @@ def make_trainer( "dataset": dataset, "counts": counts, "logger": logger, - "normalize_advantage": self._config.lambda_gae, + "lambda_gae": self._config.lambda_gae, + "normalize_advantage": self._config.normalize_advantage, "entropy_cost": self._config.entropy_cost, "baseline_cost": self._config.baseline_cost, "clipping_epsilon": self._config.clipping_epsilon, diff --git a/mava/systems/tf/mappo/system.py b/mava/systems/tf/mappo/system.py index 7a93b069e..89ecaa6be 100644 --- a/mava/systems/tf/mappo/system.py +++ b/mava/systems/tf/mappo/system.py @@ -84,8 +84,8 @@ def __init__( # noqa batch_size: int = 512, minibatch_size: int = None, num_epochs: int = 5, - sequence_length: int = 20, - sequence_period: Optional[int] = 10, + sequence_length: int = 10, + sequence_period: Optional[int] = None, max_executor_steps: int = None, checkpoint: bool = True, checkpoint_subpath: str = "~/mava/", diff --git a/mava/systems/tf/mappo/training.py b/mava/systems/tf/mappo/training.py index e804018d2..98e1f0b15 100644 --- a/mava/systems/tf/mappo/training.py +++ b/mava/systems/tf/mappo/training.py @@ -437,8 +437,12 @@ def _forward_pass(self, inputs: Any) -> None: advantages = train_utils._normalize_advantages( advantages, variance_epsilon=1e-8 ) + raise NotImplementedError( + "Confirm that this is working." + + "It gave zeros when we tried it out." + ) - advantages = tf.stop_gradient(advantages) + advantages = tf.stop_gradient(advantages) # td_lambda_returns returns = advantages + value_pred From b8d27380b7025eff136a6e8a3e003b6844980b39 Mon Sep 17 00:00:00 2001 From: Dries Date: Tue, 19 Apr 2022 09:46:41 +0200 Subject: [PATCH 074/104] Remove variable update inside mappo trainer _step code. --- mava/systems/tf/mappo/training.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/mava/systems/tf/mappo/training.py b/mava/systems/tf/mappo/training.py index 98e1f0b15..c5404c623 100644 --- a/mava/systems/tf/mappo/training.py +++ b/mava/systems/tf/mappo/training.py @@ -284,9 +284,6 @@ def _step( batch_size = inputs.data.observations[self._agents[0]].observation.shape[0] dataset = tf.data.Dataset.from_tensor_slices(inputs.data) for _ in range(self._num_epochs): - # Update the variable source and the trainer - self._variable_client.set_and_get_async() - # Split for possible minibatches dataset = dataset.shuffle(batch_size) minibatch_dataset = dataset.batch(self._minibatch_size) From 86858c2ba885865979c110f64720008542f3661a Mon Sep 17 00:00:00 2001 From: Kale-ab Date: Wed, 13 Apr 2022 18:44:02 +0200 Subject: [PATCH 075/104] feat: Added jax docker containers. --- Dockerfile | 64 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ Makefile | 46 +++++++++++++++++++++++++++------------ setup.py | 11 ++++++++++ 3 files changed, 107 insertions(+), 14 deletions(-) diff --git a/Dockerfile b/Dockerfile index 3476e89f2..a0f236897 100755 --- a/Dockerfile +++ b/Dockerfile @@ -95,3 +95,67 @@ RUN ./bash_scripts/install_meltingpot.sh # Add meltingpot to python path ENV PYTHONPATH "${PYTHONPATH}:${folder}/../packages/meltingpot" ########################################################## + +# New Jax Images +########################################################## +# Core Mava-Jax image +FROM mava-core as jax-core +# Jax gpu config. +ENV XLA_PYTHON_CLIENT_PREALLOCATE=false +## Install core jax dependencies. +# Install jax gpu +RUN pip install -e .[jax] +RUN pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html +########################################################## + +########################################################## +# PZ image +FROM tf-core AS pz-jax +RUN pip install -e .[pz] +# PettingZoo Atari envs +RUN apt-get update +RUN apt-get install ffmpeg libsm6 libxext6 -y +RUN apt-get install -y unrar-free +RUN pip install autorom +RUN AutoROM -v +########################################################## + +########################################################## +# SMAC image +FROM tf-core AS sc2-jax +## Install smac environment +RUN apt-get -y install git +RUN pip install .[sc2] +# We use the pz wrapper for smac +RUN pip install .[pz] +ENV SC2PATH /home/app/mava/3rdparty/StarCraftII +########################################################## + +########################################################## +# Flatland Image +FROM tf-core AS flatland-jax +RUN pip install -e .[flatland] +########################################################## + +######################################################### +## Robocup Image +FROM tf-core AS robocup-jax +RUN apt-get install sudo -y +RUN ./bash_scripts/install_robocup.sh +########################################################## + +########################################################## +## OpenSpiel Image +FROM tf-core AS openspiel-jax +RUN pip install .[open_spiel] +########################################################## + +########################################################## +# MeltingPot Image +FROM tf-core AS meltingpot-jax +# Install meltingpot +RUN apt-get install -y git +RUN ./bash_scripts/install_meltingpot.sh +# Add meltingpot to python path +ENV PYTHONPATH "${PYTHONPATH}:${folder}/../packages/meltingpot" +########################################################## diff --git a/Makefile b/Makefile index 3720765fe..f37e75ec9 100755 --- a/Makefile +++ b/Makefile @@ -26,20 +26,38 @@ DOCKER_RUN=docker run $(RUN_FLAGS) $(IMAGE) DOCKER_RUN_TENSORBOARD=docker run $(RUN_FLAGS_TENSORBOARD) $(IMAGE) # Choose correct image for example -ifneq (,$(findstring debugging,$(example))) -DOCKER_IMAGE_TAG=tf-core -else ifneq (,$(findstring petting,$(example))) -DOCKER_IMAGE_TAG=pz -else ifneq (,$(findstring flatland,$(example))) -DOCKER_IMAGE_TAG=flatland -else ifneq (,$(findstring openspiel,$(example))) -DOCKER_IMAGE_TAG=openspiel -else ifneq (,$(findstring robocup,$(example))) -DOCKER_IMAGE_TAG=robocup -else ifneq (,$(findstring smac,$(example))) -DOCKER_IMAGE_TAG=sc2 -else ifneq (,$(findstring meltingpot,$(example))) -DOCKER_IMAGE_TAG=meltingpot +ifneq (,$(findstring jax,$(example))) + ifneq (,$(findstring debugging,$(example))) + DOCKER_IMAGE_TAG=jax-core + else ifneq (,$(findstring petting,$(example))) + DOCKER_IMAGE_TAG=pz-jax + else ifneq (,$(findstring flatland,$(example))) + DOCKER_IMAGE_TAG=flatland-jax + else ifneq (,$(findstring openspiel,$(example))) + DOCKER_IMAGE_TAG=openspiel-jax + else ifneq (,$(findstring robocup,$(example))) + DOCKER_IMAGE_TAG=robocup-jax + else ifneq (,$(findstring smac,$(example))) + DOCKER_IMAGE_TAG=sc2-jax + else ifneq (,$(findstring meltingpot,$(example))) + DOCKER_IMAGE_TAG=meltingpot-jax + endif +else + ifneq (,$(findstring debugging,$(example))) + DOCKER_IMAGE_TAG=tf-core + else ifneq (,$(findstring petting,$(example))) + DOCKER_IMAGE_TAG=pz + else ifneq (,$(findstring flatland,$(example))) + DOCKER_IMAGE_TAG=flatland + else ifneq (,$(findstring openspiel,$(example))) + DOCKER_IMAGE_TAG=openspiel + else ifneq (,$(findstring robocup,$(example))) + DOCKER_IMAGE_TAG=robocup + else ifneq (,$(findstring smac,$(example))) + DOCKER_IMAGE_TAG=sc2 + else ifneq (,$(findstring meltingpot,$(example))) + DOCKER_IMAGE_TAG=meltingpot + endif endif # make file commands diff --git a/setup.py b/setup.py index d96824bba..f12126f22 100644 --- a/setup.py +++ b/setup.py @@ -38,6 +38,16 @@ "trfl", ] +jax_requirements = [ + "chex", + "jax", + "jaxlib", + "dm-haiku", + "flax", + "optax", + "rlax", +] + tf_requirements + pettingzoo_requirements = [ "pettingzoo~=1.13.1", "multi_agent_ale_py", @@ -117,6 +127,7 @@ "record_episode": record_episode_requirements, "sc2": smac_requirements, "envs": pettingzoo_requirements + open_spiel_requirements + smac_requirements, + "jax": jax_requirements, }, classifiers=[ "Development Status :: 3 - Alpha", From 8da85626332138ff694a2188042d784be59bce40 Mon Sep 17 00:00:00 2001 From: Kale-ab Date: Wed, 13 Apr 2022 18:49:25 +0200 Subject: [PATCH 076/104] feat: Added jax containers to autopush. --- .github/workflows/push_docker.yaml | 128 +++++++++++++++++++++++++++++ 1 file changed, 128 insertions(+) diff --git a/.github/workflows/push_docker.yaml b/.github/workflows/push_docker.yaml index 6365e8916..b066b52a5 100644 --- a/.github/workflows/push_docker.yaml +++ b/.github/workflows/push_docker.yaml @@ -98,6 +98,71 @@ jobs: tags: ${{ steps.vars2.outputs.docker_repo }}:meltingpot-latest target: meltingpot + + - + name: Build and push jax-core + uses: docker/build-push-action@v2 + with: + context: . + push: true + tags: ${{ steps.vars2.outputs.docker_repo }}:jax-core-latest + target: jax-core + + - + name: Build and push pz-jax + uses: docker/build-push-action@v2 + with: + context: . + push: true + tags: ${{ steps.vars2.outputs.docker_repo }}:pz-jax-latest + target: pz-jax + + - + name: Build and push sc2-jax + uses: docker/build-push-action@v2 + with: + context: . + push: true + tags: ${{ steps.vars2.outputs.docker_repo }}:sc2-jax-latest + target: sc2-jax + + + - + name: Build and push flatland-jax + uses: docker/build-push-action@v2 + with: + context: . + push: true + tags: ${{ steps.vars2.outputs.docker_repo }}:flatland-jax-latest + target: flatland-jax + + - + name: Build and push robocup-jax + uses: docker/build-push-action@v2 + with: + context: . + push: true + tags: ${{ steps.vars2.outputs.docker_repo }}:robocup-jax-latest + target: robocup-jax + + - + name: Build and push openspiel-jax + uses: docker/build-push-action@v2 + with: + context: . + push: true + tags: ${{ steps.vars2.outputs.docker_repo }}:openspiel-jax-latest + target: openspiel-jax + + - + name: Build and push meltingpot-jax + uses: docker/build-push-action@v2 + with: + context: . + push: true + tags: ${{ steps.vars2.outputs.docker_repo }}:meltingpot-jax-latest + target: meltingpot-jax + # Action pushes new docker image if we have a PR to dev and have the label benchmark build-push-branches: if: ${{ github.event.label.name == 'benchmark' }} @@ -196,3 +261,66 @@ jobs: push: true tags: ${{ steps.vars2.outputs.docker_repo }}:meltingpot-${{ steps.vars.outputs.sha_short }} target: meltingpot + + - + name: Build and push jax-core + uses: docker/build-push-action@v2 + with: + context: . + push: true + tags: ${{ steps.vars2.outputs.docker_repo }}:jax-core-${{ steps.vars.outputs.sha_short }} + target: jax-core + + - + name: Build and push pz-jax + uses: docker/build-push-action@v2 + with: + context: . + push: true + tags: ${{ steps.vars2.outputs.docker_repo }}:pz-jax-${{ steps.vars.outputs.sha_short }} + target: pz-jax + + - + name: Build and push sc2-jax + uses: docker/build-push-action@v2 + with: + context: . + push: true + tags: ${{ steps.vars2.outputs.docker_repo }}:sc2-jax-${{ steps.vars.outputs.sha_short }} + target: sc2-jax + + - + name: Build and push flatland-jax + uses: docker/build-push-action@v2 + with: + context: . + push: true + tags: ${{ steps.vars2.outputs.docker_repo }}:flatland-jax-${{ steps.vars.outputs.sha_short }} + target: flatland-jax + + - + name: Build and push robocup-jax + uses: docker/build-push-action@v2 + with: + context: . + push: true + tags: ${{ steps.vars2.outputs.docker_repo }}:robocup-jax-${{ steps.vars.outputs.sha_short }} + target: robocup-jax + + - + name: Build and push openspiel-jax + uses: docker/build-push-action@v2 + with: + context: . + push: true + tags: ${{ steps.vars2.outputs.docker_repo }}:openspiel-jax-${{ steps.vars.outputs.sha_short }} + target: openspiel-jax + + - + name: Build and push meltingpot + uses: docker/build-push-action@v2 + with: + context: . + push: true + tags: ${{ steps.vars2.outputs.docker_repo }}:meltingpot-jax-${{ steps.vars.outputs.sha_short }} + target: meltingpot-jax From 2f9d56863b8db2a7bcb6c5f9abda1ee09c65c51d Mon Sep 17 00:00:00 2001 From: Kale-ab Date: Wed, 13 Apr 2022 18:52:29 +0200 Subject: [PATCH 077/104] fix: Correct base docker image. --- Dockerfile | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/Dockerfile b/Dockerfile index a0f236897..013fb7fd5 100755 --- a/Dockerfile +++ b/Dockerfile @@ -110,7 +110,7 @@ RUN pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-rele ########################################################## # PZ image -FROM tf-core AS pz-jax +FROM jax-core AS pz-jax RUN pip install -e .[pz] # PettingZoo Atari envs RUN apt-get update @@ -122,7 +122,7 @@ RUN AutoROM -v ########################################################## # SMAC image -FROM tf-core AS sc2-jax +FROM jax-core AS sc2-jax ## Install smac environment RUN apt-get -y install git RUN pip install .[sc2] @@ -133,26 +133,26 @@ ENV SC2PATH /home/app/mava/3rdparty/StarCraftII ########################################################## # Flatland Image -FROM tf-core AS flatland-jax +FROM jax-core AS flatland-jax RUN pip install -e .[flatland] ########################################################## ######################################################### ## Robocup Image -FROM tf-core AS robocup-jax +FROM jax-core AS robocup-jax RUN apt-get install sudo -y RUN ./bash_scripts/install_robocup.sh ########################################################## ########################################################## ## OpenSpiel Image -FROM tf-core AS openspiel-jax +FROM jax-core AS openspiel-jax RUN pip install .[open_spiel] ########################################################## ########################################################## # MeltingPot Image -FROM tf-core AS meltingpot-jax +FROM jax-core AS meltingpot-jax # Install meltingpot RUN apt-get install -y git RUN ./bash_scripts/install_meltingpot.sh From 2bfa19ffb16ef17b4dde4bb9c04e6914b6c77594 Mon Sep 17 00:00:00 2001 From: Kale-ab Date: Thu, 14 Apr 2022 08:47:23 +0200 Subject: [PATCH 078/104] fix: Updated base image to work with jax. --- Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index 013fb7fd5..6a239a66e 100755 --- a/Dockerfile +++ b/Dockerfile @@ -1,6 +1,6 @@ ########################################################## # Core Mava image -FROM nvidia/cuda:11.4.2-cudnn8-runtime-ubuntu20.04 as mava-core +FROM nvidia/cuda:11.5.1-cudnn8-devel-ubuntu20.04 as mava-core # Flag to record agents ARG record # Ensure no installs try launch interactive screen From f307597b2a6aab7af6e6553643425b7b12fade61 Mon Sep 17 00:00:00 2001 From: Kale-ab Date: Thu, 14 Apr 2022 08:47:49 +0200 Subject: [PATCH 079/104] chore: Updated readme with new jax containers. --- README.md | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 01b24e1bc..551af056c 100644 --- a/README.md +++ b/README.md @@ -150,11 +150,20 @@ docker run --gpus all -it --rm -v $(pwd):/home/app/mava -w /home/app/mava insta 1.1 Only Mava core: + Tensorflow version: ```bash - make build + make build version=tf-core + ``` + + Jax version: + ```bash + make build version=jax-core ``` 1.2 For **optional** environments: + + **Note for jax images, append `-jax` to the build command, e.g. `make build version=pz-jax`.** + * PettingZoo: ``` From 315574867158755b662f48f7244ca80c618c2a2c Mon Sep 17 00:00:00 2001 From: Kale-ab Date: Thu, 14 Apr 2022 09:04:17 +0200 Subject: [PATCH 080/104] chore: Updated python virtual env readme. --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 551af056c..dbfb65829 100644 --- a/README.md +++ b/README.md @@ -272,6 +272,8 @@ docker run --gpus all -it --rm -v $(pwd):/home/app/mava -w /home/app/mava insta pip install git+https://github.com/instadeepai/Mava#egg=id-mava[reverb,tf,launchpad] ``` + **For the jax version of mava, please replace `tf` with `jax`, e.g. `pip install id-mava[jax,reverb,launchpad]`** + 1.2 For **optional** environments: * PettingZoo: From 7b77152245464b539893027e724ac08df3ec7c43 Mon Sep 17 00:00:00 2001 From: Kale-ab Date: Tue, 19 Apr 2022 12:34:31 +0200 Subject: [PATCH 081/104] fix: removed jax from reverb requirements. --- setup.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/setup.py b/setup.py index f12126f22..17c70e5e7 100644 --- a/setup.py +++ b/setup.py @@ -27,8 +27,6 @@ reverb_requirements = [ "dm-reverb~=0.6.1", - "jax", - "jaxlib", ] tf_requirements = [ From ecb90a6af27a6460b96a4e61b8fa7056e1b014ea Mon Sep 17 00:00:00 2001 From: Kale-ab Date: Tue, 19 Apr 2022 13:15:09 +0200 Subject: [PATCH 082/104] feat: Updated pettingzoo. --- .../atari/pong/recurrent/decentralised/run_madqn.py | 1 + .../cooperative_pong/feedforward/decentralised/run_mappo.py | 2 +- .../sisl/multiwalker/feedforward/centralised/run_maddpg.py | 2 +- .../sisl/multiwalker/feedforward/decentralised/run_mad4pg.py | 2 +- .../feedforward/decentralised/run_mad4pg_record.py | 2 +- .../sisl/multiwalker/feedforward/decentralised/run_maddpg.py | 2 +- .../sisl/multiwalker/feedforward/decentralised/run_mappo.py | 2 +- .../sisl/multiwalker/recurrent/decentralised/run_mad4pg.py | 2 +- .../sisl/multiwalker/recurrent/decentralised/run_maddpg.py | 2 +- setup.py | 4 ++-- tests/environment_loops/environment_loop_test.py | 4 ++-- tests/wrappers/wrapper_test.py | 4 ++-- 12 files changed, 15 insertions(+), 14 deletions(-) diff --git a/examples/tf/petting_zoo/atari/pong/recurrent/decentralised/run_madqn.py b/examples/tf/petting_zoo/atari/pong/recurrent/decentralised/run_madqn.py index be53a215c..bf1f0ee01 100644 --- a/examples/tf/petting_zoo/atari/pong/recurrent/decentralised/run_madqn.py +++ b/examples/tf/petting_zoo/atari/pong/recurrent/decentralised/run_madqn.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Example running recurrent MADQN on Atari Pong.""" +# You will need to install AutoRom, follow instructions in OPTIONAL_INSTALL.md. import functools from datetime import datetime diff --git a/examples/tf/petting_zoo/butterfly/cooperative_pong/feedforward/decentralised/run_mappo.py b/examples/tf/petting_zoo/butterfly/cooperative_pong/feedforward/decentralised/run_mappo.py index 0c17b820f..ff2a8f177 100644 --- a/examples/tf/petting_zoo/butterfly/cooperative_pong/feedforward/decentralised/run_mappo.py +++ b/examples/tf/petting_zoo/butterfly/cooperative_pong/feedforward/decentralised/run_mappo.py @@ -37,7 +37,7 @@ ) flags.DEFINE_string( "env_name", - "cooperative_pong_v3", + "cooperative_pong_v5", "Pettingzoo environment name, e.g. pong (str).", ) diff --git a/examples/tf/petting_zoo/sisl/multiwalker/feedforward/centralised/run_maddpg.py b/examples/tf/petting_zoo/sisl/multiwalker/feedforward/centralised/run_maddpg.py index e6b4f2c10..06138c955 100644 --- a/examples/tf/petting_zoo/sisl/multiwalker/feedforward/centralised/run_maddpg.py +++ b/examples/tf/petting_zoo/sisl/multiwalker/feedforward/centralised/run_maddpg.py @@ -39,7 +39,7 @@ flags.DEFINE_string( "env_name", - "multiwalker_v7", + "multiwalker_v8", "Pettingzoo environment name, e.g. pong (str).", ) flags.DEFINE_string( diff --git a/examples/tf/petting_zoo/sisl/multiwalker/feedforward/decentralised/run_mad4pg.py b/examples/tf/petting_zoo/sisl/multiwalker/feedforward/decentralised/run_mad4pg.py index 9bdb9fee6..1a23040e7 100644 --- a/examples/tf/petting_zoo/sisl/multiwalker/feedforward/decentralised/run_mad4pg.py +++ b/examples/tf/petting_zoo/sisl/multiwalker/feedforward/decentralised/run_mad4pg.py @@ -38,7 +38,7 @@ flags.DEFINE_string( "env_name", - "multiwalker_v7", + "multiwalker_v8", "Pettingzoo environment name, e.g. pong (str).", ) flags.DEFINE_string( diff --git a/examples/tf/petting_zoo/sisl/multiwalker/feedforward/decentralised/run_mad4pg_record.py b/examples/tf/petting_zoo/sisl/multiwalker/feedforward/decentralised/run_mad4pg_record.py index 6e12a8916..8e69c01e1 100644 --- a/examples/tf/petting_zoo/sisl/multiwalker/feedforward/decentralised/run_mad4pg_record.py +++ b/examples/tf/petting_zoo/sisl/multiwalker/feedforward/decentralised/run_mad4pg_record.py @@ -39,7 +39,7 @@ flags.DEFINE_string( "env_name", - "multiwalker_v7", + "multiwalker_v8", "Pettingzoo environment name, e.g. pong (str).", ) flags.DEFINE_string( diff --git a/examples/tf/petting_zoo/sisl/multiwalker/feedforward/decentralised/run_maddpg.py b/examples/tf/petting_zoo/sisl/multiwalker/feedforward/decentralised/run_maddpg.py index 15769046a..1cb137e80 100644 --- a/examples/tf/petting_zoo/sisl/multiwalker/feedforward/decentralised/run_maddpg.py +++ b/examples/tf/petting_zoo/sisl/multiwalker/feedforward/decentralised/run_maddpg.py @@ -38,7 +38,7 @@ flags.DEFINE_string( "env_name", - "multiwalker_v7", + "multiwalker_v8", "Pettingzoo environment name, e.g. pong (str).", ) flags.DEFINE_string( diff --git a/examples/tf/petting_zoo/sisl/multiwalker/feedforward/decentralised/run_mappo.py b/examples/tf/petting_zoo/sisl/multiwalker/feedforward/decentralised/run_mappo.py index 461acd87a..231803876 100644 --- a/examples/tf/petting_zoo/sisl/multiwalker/feedforward/decentralised/run_mappo.py +++ b/examples/tf/petting_zoo/sisl/multiwalker/feedforward/decentralised/run_mappo.py @@ -37,7 +37,7 @@ flags.DEFINE_string( "env_name", - "multiwalker_v7", + "multiwalker_v8", "Pettingzoo environment name, e.g. pong (str).", ) flags.DEFINE_string( diff --git a/examples/tf/petting_zoo/sisl/multiwalker/recurrent/decentralised/run_mad4pg.py b/examples/tf/petting_zoo/sisl/multiwalker/recurrent/decentralised/run_mad4pg.py index 7b9c91c7c..f493bae65 100644 --- a/examples/tf/petting_zoo/sisl/multiwalker/recurrent/decentralised/run_mad4pg.py +++ b/examples/tf/petting_zoo/sisl/multiwalker/recurrent/decentralised/run_mad4pg.py @@ -39,7 +39,7 @@ flags.DEFINE_string( "env_name", - "multiwalker_v7", + "multiwalker_v8", "Pettingzoo environment name, e.g. pong (str).", ) flags.DEFINE_string( diff --git a/examples/tf/petting_zoo/sisl/multiwalker/recurrent/decentralised/run_maddpg.py b/examples/tf/petting_zoo/sisl/multiwalker/recurrent/decentralised/run_maddpg.py index 0f659f29e..d0566f3ae 100644 --- a/examples/tf/petting_zoo/sisl/multiwalker/recurrent/decentralised/run_maddpg.py +++ b/examples/tf/petting_zoo/sisl/multiwalker/recurrent/decentralised/run_maddpg.py @@ -39,7 +39,7 @@ flags.DEFINE_string( "env_name", - "multiwalker_v7", + "multiwalker_v8", "Pettingzoo environment name, e.g. pong (str).", ) flags.DEFINE_string( diff --git a/setup.py b/setup.py index 17c70e5e7..57ef545f2 100644 --- a/setup.py +++ b/setup.py @@ -47,9 +47,9 @@ ] + tf_requirements pettingzoo_requirements = [ - "pettingzoo~=1.13.1", + "pettingzoo~=1.17.0", "multi_agent_ale_py", - "supersuit==3.3.1", + "supersuit==3.3.4", "pygame", "pysc2", ] diff --git a/tests/environment_loops/environment_loop_test.py b/tests/environment_loops/environment_loop_test.py index 2f741a9e1..ca78660c0 100644 --- a/tests/environment_loops/environment_loop_test.py +++ b/tests/environment_loops/environment_loop_test.py @@ -30,8 +30,8 @@ # Real Environments EnvSpec("pettingzoo.mpe.simple_spread_v2", EnvType.Parallel), EnvSpec("pettingzoo.mpe.simple_spread_v2", EnvType.Sequential), - EnvSpec("pettingzoo.sisl.multiwalker_v7", EnvType.Parallel), - EnvSpec("pettingzoo.sisl.multiwalker_v7", EnvType.Sequential), + EnvSpec("pettingzoo.sisl.multiwalker_v8", EnvType.Parallel), + EnvSpec("pettingzoo.sisl.multiwalker_v8", EnvType.Sequential), ], ) class TestEnvironmentLoop: diff --git a/tests/wrappers/wrapper_test.py b/tests/wrappers/wrapper_test.py index 226aca86c..d262336cf 100644 --- a/tests/wrappers/wrapper_test.py +++ b/tests/wrappers/wrapper_test.py @@ -47,8 +47,8 @@ [ EnvSpec("pettingzoo.mpe.simple_spread_v2", EnvType.Parallel), EnvSpec("pettingzoo.mpe.simple_spread_v2", EnvType.Sequential), - EnvSpec("pettingzoo.sisl.multiwalker_v7", EnvType.Parallel), - EnvSpec("pettingzoo.sisl.multiwalker_v7", EnvType.Sequential), + EnvSpec("pettingzoo.sisl.multiwalker_v8", EnvType.Parallel), + EnvSpec("pettingzoo.sisl.multiwalker_v8", EnvType.Sequential), EnvSpec("flatland", EnvType.Parallel, EnvSource.Flatland) if _has_flatland else None, From a7343c2fc55a1a86739f95f437c0deef47d90561 Mon Sep 17 00:00:00 2001 From: AsadJeewa Date: Wed, 20 Apr 2022 10:54:31 +0200 Subject: [PATCH 083/104] chore: remove debug code, add comment to ddpg --- mava/systems/tf/mad4pg/system.py | 4 ---- mava/systems/tf/mad4pg/training.py | 34 --------------------------- mava/systems/tf/maddpg/builder.py | 1 - mava/systems/tf/maddpg/networks.py | 4 ++++ mava/systems/tf/maddpg/system.py | 3 --- mava/systems/tf/maddpg/training.py | 37 ++++-------------------------- 6 files changed, 8 insertions(+), 75 deletions(-) diff --git a/mava/systems/tf/mad4pg/system.py b/mava/systems/tf/mad4pg/system.py index c5ebfac0e..906acf5a7 100644 --- a/mava/systems/tf/mad4pg/system.py +++ b/mava/systems/tf/mad4pg/system.py @@ -89,7 +89,6 @@ def __init__( learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, termination_condition: Optional[Dict[str, int]] = None, evaluator_interval: Optional[dict] = None, - update_obs_once: bool = False, ): """Initialise the system @@ -187,8 +186,6 @@ def __init__( happen at every timestep. E.g. to evaluate a system after every 100 executor episodes, evaluator_interval = {"executor_episodes": 100}. - update_obs_once: bool = use only critic gradients to update shared - observation network """ super().__init__( @@ -231,5 +228,4 @@ def __init__( learning_rate_scheduler_fn=learning_rate_scheduler_fn, termination_condition=termination_condition, evaluator_interval=evaluator_interval, - update_obs_once=update_obs_once, ) diff --git a/mava/systems/tf/mad4pg/training.py b/mava/systems/tf/mad4pg/training.py index 83e738048..4fab17c15 100644 --- a/mava/systems/tf/mad4pg/training.py +++ b/mava/systems/tf/mad4pg/training.py @@ -74,7 +74,6 @@ def __init__( max_gradient_norm: float = None, logger: loggers.Logger = None, learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, - update_obs_once: bool = False, ): """Initialise MAD4PG trainer @@ -110,8 +109,6 @@ def __init__( learning_rate_scheduler_fn: dict with two functions (one for the policy and one for the critic optimizer), that takes in a trainer step t and returns the current learning rate. - update_obs_once: bool = use only critic gradients to update shared - observation network """ super().__init__( @@ -136,7 +133,6 @@ def __init__( max_gradient_norm=max_gradient_norm, logger=logger, learning_rate_scheduler_fn=learning_rate_scheduler_fn, - update_obs_once=update_obs_once, ) # Forward pass that calculates loss. @@ -254,7 +250,6 @@ def __init__( max_gradient_norm: float = None, logger: loggers.Logger = None, learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, - update_obs_once: bool = False, ): """Initialise decentralised MAD4PG trainer @@ -290,8 +285,6 @@ def __init__( learning_rate_scheduler_fn: dict with two functions (one for the policy and one for the critic optimizer), that takes in a trainer step t and returns the current learning rate. - update_obs_once: bool = use only critic gradients to update shared - observation network """ super().__init__( @@ -316,7 +309,6 @@ def __init__( variable_client=variable_client, counts=counts, learning_rate_scheduler_fn=learning_rate_scheduler_fn, - update_obs_once=update_obs_once, ) @@ -346,7 +338,6 @@ def __init__( max_gradient_norm: float = None, logger: loggers.Logger = None, learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, - update_obs_once: bool = False, ): """Initialise centralised MAD4PG trainer @@ -382,8 +373,6 @@ def __init__( learning_rate_scheduler_fn: dict with two functions (one for the policy and one for the critic optimizer), that takes in a trainer step t and returns the current learning rate. - update_obs_once: bool = use only critic gradients to update shared - observation network """ super().__init__( @@ -408,7 +397,6 @@ def __init__( variable_client=variable_client, counts=counts, learning_rate_scheduler_fn=learning_rate_scheduler_fn, - update_obs_once=update_obs_once, ) @@ -438,7 +426,6 @@ def __init__( max_gradient_norm: float = None, logger: loggers.Logger = None, learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, - update_obs_once: bool = False, ): """Initialise state-based MAD4PG trainer @@ -474,8 +461,6 @@ def __init__( learning_rate_scheduler_fn: dict with two functions (one for the policy and one for the critic optimizer), that takes in a trainer step t and returns the current learning rate. - update_obs_once: bool = use only critic gradients to update shared - observation network """ super().__init__( @@ -500,7 +485,6 @@ def __init__( variable_client=variable_client, counts=counts, learning_rate_scheduler_fn=learning_rate_scheduler_fn, - update_obs_once=update_obs_once, ) @@ -535,7 +519,6 @@ def __init__( logger: loggers.Logger = None, bootstrap_n: int = 10, learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, - update_obs_once: bool = False, ): """Initialise Recurrent MAD4PG trainer @@ -572,8 +555,6 @@ def __init__( learning_rate_scheduler_fn: dict with two functions (one for the policy and one for the critic optimizer), that takes in a trainer step t and returns the current learning rate. - update_obs_once: bool = use only critic gradients to update shared - observation network """ super().__init__( @@ -599,7 +580,6 @@ def __init__( counts=counts, bootstrap_n=bootstrap_n, learning_rate_scheduler_fn=learning_rate_scheduler_fn, - update_obs_once=update_obs_once, ) # Forward pass that calculates loss. @@ -780,7 +760,6 @@ def __init__( logger: loggers.Logger = None, bootstrap_n: int = 10, learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, - update_obs_once: bool = False, ): """Initialise Recurrent MAD4PG trainer @@ -817,8 +796,6 @@ def __init__( learning_rate_scheduler_fn: dict with two functions (one for the policy and one for the critic optimizer), that takes in a trainer step t and returns the current learning rate. - update_obs_once: bool = use only critic gradients to update shared - observation network """ super().__init__( @@ -844,7 +821,6 @@ def __init__( counts=counts, bootstrap_n=bootstrap_n, learning_rate_scheduler_fn=learning_rate_scheduler_fn, - update_obs_once=update_obs_once, ) @@ -877,7 +853,6 @@ def __init__( logger: loggers.Logger = None, bootstrap_n: int = 10, learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, - update_obs_once: bool = False, ): """Initialise Recurrent MAD4PG trainer @@ -914,8 +889,6 @@ def __init__( learning_rate_scheduler_fn: dict with two functions (one for the policy and one for the critic optimizer), that takes in a trainer step t and returns the current learning rate. - update_obs_once: bool = use only critic gradients to update shared - observation network """ super().__init__( @@ -941,7 +914,6 @@ def __init__( counts=counts, bootstrap_n=bootstrap_n, learning_rate_scheduler_fn=learning_rate_scheduler_fn, - update_obs_once=update_obs_once, ) @@ -974,7 +946,6 @@ def __init__( logger: loggers.Logger = None, bootstrap_n: int = 10, learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, - update_obs_once: bool = False, ): """Initialise Recurrent MAD4PG trainer @@ -1011,8 +982,6 @@ def __init__( learning_rate_scheduler_fn: dict with two functions (one for the policy and one for the critic optimizer), that takes in a trainer step t and returns the current learning rate. - update_obs_once: bool = use only critic gradients to update shared - observation network """ super().__init__( @@ -1038,7 +1007,6 @@ def __init__( counts=counts, bootstrap_n=bootstrap_n, learning_rate_scheduler_fn=learning_rate_scheduler_fn, - update_obs_once=update_obs_once, ) @@ -1106,8 +1074,6 @@ def __init__( learning_rate_scheduler_fn: dict with two functions (one for the policy and one for the critic optimizer), that takes in a trainer step t and returns the current learning rate. - update_obs_once: bool = use only critic gradients to update shared - observation network """ super().__init__( diff --git a/mava/systems/tf/maddpg/builder.py b/mava/systems/tf/maddpg/builder.py index 41a68c33c..2069865f3 100644 --- a/mava/systems/tf/maddpg/builder.py +++ b/mava/systems/tf/maddpg/builder.py @@ -130,7 +130,6 @@ class MADDPGConfig: termination_condition: Optional[Dict[str, int]] = None evaluator_interval: Optional[dict] = None learning_rate_scheduler_fn: Optional[Any] = None - update_obs_once: bool = False class MADDPGBuilder: diff --git a/mava/systems/tf/maddpg/networks.py b/mava/systems/tf/maddpg/networks.py index 8d9738d86..12825f1d9 100644 --- a/mava/systems/tf/maddpg/networks.py +++ b/mava/systems/tf/maddpg/networks.py @@ -108,6 +108,10 @@ def make_default_networks( policy_networks = {} critic_networks = {} for key in specs.keys(): + # Create the shared observation network if not already defined, + # as a state-less operation. + if observation_network is None: + observation_network = tf2_utils.to_sonnet_module(tf.identity) # TODO (dries): Make specs[key].actions # return a list of specs for hybrid action space # Get total number of action dimensions from action spec. diff --git a/mava/systems/tf/maddpg/system.py b/mava/systems/tf/maddpg/system.py index 85eba3675..baecaa884 100644 --- a/mava/systems/tf/maddpg/system.py +++ b/mava/systems/tf/maddpg/system.py @@ -108,7 +108,6 @@ def __init__( # noqa termination_condition: Optional[Dict[str, int]] = None, evaluator_interval: Optional[dict] = None, learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, - update_obs_once: bool = False, ): """Initialise the system @@ -206,7 +205,6 @@ def __init__( # noqa happen at every timestep. E.g. to evaluate a system after every 100 executor episodes, evaluator_interval = {"executor_episodes": 100}. - update_obs_once: bool = use only critic gradients to update shared observation network """ if not environment_spec: @@ -400,7 +398,6 @@ def __init__( # noqa termination_condition=termination_condition, evaluator_interval=evaluator_interval, learning_rate_scheduler_fn=learning_rate_scheduler_fn, - update_obs_once=update_obs_once, ), trainer_fn=trainer_fn, executor_fn=executor_fn, diff --git a/mava/systems/tf/maddpg/training.py b/mava/systems/tf/maddpg/training.py index 5765df123..6e3bb7ca0 100644 --- a/mava/systems/tf/maddpg/training.py +++ b/mava/systems/tf/maddpg/training.py @@ -71,7 +71,6 @@ def __init__( max_gradient_norm: float = None, logger: loggers.Logger = None, learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, - update_obs_once: bool = False, ): """Initialise MADDPG trainer @@ -109,15 +108,12 @@ def __init__( learning_rate_scheduler_fn: dict with two functions (one for the policy and one for the critic optimizer), that takes in a trainer step t and returns the current learning rate. - update_obs_once: bool = use only critic gradients to update shared - observation network """ self._agents = agents self._agent_net_keys = agent_net_keys self._variable_client = variable_client self._learning_rate_scheduler_fn = learning_rate_scheduler_fn - self._update_obs_once = update_obs_once # Setup counts self._counts = counts @@ -461,13 +457,8 @@ def _backward(self) -> None: agent_key = self._agent_net_keys[agent] # Get trainable variables. - if self._update_obs_once: - policy_variables = self._policy_networks[agent_key].trainable_variables - else: - policy_variables = ( - self._observation_networks[agent_key].trainable_variables - + self._policy_networks[agent_key].trainable_variables - ) + # Note: policy does not update the shared obs network + policy_variables = self._policy_networks[agent_key].trainable_variables critic_variables = ( # In this agent, the critic loss trains the observation network. self._observation_networks[agent_key].trainable_variables @@ -582,7 +573,6 @@ def __init__( max_gradient_norm: float = None, logger: loggers.Logger = None, learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, - update_obs_once: bool = False, ): """Initialise the decentralised MADDPG trainer.""" super().__init__( @@ -607,7 +597,6 @@ def __init__( variable_client=variable_client, counts=counts, learning_rate_scheduler_fn=learning_rate_scheduler_fn, - update_obs_once=update_obs_once, ) @@ -637,7 +626,6 @@ def __init__( max_gradient_norm: float = None, logger: loggers.Logger = None, learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, - update_obs_once: bool = False, ): """Initialise the centralised MADDPG trainer.""" @@ -663,7 +651,6 @@ def __init__( variable_client=variable_client, counts=counts, learning_rate_scheduler_fn=learning_rate_scheduler_fn, - update_obs_once=update_obs_once, ) def _get_critic_feed( @@ -735,7 +722,6 @@ def __init__( max_gradient_norm: float = None, logger: loggers.Logger = None, learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, - update_obs_once: bool = False, ): """Initialise the networked MADDPG trainer.""" @@ -761,7 +747,6 @@ def __init__( variable_client=variable_client, counts=counts, learning_rate_scheduler_fn=learning_rate_scheduler_fn, - update_obs_once=update_obs_once, ) self._connection_spec = connection_spec @@ -949,7 +934,6 @@ def __init__( logger: loggers.Logger = None, bootstrap_n: int = 10, learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, - update_obs_once: bool = False, ): """Initialise Recurrent MADDPG trainer @@ -987,8 +971,6 @@ def __init__( learning_rate_scheduler_fn: dict with two functions (one for the policy and one for the critic optimizer), that takes in a trainer step t and returns the current learning rate. - update_obs_once: bool = use only critic gradients to update shared - observation network """ self._bootstrap_n = bootstrap_n @@ -996,7 +978,6 @@ def __init__( self._agent_net_keys = agent_net_keys self._variable_client = variable_client self._learning_rate_scheduler_fn = learning_rate_scheduler_fn - self._update_obs_once = update_obs_once # Setup counts self._counts = counts @@ -1435,14 +1416,8 @@ def _backward(self) -> None: agent_key = self._agent_net_keys[agent] # Get trainable variables. - - if self._update_obs_once: - policy_variables = self._policy_networks[agent_key].trainable_variables - else: - policy_variables = ( - self._observation_networks[agent_key].trainable_variables - + self._policy_networks[agent_key].trainable_variables - ) + # Note: policy does not update the shared obs network + policy_variables = self._policy_networks[agent_key].trainable_variables critic_variables = ( # In this agent, the critic loss trains the observation network. @@ -1581,7 +1556,6 @@ def __init__( logger: loggers.Logger = None, bootstrap_n: int = 10, learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, - update_obs_once: bool = False, ): super().__init__( @@ -1607,7 +1581,6 @@ def __init__( counts=counts, bootstrap_n=bootstrap_n, learning_rate_scheduler_fn=learning_rate_scheduler_fn, - update_obs_once=update_obs_once, ) @@ -1642,7 +1615,6 @@ def __init__( logger: loggers.Logger = None, bootstrap_n: int = 10, learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, - update_obs_once: bool = False, ): super().__init__( @@ -1668,7 +1640,6 @@ def __init__( counts=counts, bootstrap_n=bootstrap_n, learning_rate_scheduler_fn=learning_rate_scheduler_fn, - update_obs_once=update_obs_once, ) def _get_critic_feed( From 67a2d436d385a2e10b6d1f81899cebec3852a556 Mon Sep 17 00:00:00 2001 From: AsadJeewa Date: Wed, 20 Apr 2022 10:57:24 +0200 Subject: [PATCH 084/104] chore: remove debug variable --- mava/systems/tf/maddpg/builder.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mava/systems/tf/maddpg/builder.py b/mava/systems/tf/maddpg/builder.py index 2069865f3..c01550608 100644 --- a/mava/systems/tf/maddpg/builder.py +++ b/mava/systems/tf/maddpg/builder.py @@ -604,7 +604,6 @@ def make_trainer( "counts": counts, "logger": logger, "learning_rate_scheduler_fn": self._config.learning_rate_scheduler_fn, - "update_obs_once": self._config.update_obs_once, } if connection_spec: trainer_config["connection_spec"] = connection_spec From 60eb0541d3f3ae01d4d9c3fc8ebb54d18da62d78 Mon Sep 17 00:00:00 2001 From: Dries Date: Thu, 21 Apr 2022 11:20:38 +0200 Subject: [PATCH 085/104] fix: Address PR comments. --- .../recurrent/state_based/run_mappo.py | 14 +++----------- mava/components/tf/architectures/centralised.py | 5 ----- mava/components/tf/architectures/state_based.py | 6 ++++-- mava/utils/sort_utils.py | 2 ++ 4 files changed, 9 insertions(+), 18 deletions(-) diff --git a/examples/tf/debugging/simple_spread/recurrent/state_based/run_mappo.py b/examples/tf/debugging/simple_spread/recurrent/state_based/run_mappo.py index 6e2a1c972..64c2d35b2 100644 --- a/examples/tf/debugging/simple_spread/recurrent/state_based/run_mappo.py +++ b/examples/tf/debugging/simple_spread/recurrent/state_based/run_mappo.py @@ -51,25 +51,19 @@ def main(_: Any) -> None: - - recurrent_test = False - recurrent_ppo = True - # Environment. environment_factory = functools.partial( debugging_utils.make_environment, env_name=FLAGS.env_name, action_space=FLAGS.action_space, return_state_info=True, - recurrent_test=recurrent_test, + recurrent_test=True, ) # Networks. network_factory = lp_utils.partial_kwargs( mappo.make_default_networks, - architecture_type=ArchitectureType.recurrent - if recurrent_ppo - else ArchitectureType.feedforward, + architecture_type=ArchitectureType.recurrent, ) # Checkpointer appends "Checkpoints" to checkpoint_dir @@ -96,9 +90,7 @@ def main(_: Any) -> None: critic_optimizer=snt.optimizers.Adam(learning_rate=1e-4), checkpoint_subpath=checkpoint_dir, max_gradient_norm=40.0, - executor_fn=mappo.MAPPORecurrentExecutor - if recurrent_ppo - else mappo.MAPPOFeedForwardExecutor, + executor_fn=mappo.MAPPORecurrentExecutor, architecture=architectures.StateBasedValueActorCritic, trainer_fn=mappo.StateBasedMAPPOTrainer, ).build() diff --git a/mava/components/tf/architectures/centralised.py b/mava/components/tf/architectures/centralised.py index 5f18729de..4ff207b8e 100644 --- a/mava/components/tf/architectures/centralised.py +++ b/mava/components/tf/architectures/centralised.py @@ -146,9 +146,6 @@ def _get_critic_specs( agent_key = agents[0] net_key = self._agent_net_keys[agent_key] - # TODO (dries): Add a check to see if all - # self._embed_specs[agent_key].shape are of the same shape - critic_obs_shape = list(copy.copy(self._embed_specs[net_key].shape)) critic_obs_shape.insert(0, len(agents)) obs_specs_per_type[agent_type] = tf.TensorSpec( @@ -159,8 +156,6 @@ def _get_critic_specs( critic_act_shape = list( copy.copy(self._agent_specs[agents[0]].actions.shape) ) - # TODO (dries): Add a check to see if all - # self._agent_specs[agents[0]].actions.shape are of the same shape critic_act_shape.insert(0, len(agents)) action_specs_per_type[agent_type] = tf.TensorSpec( diff --git a/mava/components/tf/architectures/state_based.py b/mava/components/tf/architectures/state_based.py index 2e8b887b5..8cdd1d8d1 100644 --- a/mava/components/tf/architectures/state_based.py +++ b/mava/components/tf/architectures/state_based.py @@ -115,7 +115,8 @@ def _get_critic_specs( ) ) - # TODO (dries): Fix this + # TODO (dries): Fix the line below and remove the assert. + # Don't assume len(critic_obs_spec) == 1. assert len(critic_obs_spec) == 1 critic_obs_spec = critic_obs_spec[0] @@ -208,7 +209,8 @@ def _get_critic_specs( ) ) - # TODO (dries): Fix this + # TODO (dries): Fix the line below and remove the assert. + # Don't assume len(critic_obs_spec) == 1. assert len(critic_obs_spec) == 1 critic_obs_spec = critic_obs_spec[0] diff --git a/mava/utils/sort_utils.py b/mava/utils/sort_utils.py index 2fbe24d2e..f3824c3bd 100644 --- a/mava/utils/sort_utils.py +++ b/mava/utils/sort_utils.py @@ -37,6 +37,8 @@ def sample_new_agent_keys( sampled from by the executors at the start of an environment run. shared_weights: whether agents should share weights or not. net_keys_to_ids: Dictionary mapping network keys to network ids. + fix_sampler: Optional list that can fix the executor sampler to sample + in a specific way. Returns: Tuple of dictionaries mapping network keys to ids. """ From 80dd1dfb4b26e0d78a25ddeae968b192cba23862 Mon Sep 17 00:00:00 2001 From: AsadJeewa Date: Thu, 21 Apr 2022 14:43:24 +0200 Subject: [PATCH 086/104] chore: remove redundant code --- mava/systems/tf/maddpg/networks.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/mava/systems/tf/maddpg/networks.py b/mava/systems/tf/maddpg/networks.py index 12825f1d9..e95713676 100644 --- a/mava/systems/tf/maddpg/networks.py +++ b/mava/systems/tf/maddpg/networks.py @@ -135,8 +135,6 @@ def make_default_networks( # Create the shared observation network; here simply a state-less operation. if observation_network is None: observation_network = tf2_utils.to_sonnet_module(tf.identity) - else: - observation_network = observation_network # Create the policy network. if architecture_type == ArchitectureType.feedforward: policy_network = [ From 897a08b1dfc1b071240ac5b46fb4e4e78da26277 Mon Sep 17 00:00:00 2001 From: AsadJeewa Date: Thu, 21 Apr 2022 14:53:37 +0200 Subject: [PATCH 087/104] chore: remove redundant code --- mava/systems/tf/maddpg/networks.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/mava/systems/tf/maddpg/networks.py b/mava/systems/tf/maddpg/networks.py index e95713676..c90be8369 100644 --- a/mava/systems/tf/maddpg/networks.py +++ b/mava/systems/tf/maddpg/networks.py @@ -132,9 +132,6 @@ def make_default_networks( # Get total number of action dimensions from action spec. num_dimensions = np.prod(agent_act_spec.shape, dtype=int) - # Create the shared observation network; here simply a state-less operation. - if observation_network is None: - observation_network = tf2_utils.to_sonnet_module(tf.identity) # Create the policy network. if architecture_type == ArchitectureType.feedforward: policy_network = [ From 1f7994a0154e72b29c7116925857deadbd608eee Mon Sep 17 00:00:00 2001 From: Dries Date: Fri, 22 Apr 2022 09:58:14 +0200 Subject: [PATCH 088/104] fix: Update 'list_of_networks' variable name to be 'networks_used_by_trainer'. --- mava/systems/tf/maddpg/builder.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/mava/systems/tf/maddpg/builder.py b/mava/systems/tf/maddpg/builder.py index 66a5187b4..28c651529 100644 --- a/mava/systems/tf/maddpg/builder.py +++ b/mava/systems/tf/maddpg/builder.py @@ -206,13 +206,13 @@ def convert_discrete_to_bounded( return env_adder_spec def covert_specs( - self, spec: Dict[str, Any], list_of_networks: List + self, spec: Dict[str, Any], networks_used_by_trainer: List ) -> Dict[str, Any]: if type(spec) is not dict: return spec agents = [] - for network in list_of_networks: + for network in networks_used_by_trainer: agents.append(self._config.net_spec_keys[network]) agents = sort_str_num(agents) @@ -224,7 +224,9 @@ def covert_specs( else: # For the extras for key in spec.keys(): - converted_spec[key] = self.covert_specs(spec[key], list_of_networks) + converted_spec[key] = self.covert_specs( + spec[key], networks_used_by_trainer + ) return converted_spec def make_replay_tables( @@ -291,18 +293,20 @@ def limiter_fn() -> reverb.rate_limiters: # TODO (dries): Clean the below coverter code up. # Convert a Mava spec - list_of_networks = self._config.table_network_config[table_key] + networks_used_by_trainer = self._config.table_network_config[table_key] env_spec = copy.deepcopy(env_adder_spec) - env_spec._specs = self.covert_specs(env_spec._specs, list_of_networks) + env_spec._specs = self.covert_specs( + env_spec._specs, networks_used_by_trainer + ) env_spec._keys = list(sort_str_num(env_spec._specs.keys())) if env_spec.extra_specs is not None: env_spec.extra_specs = self.covert_specs( - env_spec.extra_specs, list_of_networks + env_spec.extra_specs, networks_used_by_trainer ) extra_specs = self.covert_specs( self._extra_specs, - list_of_networks, + networks_used_by_trainer, ) replay_tables.append( From 019da426a9a894a1d91b8d513c3ba770c4db7729 Mon Sep 17 00:00:00 2001 From: Dries Date: Fri, 22 Apr 2022 10:33:23 +0200 Subject: [PATCH 089/104] fix: Small update to variable naming. --- mava/systems/tf/maddpg/builder.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/mava/systems/tf/maddpg/builder.py b/mava/systems/tf/maddpg/builder.py index 28c651529..41ce9d00d 100644 --- a/mava/systems/tf/maddpg/builder.py +++ b/mava/systems/tf/maddpg/builder.py @@ -206,13 +206,13 @@ def convert_discrete_to_bounded( return env_adder_spec def covert_specs( - self, spec: Dict[str, Any], networks_used_by_trainer: List + self, spec: Dict[str, Any], trainer_network_names: List ) -> Dict[str, Any]: if type(spec) is not dict: return spec agents = [] - for network in networks_used_by_trainer: + for network in trainer_network_names: agents.append(self._config.net_spec_keys[network]) agents = sort_str_num(agents) @@ -225,7 +225,7 @@ def covert_specs( # For the extras for key in spec.keys(): converted_spec[key] = self.covert_specs( - spec[key], networks_used_by_trainer + spec[key], trainer_network_names ) return converted_spec @@ -293,20 +293,18 @@ def limiter_fn() -> reverb.rate_limiters: # TODO (dries): Clean the below coverter code up. # Convert a Mava spec - networks_used_by_trainer = self._config.table_network_config[table_key] + trainer_network_names = self._config.table_network_config[table_key] env_spec = copy.deepcopy(env_adder_spec) - env_spec._specs = self.covert_specs( - env_spec._specs, networks_used_by_trainer - ) + env_spec._specs = self.covert_specs(env_spec._specs, trainer_network_names) env_spec._keys = list(sort_str_num(env_spec._specs.keys())) if env_spec.extra_specs is not None: env_spec.extra_specs = self.covert_specs( - env_spec.extra_specs, networks_used_by_trainer + env_spec.extra_specs, trainer_network_names ) extra_specs = self.covert_specs( self._extra_specs, - networks_used_by_trainer, + trainer_network_names, ) replay_tables.append( From f494b836e72d026b05a7510028a1e8fd29f380b6 Mon Sep 17 00:00:00 2001 From: Dries Date: Fri, 22 Apr 2022 12:07:55 +0200 Subject: [PATCH 090/104] fix: Small fix to PPO variable naming. --- mava/systems/tf/mappo/builder.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/mava/systems/tf/mappo/builder.py b/mava/systems/tf/mappo/builder.py index 935d2418e..ab306fd76 100644 --- a/mava/systems/tf/mappo/builder.py +++ b/mava/systems/tf/mappo/builder.py @@ -184,12 +184,12 @@ def add_log_prob_to_spec( return env_adder_spec def get_nets_specific_specs( - self, spec: Dict[str, Any], network_names: List + self, spec: Dict[str, Any], trainer_network_names: List ) -> Dict[str, Any]: """Convert specs. Args: spec: agent specs - network_names: names of the networks in the desired reverb table + trainer_network_names: names of the networks in the desired reverb table (to get specs for) Returns: distilled version of agent specs containing only specs related to @@ -199,7 +199,7 @@ def get_nets_specific_specs( return spec agents = [] - for network in network_names: + for network in trainer_network_names: agents.append(self._config.net_spec_keys[network]) agents = sort_str_num(agents) @@ -211,7 +211,7 @@ def get_nets_specific_specs( # For the extras for key in spec.keys(): converted_spec[key] = self.get_nets_specific_specs( - spec[key], network_names + spec[key], trainer_network_names ) return converted_spec @@ -236,20 +236,20 @@ def make_replay_tables( for table_key in self._config.table_network_config.keys(): # TODO (dries): Clean the below converter code up. # Convert a Mava spec - tables_network_names = self._config.table_network_config[table_key] + trainer_network_names = self._config.table_network_config[table_key] env_spec = copy.deepcopy(adder_env_spec) env_spec._specs = self.get_nets_specific_specs( - env_spec._specs, tables_network_names + env_spec._specs, trainer_network_names ) env_spec._keys = list(sort_str_num(env_spec._specs.keys())) if env_spec.extra_specs is not None: env_spec.extra_specs = self.get_nets_specific_specs( - env_spec.extra_specs, tables_network_names + env_spec.extra_specs, trainer_network_names ) extra_specs = self.get_nets_specific_specs( self._extra_specs, - tables_network_names, + trainer_network_names, ) signature = reverb_adders.ParallelSequenceAdder.signature( From fa57447f1ea24b7c8bcb3004ec1248cd82a73520 Mon Sep 17 00:00:00 2001 From: Kale-ab Date: Fri, 22 Apr 2022 15:10:03 +0200 Subject: [PATCH 091/104] chore: Up the patch version of mava. --- mava/_metadata.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mava/_metadata.py b/mava/_metadata.py index 59fbb2e4c..0671725a0 100644 --- a/mava/_metadata.py +++ b/mava/_metadata.py @@ -22,7 +22,7 @@ # We follow Semantic Versioning (https://semver.org/) _MAJOR_VERSION = "0" _MINOR_VERSION = "1" -_PATCH_VERSION = "2" +_PATCH_VERSION = "3" # Example: '0.4.2' __version__ = ".".join([_MAJOR_VERSION, _MINOR_VERSION, _PATCH_VERSION]) From 06773eebd536b1afbb5224c246cd268a01e49b1d Mon Sep 17 00:00:00 2001 From: Kale-ab Date: Mon, 6 Jun 2022 12:47:52 +0200 Subject: [PATCH 092/104] fix: Fixed keys in state based arch. --- mava/components/tf/architectures/state_based.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mava/components/tf/architectures/state_based.py b/mava/components/tf/architectures/state_based.py index 8cdd1d8d1..4cbabb41e 100644 --- a/mava/components/tf/architectures/state_based.py +++ b/mava/components/tf/architectures/state_based.py @@ -137,7 +137,7 @@ def _get_critic_specs( net_key = self._agent_net_keys[agent_key] # Get observation and action spec for critic. critic_obs_specs[net_key] = critic_obs_spec - critic_act_specs[net_key] = action_specs_per_type[agent_type] + critic_act_specs[agent_key] = action_specs_per_type[agent_type] return critic_obs_specs, critic_act_specs @@ -288,5 +288,5 @@ def _get_critic_specs( net_key = self._agent_net_keys[agent_key] # Get observation and action spec for critic. critic_obs_specs[net_key] = critic_obs_spec - critic_act_specs[net_key] = action_specs_per_type[agent_type] + critic_act_specs[agent_key] = action_specs_per_type[agent_type] return critic_obs_specs, critic_act_specs From 486d9636489102fa7fd83d14f86ad5299b14bac1 Mon Sep 17 00:00:00 2001 From: Kale-ab Date: Mon, 6 Jun 2022 13:00:29 +0200 Subject: [PATCH 093/104] fix: Fixed keys in centralised arch. --- mava/components/tf/architectures/centralised.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mava/components/tf/architectures/centralised.py b/mava/components/tf/architectures/centralised.py index 4ff207b8e..f51d1445e 100644 --- a/mava/components/tf/architectures/centralised.py +++ b/mava/components/tf/architectures/centralised.py @@ -167,8 +167,10 @@ def _get_critic_specs( critic_act_specs = {} for agent_key in self._agents: agent_type = agent_key.split("_")[0] + net_key = self._agent_net_keys[agent_key] + # Get observation and action spec for critic. - critic_obs_specs[agent_key] = obs_specs_per_type[agent_type] + critic_obs_specs[net_key] = obs_specs_per_type[agent_type] critic_act_specs[agent_key] = action_specs_per_type[agent_type] return critic_obs_specs, critic_act_specs From 9c00fa1e6f95788cab4d86e19da990bbd2a25f19 Mon Sep 17 00:00:00 2001 From: Kale-ab Date: Mon, 6 Jun 2022 15:11:26 +0200 Subject: [PATCH 094/104] feat: Update dependencies and remove lp dependency (it is already part of acme). --- Dockerfile | 4 ++-- README.md | 6 +++--- setup.py | 9 +++------ 3 files changed, 8 insertions(+), 11 deletions(-) diff --git a/Dockerfile b/Dockerfile index 6a239a66e..9f59a7393 100755 --- a/Dockerfile +++ b/Dockerfile @@ -23,8 +23,8 @@ WORKDIR ${folder} COPY . /home/app/mava # For box2d RUN apt-get install swig -y -## Install core dependencies. -RUN pip install -e .[reverb,launchpad] +## Install core dependencies + reverb. +RUN pip install -e .[reverb] ## Optional install for screen recording. ENV DISPLAY=:0 RUN if [ "$record" = "true" ]; then \ diff --git a/README.md b/README.md index dbfb65829..3adaff944 100644 --- a/README.md +++ b/README.md @@ -263,16 +263,16 @@ docker run --gpus all -it --rm -v $(pwd):/home/app/mava -w /home/app/mava insta * Install core dependencies: ```bash - pip install id-mava[tf,reverb,launchpad] + pip install id-mava[tf,reverb] ``` * Or for the latest version of mava from source (**you can do this for all pip install commands below for the latest depedencies**): ```bash - pip install git+https://github.com/instadeepai/Mava#egg=id-mava[reverb,tf,launchpad] + pip install git+https://github.com/instadeepai/Mava#egg=id-mava[reverb,tf] ``` - **For the jax version of mava, please replace `tf` with `jax`, e.g. `pip install id-mava[jax,reverb,launchpad]`** + **For the jax version of mava, please replace `tf` with `jax`, e.g. `pip install id-mava[jax,reverb]`** 1.2 For **optional** environments: * PettingZoo: diff --git a/setup.py b/setup.py index 57ef545f2..923b06d58 100644 --- a/setup.py +++ b/setup.py @@ -26,11 +26,11 @@ spec.loader.exec_module(_metadata) # type: ignore reverb_requirements = [ - "dm-reverb~=0.6.1", + "dm-reverb~=0.7.2", ] tf_requirements = [ - "tensorflow~=2.7.0", + "tensorflow~=2.8.0", "tensorflow_probability~=0.15.0", "dm-sonnet", "trfl", @@ -54,8 +54,6 @@ "pysc2", ] -launchpad_requirements = ["dm-launchpad~=0.3.2"] - smac_requirements = ["pysc2", "SMAC @ git+https://github.com/oxwhirl/smac.git"] testing_formatting_requirements = [ "pytest==6.2.4", @@ -102,7 +100,7 @@ keywords="multi-agent reinforcement-learning python machine learning", packages=find_packages(), install_requires=[ - "dm-acme~=0.2.4", + "dm-acme~=0.4.0", "chex", "absl-py", "dm_env", @@ -120,7 +118,6 @@ "flatland": flatland_requirements, "open_spiel": open_spiel_requirements, "reverb": reverb_requirements, - "launchpad": launchpad_requirements, "testing_formatting": testing_formatting_requirements, "record_episode": record_episode_requirements, "sc2": smac_requirements, From 1e40dd746aac5008911ba1b3e49a1bc629aa1700 Mon Sep 17 00:00:00 2001 From: Kale-ab Date: Mon, 6 Jun 2022 15:15:35 +0200 Subject: [PATCH 095/104] chore: Removed independent lp install. --- bash_scripts/tests.sh | 2 +- docs/requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/bash_scripts/tests.sh b/bash_scripts/tests.sh index 9db8f3ecc..2e031bd19 100755 --- a/bash_scripts/tests.sh +++ b/bash_scripts/tests.sh @@ -46,7 +46,7 @@ apt-get -y install git apt-get install swig -y # Install depedencies -pip install .[tf,envs,reverb,testing_formatting,launchpad,record_episode] +pip install .[tf,envs,reverb,testing_formatting,record_episode] # For atari envs apt-get -y install unrar-free diff --git a/docs/requirements.txt b/docs/requirements.txt index 79b4eb5b9..3ec968d47 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,4 +1,4 @@ -git+https://github.com/instadeepai/Mava#egg=id-mava[tf,envs,reverb,launchpad] +git+https://github.com/instadeepai/Mava#egg=id-mava[tf,envs,reverb] mkdocstrings mkdocs-material pymdown-extensions From f1c346ec4044201ba57a76d566f000038c5e9b3f Mon Sep 17 00:00:00 2001 From: Kale-ab Date: Tue, 7 Jun 2022 14:50:50 +0200 Subject: [PATCH 096/104] chore :Improve consistency env_states -> s_t. --- mava/components/tf/architectures/state_based.py | 8 ++++---- mava/systems/tf/maddpg/training.py | 12 ++++++------ mava/systems/tf/mappo/training.py | 6 +++--- mava/utils/debugging/environment.py | 2 +- mava/utils/debugging/environments/two_step.py | 10 +++++----- .../RoboCup_env/robocup_utils/util_functions.py | 2 +- mava/wrappers/debugging_envs.py | 6 +++--- 7 files changed, 23 insertions(+), 23 deletions(-) diff --git a/mava/components/tf/architectures/state_based.py b/mava/components/tf/architectures/state_based.py index 4cbabb41e..ef82904b8 100644 --- a/mava/components/tf/architectures/state_based.py +++ b/mava/components/tf/architectures/state_based.py @@ -56,7 +56,7 @@ def _get_actor_specs( agents_by_type = self._env_spec.get_agents_by_type() for agent_type, agents in agents_by_type.items(): - actor_state_shape = self._env_spec.get_extra_specs()["env_states"].shape + actor_state_shape = self._env_spec.get_extra_specs()["s_t"].shape obs_specs_per_type[agent_type] = tf.TensorSpec( shape=actor_state_shape, dtype=tf.dtypes.float32, @@ -99,7 +99,7 @@ def _get_critic_specs( # Create one critic per agent. Each critic gets # absolute state information of the environment. - critic_env_state_spec = self._env_spec.get_extra_specs()["env_states"] + critic_env_state_spec = self._env_spec.get_extra_specs()["s_t"] if type(critic_env_state_spec) == dict: critic_env_state_spec = list(critic_env_state_spec.values())[0] @@ -193,7 +193,7 @@ def _get_critic_specs( # Create one critic per agent. Each critic gets # absolute state information of the environment. - critic_env_state_spec = self._env_spec.get_extra_specs()["env_states"] + critic_env_state_spec = self._env_spec.get_extra_specs()["s_t"] if type(critic_env_state_spec) == dict: critic_env_state_spec = list(critic_env_state_spec.values())[0] @@ -255,7 +255,7 @@ def _get_critic_specs( # Create one critic per agent. Each critic gets # absolute state information of the environment. - critic_env_state_spec = self._env_spec.get_extra_specs()["env_states"] + critic_env_state_spec = self._env_spec.get_extra_specs()["s_t"] if type(critic_env_state_spec) == dict: critic_env_state_spec = list(critic_env_state_spec.values())[0] diff --git a/mava/systems/tf/maddpg/training.py b/mava/systems/tf/maddpg/training.py index 626e0555c..b4ace468b 100644 --- a/mava/systems/tf/maddpg/training.py +++ b/mava/systems/tf/maddpg/training.py @@ -864,8 +864,8 @@ def _get_critic_feed( """Get critic feed.""" # State based - o_tm1_feed = e_tm1["env_states"] - o_t_feed = e_t["env_states"] + o_tm1_feed = e_tm1["s_t"] + o_t_feed = e_t["s_t"] a_tm1_feed = tf.stack([a_tm1[agent] for agent in self._agents], 1) a_t_feed = tf.stack([a_t[agent] for agent in self._agents], 1) @@ -1733,8 +1733,8 @@ def _get_critic_feed( ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]: # State based - obs_trans_feed = extras["env_states"] - target_obs_trans_feed = extras["env_states"] + obs_trans_feed = extras["s_t"] + target_obs_trans_feed = extras["s_t"] if type(obs_trans_feed) == dict: # type: ignore obs_trans_feed = obs_trans_feed[agent] target_obs_trans_feed = target_obs_trans_feed[agent] @@ -1828,8 +1828,8 @@ def _get_critic_feed( agent: str, ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]: # State based - obs_trans_feed = extras["env_states"][agent] - target_obs_trans_feed = extras["env_states"][agent] + obs_trans_feed = extras["s_t"][agent] + target_obs_trans_feed = extras["s_t"][agent] actions_feed = actions[agent] target_actions_feed = target_actions[agent] return obs_trans_feed, target_obs_trans_feed, actions_feed, target_actions_feed diff --git a/mava/systems/tf/mappo/training.py b/mava/systems/tf/mappo/training.py index c5404c623..6f45b18ee 100644 --- a/mava/systems/tf/mappo/training.py +++ b/mava/systems/tf/mappo/training.py @@ -793,7 +793,7 @@ def _get_critic_feed( agent: str, ) -> tf.Tensor: # State based - if type(extras["env_states"]) == dict: # type: ignore - return extras["env_states"][agent] + if type(extras["s_t"]) == dict: # type: ignore + return extras["s_t"][agent] else: - return extras["env_states"] + return extras["s_t"] diff --git a/mava/utils/debugging/environment.py b/mava/utils/debugging/environment.py index efc2080b9..b73a902bb 100644 --- a/mava/utils/debugging/environment.py +++ b/mava/utils/debugging/environment.py @@ -180,7 +180,7 @@ def reset(self) -> Tuple[Dict[str, Any], Dict[str, Any]]: agent = self.agents[agent_id] obs_n[agent_id] = self._get_obs(a_i, agent) state_n = self._get_state() - return obs_n, {"env_states": state_n} + return obs_n, {"s_t": state_n} # get info used for benchmarking def _get_info(self, agent: Agent) -> Dict: diff --git a/mava/utils/debugging/environments/two_step.py b/mava/utils/debugging/environments/two_step.py index b5f060c5b..d5efce894 100644 --- a/mava/utils/debugging/environments/two_step.py +++ b/mava/utils/debugging/environments/two_step.py @@ -62,7 +62,7 @@ def reset( "agent_0": np.array([0.0], dtype=np.float32), "agent_1": np.array([0.0], dtype=np.float32), } - return self.obs_n, {"env_states": np.array(self.state, dtype=np.int64)} + return self.obs_n, {"s_t": np.array(self.state, dtype=np.int64)} def step( self, action_n: Dict[str, int] @@ -83,7 +83,7 @@ def step( self.obs_n, self.reward_n, self.done_n, - {"env_states": np.array(self.state, dtype=np.int64)}, + {"s_t": np.array(self.state, dtype=np.int64)}, ) # Go to 2A else: self.state = 2 @@ -95,7 +95,7 @@ def step( self.obs_n, self.reward_n, self.done_n, - {"env_states": np.array(self.state, dtype=np.int64)}, + {"s_t": np.array(self.state, dtype=np.int64)}, ) # Go to 2B elif self.state == 1: # State 2A @@ -110,7 +110,7 @@ def step( self.obs_n, self.reward_n, self.done_n, - {"env_states": np.array(self.state, dtype=np.int64)}, + {"s_t": np.array(self.state, dtype=np.int64)}, ) elif self.state == 2: # State 2B @@ -141,7 +141,7 @@ def step( self.obs_n, self.reward_n, self.done_n, - {"env_states": np.array(self.state, dtype=np.int64)}, + {"s_t": np.array(self.state, dtype=np.int64)}, ) else: diff --git a/mava/utils/environments/RoboCup_env/robocup_utils/util_functions.py b/mava/utils/environments/RoboCup_env/robocup_utils/util_functions.py index 50bd74fe7..cbdbfcd21 100644 --- a/mava/utils/environments/RoboCup_env/robocup_utils/util_functions.py +++ b/mava/utils/environments/RoboCup_env/robocup_utils/util_functions.py @@ -308,7 +308,7 @@ def discount_spec(self) -> Dict[str, specs.BoundedArray]: return discount_specs def extra_spec(self) -> Dict[str, specs.BoundedArray]: - return {"env_states": self._state_spec} + return {"s_t": self._state_spec} def _proc_robocup_obs( self, observations: Dict, done: bool, nn_actions: Dict = None diff --git a/mava/wrappers/debugging_envs.py b/mava/wrappers/debugging_envs.py index 8f9cfabc8..27363aa67 100644 --- a/mava/wrappers/debugging_envs.py +++ b/mava/wrappers/debugging_envs.py @@ -105,7 +105,7 @@ def step( step_type=self._step_type, ) if self.return_state_info: - return timestep, {"env_states": state} + return timestep, {"s_t": state} else: return timestep @@ -182,7 +182,7 @@ def extra_spec(self) -> Dict[str, specs.BoundedArray]: minimum=[float("-inf")] * shape[0], maximum=[float("inf")] * shape[0], ) - extras.update({"env_states": ex_spec}) + extras.update({"s_t": ex_spec}) return extras @@ -228,7 +228,7 @@ def __init__(self, environment: TwoStepEnv) -> None: self.environment.action_spaces = {} self.environment.observation_spaces = {} self.environment.extra_specs = { - "env_states": spaces.Discrete(3) + "s_t": spaces.Discrete(3) } # Global state 1, 2, or 3 for agent_id in self.environment.agent_ids: From 4b163c0d1c7900b0499080d5b897816e2c2ef548 Mon Sep 17 00:00:00 2001 From: Kale-ab Date: Tue, 7 Jun 2022 14:51:39 +0200 Subject: [PATCH 097/104] fix: Make all tf tests use shared weights. --- tests/systems/mad4pg_system_test.py | 2 -- tests/systems/maddpg_system_test.py | 3 --- 2 files changed, 5 deletions(-) diff --git a/tests/systems/mad4pg_system_test.py b/tests/systems/mad4pg_system_test.py index 12cb0cc1d..12eea347f 100644 --- a/tests/systems/mad4pg_system_test.py +++ b/tests/systems/mad4pg_system_test.py @@ -167,7 +167,6 @@ def test_centralised_mad4pg_on_debugging_env(self) -> None: checkpoint=False, architecture=architectures.CentralisedQValueCritic, trainer_fn=mad4pg.MAD4PGCentralisedTrainer, - shared_weights=False, ) program = system.build() @@ -221,7 +220,6 @@ def test_state_based_mad4pg_on_debugging_env(self) -> None: checkpoint=False, trainer_fn=mad4pg.MAD4PGStateBasedTrainer, architecture=architectures.StateBasedQValueCritic, - shared_weights=False, ) program = system.build() diff --git a/tests/systems/maddpg_system_test.py b/tests/systems/maddpg_system_test.py index def0b1bbc..42e72d67b 100644 --- a/tests/systems/maddpg_system_test.py +++ b/tests/systems/maddpg_system_test.py @@ -162,7 +162,6 @@ def test_centralised_maddpg_on_debugging_env(self) -> None: checkpoint=False, architecture=architectures.CentralisedQValueCritic, trainer_fn=maddpg.MADDPGCentralisedTrainer, - shared_weights=False, ) program = system.build() @@ -214,7 +213,6 @@ def test_networked_maddpg_on_debugging_env(self) -> None: trainer_fn=maddpg.MADDPGNetworkedTrainer, architecture=architectures.NetworkedQValueCritic, connection_spec=fully_connected_network_spec, - shared_weights=False, ) program = system.build() @@ -266,7 +264,6 @@ def test_state_based_maddpg_on_debugging_env(self) -> None: checkpoint=False, trainer_fn=maddpg.MADDPGStateBasedTrainer, architecture=architectures.StateBasedQValueCritic, - shared_weights=False, ) program = system.build() From 776b6127e1761478d20631ba73773f3616c5b4b8 Mon Sep 17 00:00:00 2001 From: Kale-ab Date: Tue, 7 Jun 2022 15:59:34 +0200 Subject: [PATCH 098/104] fix: Fix networked arch. --- mava/components/tf/architectures/networked.py | 5 +++-- tests/systems/maddpg_system_test.py | 20 ++++++++++++++++++- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/mava/components/tf/architectures/networked.py b/mava/components/tf/architectures/networked.py index 2cd384adc..c6b40346f 100644 --- a/mava/components/tf/architectures/networked.py +++ b/mava/components/tf/architectures/networked.py @@ -114,12 +114,13 @@ def _get_critic_specs( for agent_type, agents in agents_by_type.items(): for agent in agents: - critic_obs_shape = list(copy.copy(self._embed_specs[agent].shape)) + net_key = self._agent_net_keys[agent] + critic_obs_shape = list(copy.copy(self._embed_specs[net_key].shape)) critic_act_shape = list( copy.copy(self._agent_specs[agent].actions.shape) ) critic_obs_shape.insert(0, len(self._network_spec[agent])) - critic_obs_specs[agent] = tf.TensorSpec( + critic_obs_specs[net_key] = tf.TensorSpec( shape=critic_obs_shape, dtype=tf.dtypes.float32, ) diff --git a/tests/systems/maddpg_system_test.py b/tests/systems/maddpg_system_test.py index 42e72d67b..28b767472 100644 --- a/tests/systems/maddpg_system_test.py +++ b/tests/systems/maddpg_system_test.py @@ -18,6 +18,7 @@ import functools import launchpad as lp +import pytest import sonnet as snt import mava @@ -44,7 +45,9 @@ def test_maddpg_on_debugging_env(self) -> None: # networks network_factory = lp_utils.partial_kwargs( - maddpg.make_default_networks, policy_networks_layer_sizes=(64, 64) + maddpg.make_default_networks, + policy_networks_layer_sizes=(32, 32), + critic_networks_layer_sizes=(64, 64), ) # system @@ -94,6 +97,7 @@ def test_recurrent_maddpg_on_debugging_env(self) -> None: maddpg.make_default_networks, architecture_type=ArchitectureType.recurrent, policy_networks_layer_sizes=(32, 32), + critic_networks_layer_sizes=(64, 64), ) # system @@ -147,6 +151,7 @@ def test_centralised_maddpg_on_debugging_env(self) -> None: network_factory = lp_utils.partial_kwargs( maddpg.make_default_networks, policy_networks_layer_sizes=(32, 32), + critic_networks_layer_sizes=(64, 64), ) # system @@ -184,6 +189,16 @@ def test_centralised_maddpg_on_debugging_env(self) -> None: for _ in range(2): trainer.step() + @pytest.mark.skip( + reason=""" + Running tests with shared_weights=False pass when running indepedently + (other tests commented out), but fail when run with other tests and not + enough parallel cores (2 or less). This is likely a race condition, + hangling process from previous tests or related to network sampling + (TODO @Dries investigate if you have a chance). Only the test fails, + the examples run. + """ + ) def test_networked_maddpg_on_debugging_env(self) -> None: """Test networked maddpg.""" # environment @@ -197,6 +212,7 @@ def test_networked_maddpg_on_debugging_env(self) -> None: network_factory = lp_utils.partial_kwargs( maddpg.make_default_networks, policy_networks_layer_sizes=(32, 32), + critic_networks_layer_sizes=(64, 64), ) # system @@ -213,6 +229,7 @@ def test_networked_maddpg_on_debugging_env(self) -> None: trainer_fn=maddpg.MADDPGNetworkedTrainer, architecture=architectures.NetworkedQValueCritic, connection_spec=fully_connected_network_spec, + shared_weights=False, ) program = system.build() @@ -249,6 +266,7 @@ def test_state_based_maddpg_on_debugging_env(self) -> None: network_factory = lp_utils.partial_kwargs( maddpg.make_default_networks, policy_networks_layer_sizes=(32, 32), + critic_networks_layer_sizes=(64, 64), ) # system From 4db7532dc101c763fab4215058ac5f590703a3f1 Mon Sep 17 00:00:00 2001 From: Kale-ab Date: Tue, 7 Jun 2022 16:00:12 +0200 Subject: [PATCH 099/104] chore(tests): speed up tests by using smaller networks. --- tests/systems/mad4pg_system_test.py | 6 +++++- tests/systems/maddpg_system_test.py | 6 +++--- tests/systems/mappo_system_test.py | 4 +++- tests/systems/qmix_system_test.py | 1 + tests/systems/vdn_system_test.py | 1 + 5 files changed, 13 insertions(+), 5 deletions(-) diff --git a/tests/systems/mad4pg_system_test.py b/tests/systems/mad4pg_system_test.py index 12eea347f..66fdd6595 100644 --- a/tests/systems/mad4pg_system_test.py +++ b/tests/systems/mad4pg_system_test.py @@ -43,7 +43,8 @@ def test_mad4pg_on_debugging_env(self) -> None: # networks network_factory = lp_utils.partial_kwargs( mad4pg.make_default_networks, - policy_networks_layer_sizes=(64, 64), + policy_networks_layer_sizes=(32, 32), + critic_networks_layer_sizes=(64, 64), vmin=-10, vmax=50, ) @@ -95,6 +96,7 @@ def test_recurrent_mad4pg_on_debugging_env(self) -> None: mad4pg.make_default_networks, architecture_type=ArchitectureType.recurrent, policy_networks_layer_sizes=(32, 32), + critic_networks_layer_sizes=(64, 64), vmin=-10, vmax=50, ) @@ -150,6 +152,7 @@ def test_centralised_mad4pg_on_debugging_env(self) -> None: network_factory = lp_utils.partial_kwargs( mad4pg.make_default_networks, policy_networks_layer_sizes=(32, 32), + critic_networks_layer_sizes=(64, 64), vmin=-10, vmax=50, ) @@ -203,6 +206,7 @@ def test_state_based_mad4pg_on_debugging_env(self) -> None: network_factory = lp_utils.partial_kwargs( mad4pg.make_default_networks, policy_networks_layer_sizes=(32, 32), + critic_networks_layer_sizes=(64, 64), vmin=-10, vmax=50, ) diff --git a/tests/systems/maddpg_system_test.py b/tests/systems/maddpg_system_test.py index 28b767472..10ae5f0ad 100644 --- a/tests/systems/maddpg_system_test.py +++ b/tests/systems/maddpg_system_test.py @@ -191,9 +191,9 @@ def test_centralised_maddpg_on_debugging_env(self) -> None: @pytest.mark.skip( reason=""" - Running tests with shared_weights=False pass when running indepedently - (other tests commented out), but fail when run with other tests and not - enough parallel cores (2 or less). This is likely a race condition, + Running tests with shared_weights=False pass when running indepedently + (other tests commented out), but fail when run with other tests and not + enough parallel cores (2 or less). This is likely a race condition, hangling process from previous tests or related to network sampling (TODO @Dries investigate if you have a chance). Only the test fails, the examples run. diff --git a/tests/systems/mappo_system_test.py b/tests/systems/mappo_system_test.py index 9ca505f6d..d1a3ff126 100644 --- a/tests/systems/mappo_system_test.py +++ b/tests/systems/mappo_system_test.py @@ -40,7 +40,9 @@ def test_mappo_on_debugging_env(self) -> None: # networks network_factory = lp_utils.partial_kwargs( - mappo.make_default_networks, policy_networks_layer_sizes=(64, 64) + mappo.make_default_networks, + policy_networks_layer_sizes=(32, 32), + critic_networks_layer_sizes=(64, 64), ) # system diff --git a/tests/systems/qmix_system_test.py b/tests/systems/qmix_system_test.py index 3949ed094..cf6ccc1dc 100644 --- a/tests/systems/qmix_system_test.py +++ b/tests/systems/qmix_system_test.py @@ -47,6 +47,7 @@ def test_qmix_on_debug_simple_spread(self) -> None: # Networks. network_factory = lp_utils.partial_kwargs( value_decomposition.make_default_networks, + value_networks_layer_sizes=(10, 10), ) # system diff --git a/tests/systems/vdn_system_test.py b/tests/systems/vdn_system_test.py index 5adf5373d..23dbb2b8a 100644 --- a/tests/systems/vdn_system_test.py +++ b/tests/systems/vdn_system_test.py @@ -46,6 +46,7 @@ def test_vdn_on_debug_simple_spread(self) -> None: # Networks. network_factory = lp_utils.partial_kwargs( value_decomposition.make_default_networks, + value_networks_layer_sizes=(10, 10), ) # system From 2cd12ba1559f23df08958f75ac220601b73a6af1 Mon Sep 17 00:00:00 2001 From: Kale-ab Date: Tue, 7 Jun 2022 16:35:47 +0200 Subject: [PATCH 100/104] chore: Better test skip reason. --- tests/systems/maddpg_system_test.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/systems/maddpg_system_test.py b/tests/systems/maddpg_system_test.py index 10ae5f0ad..80da74777 100644 --- a/tests/systems/maddpg_system_test.py +++ b/tests/systems/maddpg_system_test.py @@ -193,10 +193,10 @@ def test_centralised_maddpg_on_debugging_env(self) -> None: reason=""" Running tests with shared_weights=False pass when running indepedently (other tests commented out), but fail when run with other tests and not - enough parallel cores (2 or less). This is likely a race condition, - hangling process from previous tests or related to network sampling - (TODO @Dries investigate if you have a chance). Only the test fails, - the examples run. + enough parallel cores (fail with 2 or less cores, pass with 4 cpu cores). + This is likely a race condition, hangling process from previous tests + or related to network sampling (TODO @Dries investigate if you have a + chance). Only the test fails, the examples run. """ ) def test_networked_maddpg_on_debugging_env(self) -> None: From 01b7140c052bb2b670e7e2d4ec5c9c151f4f4623 Mon Sep 17 00:00:00 2001 From: Kale-ab Date: Tue, 7 Jun 2022 19:24:45 +0200 Subject: [PATCH 101/104] fix: Update python version for melting pot images. --- Dockerfile | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index 9f59a7393..93ea356e2 100755 --- a/Dockerfile +++ b/Dockerfile @@ -29,7 +29,7 @@ RUN pip install -e .[reverb] ENV DISPLAY=:0 RUN if [ "$record" = "true" ]; then \ ./bash_scripts/install_record.sh; \ -fi + fi EXPOSE 6006 ########################################################## @@ -89,6 +89,10 @@ RUN pip install .[open_spiel] ########################################################## # MeltingPot Image FROM tf-core AS meltingpot + +# Melting pot requires python>=3.9 +RUN apt-get install python3.9 -y && update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.9 1 && rm -rf /root/.cache && apt-get clean + # Install meltingpot RUN apt-get install -y git RUN ./bash_scripts/install_meltingpot.sh @@ -153,6 +157,10 @@ RUN pip install .[open_spiel] ########################################################## # MeltingPot Image FROM jax-core AS meltingpot-jax + +# Melting pot requires python>=3.9 +RUN apt-get install python3.9 -y && update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.9 1 && rm -rf /root/.cache && apt-get clean + # Install meltingpot RUN apt-get install -y git RUN ./bash_scripts/install_meltingpot.sh From 05b69e4069cb372b104a8b03fa7fe587f2805320 Mon Sep 17 00:00:00 2001 From: Kale-ab Date: Wed, 8 Jun 2022 11:23:12 +0200 Subject: [PATCH 102/104] fix: pin meltingpot to before py3.9 requirement. --- Dockerfile | 8 -------- bash_scripts/install_meltingpot.sh | 5 ++++- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/Dockerfile b/Dockerfile index 93ea356e2..d2ae962bf 100755 --- a/Dockerfile +++ b/Dockerfile @@ -89,10 +89,6 @@ RUN pip install .[open_spiel] ########################################################## # MeltingPot Image FROM tf-core AS meltingpot - -# Melting pot requires python>=3.9 -RUN apt-get install python3.9 -y && update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.9 1 && rm -rf /root/.cache && apt-get clean - # Install meltingpot RUN apt-get install -y git RUN ./bash_scripts/install_meltingpot.sh @@ -157,10 +153,6 @@ RUN pip install .[open_spiel] ########################################################## # MeltingPot Image FROM jax-core AS meltingpot-jax - -# Melting pot requires python>=3.9 -RUN apt-get install python3.9 -y && update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.9 1 && rm -rf /root/.cache && apt-get clean - # Install meltingpot RUN apt-get install -y git RUN ./bash_scripts/install_meltingpot.sh diff --git a/bash_scripts/install_meltingpot.sh b/bash_scripts/install_meltingpot.sh index 4eca5828e..258030e55 100755 --- a/bash_scripts/install_meltingpot.sh +++ b/bash_scripts/install_meltingpot.sh @@ -20,4 +20,7 @@ pip install https://github.com/deepmind/lab2d/releases/download/release_candidat mkdir ../packages cd ../packages git clone -b main https://github.com/deepmind/meltingpot -pip install meltingpot/ +# Temp - Pin version before python 3.9 requirement. +cd meltingpot/ +git checkout c4fea35d3b57b0c4fa10db823a81e16244826e8d +pip install . From a6d01fa8a30770931e23663d523563ffa7e2ce90 Mon Sep 17 00:00:00 2001 From: Kale-ab Date: Thu, 9 Jun 2022 12:44:35 +0200 Subject: [PATCH 103/104] chore: More explicit on shared weights in tests. --- tests/systems/mad4pg_system_test.py | 2 ++ tests/systems/maddpg_system_test.py | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/systems/mad4pg_system_test.py b/tests/systems/mad4pg_system_test.py index 66fdd6595..1ff270892 100644 --- a/tests/systems/mad4pg_system_test.py +++ b/tests/systems/mad4pg_system_test.py @@ -170,6 +170,7 @@ def test_centralised_mad4pg_on_debugging_env(self) -> None: checkpoint=False, architecture=architectures.CentralisedQValueCritic, trainer_fn=mad4pg.MAD4PGCentralisedTrainer, + shared_weights=True, ) program = system.build() @@ -224,6 +225,7 @@ def test_state_based_mad4pg_on_debugging_env(self) -> None: checkpoint=False, trainer_fn=mad4pg.MAD4PGStateBasedTrainer, architecture=architectures.StateBasedQValueCritic, + shared_weights=True, ) program = system.build() diff --git a/tests/systems/maddpg_system_test.py b/tests/systems/maddpg_system_test.py index 80da74777..b3b8ffd1d 100644 --- a/tests/systems/maddpg_system_test.py +++ b/tests/systems/maddpg_system_test.py @@ -167,6 +167,7 @@ def test_centralised_maddpg_on_debugging_env(self) -> None: checkpoint=False, architecture=architectures.CentralisedQValueCritic, trainer_fn=maddpg.MADDPGCentralisedTrainer, + shared_weights=True, ) program = system.build() @@ -229,7 +230,6 @@ def test_networked_maddpg_on_debugging_env(self) -> None: trainer_fn=maddpg.MADDPGNetworkedTrainer, architecture=architectures.NetworkedQValueCritic, connection_spec=fully_connected_network_spec, - shared_weights=False, ) program = system.build() @@ -282,6 +282,7 @@ def test_state_based_maddpg_on_debugging_env(self) -> None: checkpoint=False, trainer_fn=maddpg.MADDPGStateBasedTrainer, architecture=architectures.StateBasedQValueCritic, + shared_weights=True, ) program = system.build() From deb6d6feea7784ec86af61b0736e5b05486e2ba8 Mon Sep 17 00:00:00 2001 From: Kale-ab Date: Thu, 9 Jun 2022 12:46:55 +0200 Subject: [PATCH 104/104] fix: networked arch only support shared weights. --- tests/systems/maddpg_system_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/systems/maddpg_system_test.py b/tests/systems/maddpg_system_test.py index b3b8ffd1d..e7fec0d19 100644 --- a/tests/systems/maddpg_system_test.py +++ b/tests/systems/maddpg_system_test.py @@ -230,6 +230,7 @@ def test_networked_maddpg_on_debugging_env(self) -> None: trainer_fn=maddpg.MADDPGNetworkedTrainer, architecture=architectures.NetworkedQValueCritic, connection_spec=fully_connected_network_spec, + shared_weights=False, ) program = system.build()