diff --git a/.github/ISSUE_TEMPLATE/bugfix_internal.md b/.github/ISSUE_TEMPLATE/bugfix_internal.md index 8cb718241..d4a544e4b 100644 --- a/.github/ISSUE_TEMPLATE/bugfix_internal.md +++ b/.github/ISSUE_TEMPLATE/bugfix_internal.md @@ -12,10 +12,10 @@ A clear and concise description of what the bug is. ### To Reproduce Steps to reproduce the behavior: -1. -2. -3. -4. +1. +2. +3. +4. ### Expected behavior A clear and concise description of what you expected to happen. diff --git a/.github/ISSUE_TEMPLATE/investigation_internal.md b/.github/ISSUE_TEMPLATE/investigation_internal.md new file mode 100644 index 000000000..c3431bfe8 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/investigation_internal.md @@ -0,0 +1,29 @@ +--- +name: Investigation +about: Outline the structure for an investigation. This would commonly be used to measure the impact of various design/implementation choices. +title: '[INVESTIGATION]' +labels: investigation +assignees: '' + +--- + +### What do you want to investigate? +A brief description of what you would like to investigate. Do you have a hypothesis? + +### Definition of done +A precise outline for the investigation to be considered complete. + +### [***Optional***] Results + +Results from experiments/derivations. This could be linked to a benchmarking issue. + +### What was the conclusion of your investigation? + +- What are the findings from the investigation? +- Was your hypothesis correct? + +### [***Optional***] Discussion/Future Investigations + +This could be a link to a Github [discussions page](https://github.com/instadeepai/Mava/discussions). + + diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 09fef9b6c..e4872232f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -48,7 +48,7 @@ repos: additional_dependencies: [flake8-isort] - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.910 + rev: v0.941 hooks: - id: mypy exclude: ^docs/ diff --git a/examples/README.md b/examples/README.md index 876fba3a7..03788b0a7 100644 --- a/examples/README.md +++ b/examples/README.md @@ -11,7 +11,7 @@ We include a number of systems running on continuous control tasks. - **MADDPG**: a MADDPG system running on the continuous action space simple_spread MPE environment. - *Feedforward*: - - decentralised + - Decentralised - [decentralised][debug_maddpg_ff_dec] - [decentralised record agents][debug_maddpg_ff_dec_record] (***recording agents acting in the environment***) - [decentralised executor scaling][debug_maddpg_ff_dec_scaling_executors] (***scaling to 4 executors***) @@ -20,7 +20,7 @@ We include a number of systems running on continuous control tasks. - [decentralised lr scheduling][debug_maddpg_ff_dec_lr_scheduling](***using lr schedule***) - [decentralised evaluator intervals][debug_maddpg_ff_dec_eval_intervals](***running the evaluation loop at intervals***) - - [centralised][debug_maddpg_cen] , [networked][debug_maddpg_networked] (***using a fully-connected, networked architecture***), [networked with custom architecture][debug_maddpg_networked_custom] (***using a custom, sparse, networked architecture***) and [state_based][debug_maddpg_state_based]. + - [centralised][debug_maddpg_cen] , [networked][debug_maddpg_networked] (***using a fully-connected, networked architecture***), [networked with custom architecture][debug_maddpg_networked_custom] (***using a custom, sparse, networked architecture***) and [state_based][debug_maddpg_state_based]. - *Recurrent* - [decentralised][debug_maddpg_rec_dec] and [state_based][debug_maddpg_state_based]. @@ -45,17 +45,19 @@ We include a number of systems running on continuous control tasks. - **MAD4PG**: a MAD4PG system running on the Multiwalker environment. - *Feedforward* - - [decentralised][pz_mad4pg_ff_dec] and [decentralised record agents][pz_mad4pg_ff_dec_record] (***recording agents acting in the environment***). + - [decentralised][pz_mad4pg_ff_dec] + - [decentralised record agents][pz_mad4pg_ff_dec_record] (***recording agents acting in the environment***). - - **MAPPO** - - *Feedforward* +- **MAPPO** + - *Feedforward* - [decentralised][pz_mappo_ff_dec]. ### 2D RoboCup - **MAD4PG**: a MAD4PG system running on the RoboCup environment. - - *Recurrent* [state_based][robocup_mad4pg_ff_state_based]. + - *Recurrent* + - [state_based][robocup_mad4pg_ff_state_based]. ## Discrete control @@ -71,29 +73,42 @@ We also include a number of systems running on discrete action space environment - **MADQN**: a MADQN system running on the discrete action space simple_spread MPE environment. - *Feedforward* - - [decentralised][debug_madqn_ff_dec], [decentralised lr scheduling][debug_madqn_ff_dec_lr_schedule] (***using lr schedule***), [decentralised custom lr scheduling][debug_madqn_ff_dec_custom_lr_schedule] (***using custom lr schedule***) and [decentralised custom epsilon decay scheduling][debug_madqn_ff_dec_custom_eps_schedule] (***using configurable epsilon scheduling***). + - Decentralised + - [decentralised][debug_madqn_ff_dec] + - [decentralised lr scheduling][debug_madqn_ff_dec_lr_schedule] (***using lr schedule***) + - [decentralised custom lr scheduling][debug_madqn_ff_dec_custom_lr_schedule] (***using custom lr schedule***) + - [decentralised custom epsilon decay scheduling][debug_madqn_ff_dec_custom_eps_schedule] (***using configurable epsilon scheduling***). - *Recurrent* - [decentralised][debug_madqn_rec_dec]. - **VDN**: a VDN system running on the discrete action space simple_spread MPE environment. - - *Recurrent* [centralised][debug_vdn_rec_cen]. + - *Recurrent* + - [centralised][debug_vdn_rec_cen]. ### PettingZoo - Multi-Agent Atari - **MADQN**: a MADQN system running on the two-player competitive Atari Pong environment. - - *Recurrent* [decentralised][pz_madqn_pong_ff_dec]. + - *Recurrent* + - [decentralised][pz_madqn_pong_rec_dec]. + +- **MAPPO**: + a MAPPO system running on two-player cooperative Atari Pong. + - *feedforward* + - [decentralised][pz_mappo_coop_pong_ff_dec]. ### PettingZoo - Multi-Agent Particle Environment - **MADDPG**: a MADDPG system running on the Simple Speaker Listener environment. - - *Feedforward* [decentralised][pz_maddpg_mpe_ssl_ff_dec]. + - *Feedforward* + - [decentralised][pz_maddpg_mpe_ssl_ff_dec]. - **MADDPG**: a MADDPG system running on the Simple Spread environment. - - *Feedforward* [decentralised][pz_maddpg_mpe_ss_ff_dec]. + - *Feedforward* + - [decentralised][pz_maddpg_mpe_ss_ff_dec]. ### SMAC - StarCraft Multi-Agent Challenge @@ -106,17 +121,20 @@ We also include a number of systems running on discrete action space environment - **QMIX**: a QMIX system running on the SMAC environment. - - *Recurrent* [centralised][smac_qmix_rec_cen]. + - *Recurrent* + - [centralised][smac_qmix_rec_cen]. - **VDN**: a VDN system running on the SMAC environment. - - *Recurrent* [centralised][smac_vdn_rec_cen]. + - *Recurrent* + - [centralised][smac_vdn_rec_cen]. ### OpenSpiel - Tic Tac Toe - **MADQN**: a MADQN system running on the OpenSpiel environment. - - *Feedforward* [decentralised][openspiel_madqn_ff_dec]. + - *Feedforward* + - [decentralised][openspiel_madqn_ff_dec]. [quickstart]: https://github.com/instadeepai/Mava/blob/develop/examples/tf/quickstart.ipynb @@ -151,6 +169,7 @@ We also include a number of systems running on discrete action space environment [pz_mad4pg_ff_dec_record]: https://github.com/instadeepai/Mava/blob/develop/examples/petting_zoo/sisl/multiwalker/feedforward/decentralised/run_mad4pg_record.py [pz_mappo_ff_dec]: https://github.com/instadeepai/Mava/blob/develop/examples/petting_zoo/sisl/multiwalker/feedforward/decentralised/run_mappo.py +[robocup_mad4pg_ff_state_based]:https://github.com/instadeepai/Mava/blob/develop/examples/tf/robocup/recurrent/state_based/run_mad4pg.py [debug_mappo_ff_dec]: https://github.com/instadeepai/Mava/blob/develop/examples/tf/debugging/simple_spread/feedforward/decentralised/run_mappo.py [debug_mappo_ff_cen]: https://github.com/instadeepai/Mava/blob/develop/examples/tf/debugging/simple_spread/feedforward/centralised/run_mappo.py @@ -163,12 +182,15 @@ We also include a number of systems running on discrete action space environment [debug_vdn_rec_cen]: https://github.com/instadeepai/Mava/blob/develop/examples/tf/debugging/simple_spread/recurrent/centralised/run_vdn.py -[pz_madqn_pong_rec_dec]: https://github.com/instadeepai/Mava/blob/develop/examples/tf/petting_zoo/atari/pong/recurrent/centralised/run_madqn.py +[pz_madqn_pong_rec_dec]: https://github.com/instadeepai/Mava/blob/develop/examples/tf/petting_zoo/atari/pong/recurrent/decentralised/run_madqn.py + +[pz_mappo_coop_pong_ff_dec]: https://github.com/instadeepai/Mava/blob/develop/examples/tf/petting_zoo/butterfly/cooperative_pong/feedforward/decentralised/run_mappo.py [pz_maddpg_mpe_ssl_ff_dec]: https://github.com/instadeepai/Mava/blob/develop/examples/tf/petting_zoo/mpe/simple_speaker_listener/feedforward/decentralised/run_maddpg.py [pz_maddpg_mpe_ss_ff_dec]: https://github.com/instadeepai/Mava/blob/develop/examples/tf/petting_zoo/mpe/simple_spread/feedforward/decentralised/run_maddpg.py +[smac_madqn_ff_dec]: https://github.com/instadeepai/Mava/blob/develop/examples/tf/smac/feedforward/decentralised/run_madqn.py [smac_madqn_rec_dec]: https://github.com/instadeepai/Mava/blob/develop/examples/tf/smac/recurrent/decentralised/run_madqn.py [smac_qmix_rec_cen]: https://github.com/instadeepai/Mava/blob/develop/examples/tf/smac/recurrent/centralised/run_qmix.py diff --git a/examples/tf/debugging/simple_spread/feedforward/centralised/run_mappo.py b/examples/tf/debugging/simple_spread/feedforward/centralised/run_mappo.py index ca698d0dd..f97708b02 100644 --- a/examples/tf/debugging/simple_spread/feedforward/centralised/run_mappo.py +++ b/examples/tf/debugging/simple_spread/feedforward/centralised/run_mappo.py @@ -88,7 +88,8 @@ def main(_: Any) -> None: network_factory=network_factory, logger_factory=logger_factory, num_executors=1, - optimizer=snt.optimizers.Adam(learning_rate=5e-4), + policy_optimizer=snt.optimizers.Adam(learning_rate=5e-4), + critic_optimizer=snt.optimizers.Adam(learning_rate=5e-4), checkpoint_subpath=checkpoint_dir, max_gradient_norm=40.0, architecture=architectures.CentralisedValueCritic, diff --git a/examples/tf/petting_zoo/butterfly/cooperative_pong/feedforward/decentralised/run_mappo.py b/examples/tf/petting_zoo/butterfly/cooperative_pong/feedforward/decentralised/run_mappo.py new file mode 100644 index 000000000..0c17b820f --- /dev/null +++ b/examples/tf/petting_zoo/butterfly/cooperative_pong/feedforward/decentralised/run_mappo.py @@ -0,0 +1,108 @@ +# python3 +# Copyright 2021 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Example running MAPPO on Cooperative Atari Pong.""" + +import functools +from datetime import datetime +from typing import Any + +import launchpad as lp +import numpy as np +from absl import app, flags +from acme.tf.networks import AtariTorso +from supersuit import dtype_v0 + +from mava.systems.tf import mappo +from mava.utils import lp_utils +from mava.utils.environments import pettingzoo_utils +from mava.utils.loggers import logger_utils + +FLAGS = flags.FLAGS +flags.DEFINE_string( + "env_class", + "butterfly", + "Pettingzoo environment class, e.g. atari (str).", +) +flags.DEFINE_string( + "env_name", + "cooperative_pong_v3", + "Pettingzoo environment name, e.g. pong (str).", +) + +flags.DEFINE_string( + "mava_id", + str(datetime.now()), + "Experiment identifier that can be used to continue experiments.", +) +flags.DEFINE_string("base_dir", "~/mava", "Base dir to store experiments.") + + +def main(_: Any) -> None: + """Run example.""" + + # Environment + environment_factory = functools.partial( + pettingzoo_utils.make_environment, + env_class=FLAGS.env_class, + env_name=FLAGS.env_name, + env_preprocess_wrappers=[(dtype_v0, {"dtype": np.float32})], + ) + + # Networks. + network_factory = lp_utils.partial_kwargs( + mappo.make_default_networks, observation_network=AtariTorso() + ) + + # Checkpointer appends "Checkpoints" to checkpoint_dir + checkpoint_dir = f"{FLAGS.base_dir}/{FLAGS.mava_id}" + + # Log every [log_every] seconds. + log_every = 10 + logger_factory = functools.partial( + logger_utils.make_logger, + directory=FLAGS.base_dir, + to_terminal=True, + to_tensorboard=True, + time_stamp=FLAGS.mava_id, + time_delta=log_every, + ) + + # Distributed program + program = mappo.MAPPO( + environment_factory=environment_factory, + network_factory=network_factory, + logger_factory=logger_factory, + num_executors=1, + checkpoint_subpath=checkpoint_dir, + num_epochs=5, + batch_size=32, + ).build() + + # Ensure only trainer runs on gpu, while other processes run on cpu. + local_resources = lp_utils.to_device( + program_nodes=program.groups.keys(), nodes_on_gpu=["trainer"] + ) + + # Launch. + lp.launch( + program, + lp.LaunchType.LOCAL_MULTI_PROCESSING, + terminal="current_terminal", + local_resources=local_resources, + ) + + +if __name__ == "__main__": + app.run(main) diff --git a/examples/petting_zoo/sisl/multiwalker/feedforward/decentralised/run_mappo.py b/examples/tf/petting_zoo/sisl/multiwalker/feedforward/decentralised/run_mappo.py similarity index 100% rename from examples/petting_zoo/sisl/multiwalker/feedforward/decentralised/run_mappo.py rename to examples/tf/petting_zoo/sisl/multiwalker/feedforward/decentralised/run_mappo.py diff --git a/examples/tf/petting_zoo/sisl/multiwalker/recurrent/decentralised/run_mad4pg.py b/examples/tf/petting_zoo/sisl/multiwalker/recurrent/decentralised/run_mad4pg.py new file mode 100644 index 000000000..7b9c91c7c --- /dev/null +++ b/examples/tf/petting_zoo/sisl/multiwalker/recurrent/decentralised/run_mad4pg.py @@ -0,0 +1,127 @@ +# python3 +# Copyright 2021 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Example running MADDPG on pettinzoo MPE environments.""" + +import functools +from datetime import datetime +from typing import Any + +import launchpad as lp +import sonnet as snt +from absl import app, flags + +from mava.systems.tf import mad4pg +from mava.utils import lp_utils +from mava.utils.enums import ArchitectureType +from mava.utils.environments import pettingzoo_utils +from mava.utils.loggers import logger_utils + +FLAGS = flags.FLAGS + +flags.DEFINE_string( + "env_class", + "sisl", + "Pettingzoo environment class, e.g. atari (str).", +) + +flags.DEFINE_string( + "env_name", + "multiwalker_v7", + "Pettingzoo environment name, e.g. pong (str).", +) +flags.DEFINE_string( + "mava_id", + str(datetime.now()), + "Experiment identifier that can be used to continue experiments.", +) +flags.DEFINE_string("base_dir", "~/mava", "Base dir to store experiments.") + + +def main(_: Any) -> None: + """Run example. + + Args: + _ (Any): None + """ + + # Environment. + environment_factory = functools.partial( + pettingzoo_utils.make_environment, + env_class=FLAGS.env_class, + env_name=FLAGS.env_name, + ) + + # Networks. + network_factory = lp_utils.partial_kwargs( + mad4pg.make_default_networks, + architecture_type=ArchitectureType.recurrent, + vmin=-150, + vmax=150, + num_atoms=101, + ) + + # Checkpointer appends "Checkpoints" to checkpoint_dir. + checkpoint_dir = f"{FLAGS.base_dir}/{FLAGS.mava_id}" + + # Log every [log_every] seconds. + log_every = 10 + logger_factory = functools.partial( + logger_utils.make_logger, + directory=FLAGS.base_dir, + to_terminal=True, + to_tensorboard=True, + time_stamp=FLAGS.mava_id, + time_delta=log_every, + ) + + # Distributed program. + program = mad4pg.MAD4PG( + environment_factory=environment_factory, + network_factory=network_factory, + logger_factory=logger_factory, + num_executors=1, + policy_optimizer=snt.optimizers.Adam(learning_rate=1e-4), + critic_optimizer=snt.optimizers.Adam(learning_rate=1e-4), + checkpoint_subpath=checkpoint_dir, + max_gradient_norm=40.0, + trainer_fn=mad4pg.training.MAD4PGDecentralisedRecurrentTrainer, + executor_fn=mad4pg.execution.MAD4PGRecurrentExecutor, + batch_size=32, + sequence_length=20, + period=20, + min_replay_size=1000, + max_replay_size=100000, + prefetch_size=4, + n_step=5, + samples_per_insert=None, + ).build() + + # Ensure only trainer runs on gpu, while other processes run on cpu. + local_resources = lp_utils.to_device( + program_nodes=program.groups.keys(), nodes_on_gpu=["trainer"] + ) + + # Launch. + lp.launch( + program, + lp.LaunchType.LOCAL_MULTI_PROCESSING, + terminal="current_terminal", + local_resources=local_resources, + ) + + +if __name__ == "__main__": + app.run(main) diff --git a/mava/_metadata.py b/mava/_metadata.py index c34f8b551..59fbb2e4c 100644 --- a/mava/_metadata.py +++ b/mava/_metadata.py @@ -22,7 +22,7 @@ # We follow Semantic Versioning (https://semver.org/) _MAJOR_VERSION = "0" _MINOR_VERSION = "1" -_PATCH_VERSION = "1" +_PATCH_VERSION = "2" # Example: '0.4.2' __version__ = ".".join([_MAJOR_VERSION, _MINOR_VERSION, _PATCH_VERSION]) diff --git a/mava/adders/reverb/base.py b/mava/adders/reverb/base.py index deb99d6a4..b05354c17 100644 --- a/mava/adders/reverb/base.py +++ b/mava/adders/reverb/base.py @@ -94,8 +94,7 @@ def get_trajectory_net_agents( trajectory: Union[Trajectory, mava_types.Transition], trajectory_net_keys: Dict[str, str], ) -> Tuple[List, Dict[str, List]]: - """Returns a dictionary that maps network_keys to a list of agents using that - specific network. + """Maps network_keys to a list of agents using that specific network. Args: trajectory: Episode experience recorded by diff --git a/mava/core_jax.py b/mava/core_jax.py index de17ebe39..8a1693f4b 100644 --- a/mava/core_jax.py +++ b/mava/core_jax.py @@ -69,3 +69,59 @@ def launch( are primarily for debugging name : name of the system """ + + +class SystemBuilder(abc.ABC): + """Abstract system builder.""" + + @abc.abstractmethod + def data_server(self) -> List[Any]: + """Data server to store and serve transition data from and to system. + + Returns: + System data server + """ + + @abc.abstractmethod + def parameter_server(self) -> Any: + """Parameter server to store and serve system network parameters. + + Returns: + System parameter server + """ + + @abc.abstractmethod + def executor( + self, executor_id: str, data_server_client: Any, parameter_server_client: Any + ) -> Any: + """Executor, a collection of agents in an environment to gather experience. + + Args: + executor_id : id to identify the executor process for logging purposes + data_server_client : data server client for pushing transition data + parameter_server_client : parameter server client for pulling parameters + Returns: + System executor + """ + + @abc.abstractmethod + def trainer( + self, trainer_id: str, data_server_client: Any, parameter_server_client: Any + ) -> Any: + """Trainer, a system process for updating agent specific network parameters. + + Args: + trainer_id : id to identify the trainer process for logging purposes + data_server_client : data server client for pulling transition data + parameter_server_client : parameter server client for pushing parameters + Returns: + System trainer + """ + + @abc.abstractmethod + def build(self) -> None: + """Construct program nodes.""" + + @abc.abstractmethod + def launch(self) -> None: + """Run the graph program.""" diff --git a/mava/core_jax_test.py b/mava/core_jax_test.py index 65f9bde48..26190de21 100644 --- a/mava/core_jax_test.py +++ b/mava/core_jax_test.py @@ -16,14 +16,14 @@ """Tests for core Mava interfaces for Jax systems.""" -from typing import Any +from typing import Any, List import pytest -from mava.core_jax import BaseSystem +from mava.core_jax import BaseSystem, SystemBuilder -def test_exception_for_incomplete_child_class() -> None: +def test_exception_for_incomplete_child_system_class() -> None: """Test if error is thrown for missing abstract class overwrites.""" with pytest.raises(TypeError): @@ -41,3 +41,22 @@ def configure(self, **kwargs: Any) -> None: pass TestIncompleteDummySystem() # type: ignore + + +def test_exception_for_incomplete_child_builder_class() -> None: + """Test if error is thrown for missing abstract class overwrites.""" + with pytest.raises(TypeError): + + class TestIncompleteDummySystemBuilder(SystemBuilder): + def data_server(self) -> List[Any]: + pass + + def executor( + self, + executor_id: str, + data_server_client: Any, + parameter_server_client: Any, + ) -> Any: + pass + + TestIncompleteDummySystemBuilder() # type: ignore diff --git a/mava/systems/jax/__init__.py b/mava/systems/jax/__init__.py index d87a16ee5..0dfa26d17 100644 --- a/mava/systems/jax/__init__.py +++ b/mava/systems/jax/__init__.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Imports for Jax-based Mava systems""" +"""Jax-based Mava system implementation.""" +from mava.systems.jax.builder import Builder from mava.systems.jax.config import Config diff --git a/mava/systems/jax/builder.py b/mava/systems/jax/builder.py new file mode 100644 index 000000000..e4bdc1287 --- /dev/null +++ b/mava/systems/jax/builder.py @@ -0,0 +1,40 @@ +# python3 +# Copyright 2021 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Jax-based Mava system builder implementation.""" + +from typing import Any, List + + +class Builder: + def __init__( + self, + components: List[Any], + ) -> None: + """System building init + + Args: + components: system callback components + """ + + self.callbacks = components + + def build(self) -> None: + """Build the system.""" + pass + + def launch(self) -> None: + """Launch the system""" + pass diff --git a/mava/systems/jax/config.py b/mava/systems/jax/config.py index 3691c583e..84f43bae9 100644 --- a/mava/systems/jax/config.py +++ b/mava/systems/jax/config.py @@ -78,6 +78,17 @@ def update(self, **kwargs: Any) -> None: for name, dataclass in kwargs.items(): if is_dataclass(dataclass): if name in list(self._config.keys()): + # When updating a component, the list of current parameter names + # might contain the parameter names of the new component + # with additional new parameter names that still need to be + # checked with other components. Therefore, we first take the + # difference between the current set and the component being + # updated. + self._current_params = list( + set(self._current_params).difference( + list(self._config[name].__dict__.keys()) + ) + ) new_param_names = list(dataclass.__dict__.keys()) if set(self._current_params) & set(new_param_names): raise Exception( @@ -87,7 +98,7 @@ def update(self, **kwargs: Any) -> None: ) else: self._current_params.extend(new_param_names) - self._config[name] = dataclass + self._config[name] = dataclass else: raise Exception( "The given component config is not part of the current \ diff --git a/mava/systems/jax/config_test.py b/mava/systems/jax/config_test.py index da711b6be..81626edfe 100644 --- a/mava/systems/jax/config_test.py +++ b/mava/systems/jax/config_test.py @@ -107,7 +107,7 @@ def test_add_multiple_configs( assert conf.param_1 == 3.8 -def test_add_configs_twice( +def test_add_config_twice( config: Config, dummy_component_config: type, dummy_hyperparameter_config: type ) -> None: """Test add two configs, one after the other. @@ -149,6 +149,28 @@ def test_update_config( assert not hasattr(config, "setting") +def test_update_config_twice( + config: Config, dummy_component_config: type, dummy_hyperparameter_config: type +) -> None: + """Test add two configs, one after the other. + + Args: + config : Mava config + dummy_component_config : component config dataclass + dummy_hyperparameter_config : component config dataclass of hyperparameters + """ + config.add(component=dummy_component_config) + config.update(component=dummy_hyperparameter_config) + config.update(component=dummy_component_config) + config.build() + conf = config.get() + + assert conf.name == "component" + assert conf.setting == 5 + assert not hasattr(config, "param_0") + assert not hasattr(config, "param_1") + + def test_set_existing_parameter_on_the_fly( config: Config, dummy_component_config: type ) -> None: @@ -229,11 +251,11 @@ def test_parameter_setting_that_does_not_exist_exception( config.set_parameters(unknown_param="new_value") -def test_accidental_parameter_override_exception( +def test_accidental_parameter_override_with_add_exception( config: Config, dummy_hyperparameter_config: type ) -> None: """Test that exception is thrown when two component config dataclasses share the \ - same name for a specific hyperparameter. + same name for a specific hyperparameter when adding a new config. Args: config : Mava config @@ -248,3 +270,28 @@ def test_accidental_parameter_override_exception( # as an already existing component parameter name other_hyperparamter_config = SameParameterNameConfig(param_0=2, param_2="param") config.add(other_hyperparameter=other_hyperparamter_config) + + +def test_accidental_parameter_override_with_update_exception( + config: Config, dummy_component_config: type, dummy_hyperparameter_config: type +) -> None: + """Test that exception is thrown when two component config dataclasses share the \ + same name for a specific hyperparameter when updating an existing config. + + Args: + config : Mava config + dummy_component_config : component config dataclass + dummy_hyperparameter_config : component config dataclass of hyperparameters + """ + + with pytest.raises(Exception): + # add component dataclasses and build config + config.add(component_0=dummy_component_config) + config.add(component_1=dummy_hyperparameter_config) + + # add new component dataclass with a parameter of the same name + # as an already existing component parameter name + other_hyperparameter_config = SameParameterNameConfig( + param_0=2, param_2="param" + ) + config.update(component_0=other_hyperparameter_config) diff --git a/mava/systems/jax/system.py b/mava/systems/jax/system.py new file mode 100644 index 000000000..1aae0ad5b --- /dev/null +++ b/mava/systems/jax/system.py @@ -0,0 +1,138 @@ +# python3 +# Copyright 2021 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Jax-based Mava system implementation.""" +import abc +from types import SimpleNamespace +from typing import Any, Callable, List + +from mava.core_jax import BaseSystem +from mava.systems.jax import Builder, Config + + +# TODO(Arnu): replace component types with Callback when class is ready. +class System(BaseSystem): + """General Mava system class for Jax-based systems.""" + + def __init__(self) -> None: + """System Initialisation""" + self._design = self.design() + self.config = Config() # Mava config + self.components: List = [] + + # make config from build + self._make_config() + + def _make_config(self) -> None: + """Private method to construct system config upon initialisation.""" + for component in self._design.__dict__.values(): + comp = component() + input = {comp.name: comp.config} + self.config.add(**input) + + @abc.abstractmethod + def design(self) -> SimpleNamespace: + """System design specifying the list of components to use. + + Returns: + system callback components + """ + + def update(self, component: Any) -> None: + """Update a component that has already been added to the system. + + Args: + component : system callback component + name : component name + """ + comp = component() + name = comp.name + if name in list(self._design.__dict__.keys()): + self._design.__dict__[name] = component + config_feed = {name: comp.config} + self.config.update(**config_feed) + else: + raise Exception( + "The given component is not part of the current system.\ + Perhaps try adding it instead using .add()." + ) + + def add(self, component: Any) -> None: + """Add a new component to the system. + + Args: + component : system callback component + name : component name + """ + comp = component() + name = comp.name + if name in list(self._design.__dict__.keys()): + raise Exception( + "The given component is already part of the current system.\ + Perhaps try updating it instead using .update()." + ) + else: + self._design.__dict__[name] = component + config_feed = {name: comp.config} + self.config.add(**config_feed) + + def configure(self, **kwargs: Any) -> None: + """Configure system hyperparameters.""" + self.config.build() + self.config.set_parameters(**kwargs) + + def launch( + self, + num_executors: int, + nodes_on_gpu: List[str], + multi_process: bool = True, + name: str = "system", + builder_class: Callable = Builder, + ) -> None: + """Run the system. + + Args + num_executors : number of executor processes to run in parallel + nodes_on_gpu : which processes to run on gpu + multi_process : whether to run locally or distributed, local runs are + for debugging + name : name of the system + builder_class: callable builder class. + """ + # build config is not already built + if not self.config._built: + self.config.build() + + # update distributor config + self.config.set_parameters( + num_executors=num_executors, + nodes_on_gpu=nodes_on_gpu, + multi_process=multi_process, + name=name, + ) + + # get system config to feed to component list to update hyperparamete settings + system_config = self.config.get() + + # update default system component configs + for component in self._design.__dict__.values(): + self.components.append(component(system_config)) + + # Build system + self._builder = builder_class(components=self.components) + self._builder.build() + + # Launch system + self._builder.launch() diff --git a/mava/systems/jax/system_test.py b/mava/systems/jax/system_test.py new file mode 100644 index 000000000..da47fc8b6 --- /dev/null +++ b/mava/systems/jax/system_test.py @@ -0,0 +1,477 @@ +# python3 +# Copyright 2021 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Tests for Jax-based Mava system implementation.""" +from dataclasses import dataclass, field +from types import SimpleNamespace +from typing import Any, Callable, List + +import pytest + +from mava.systems.jax import Builder +from mava.systems.jax.system import System + +# Mock callbacks + + +class MockCallbackHookMixin: + + callbacks: List + + def dummy_int_plus_str(self) -> None: + """Called when the builder calls hook.""" + for callback in self.callbacks: + callback.dummy_int_plus_str(self) + + def dummy_float_plus_bool(self) -> None: + """Called when the builder calls hook.""" + for callback in self.callbacks: + callback.dummy_float_plus_bool(self) + + def dummy_str_plus_bool(self) -> None: + """Called when the builder calls hook.""" + for callback in self.callbacks: + callback.dummy_str_plus_bool(self) + + +# Mock builder +class MockBuilder(Builder, MockCallbackHookMixin): + def __init__(self, components: List[Any]) -> None: + """Init for mock builder. + + Args: + components : List of components. + """ + super().__init__(components) + + def add_different_data_types(self) -> None: + """Hooks for adding different data types.""" + + self.int_plus_str = 0 + self.float_plus_bool = 0.0 + self.str_plus_bool = 0 + + self.dummy_int_plus_str() + self.dummy_float_plus_bool() + self.dummy_str_plus_bool() + + +class MockCallback: + def dummy_int_plus_str(self, builder: MockBuilder) -> None: + """Dummy hook.""" + pass + + def dummy_float_plus_bool(self, builder: MockBuilder) -> None: + """Dummy hook.""" + pass + + def dummy_str_plus_bool(self, builder: MockBuilder) -> None: + """Dummy hook.""" + pass + + +# Mock components +class MainComponent: + @property + def name(self) -> str: + """Component type name, e.g. 'dataset' or 'executor'. + + Returns: + Component type name + """ + return "main_component" + + +class SubComponent: + @property + def name(self) -> str: + """Component type name, e.g. 'dataset' or 'executor'. + + Returns: + Component type name + """ + return "sub_component" + + +@dataclass +class ComponentZeroDefaultConfig: + param_0: int = 1 + param_1: str = "1" + + +class ComponentZero(MockCallback, MainComponent): + def __init__( + self, config: ComponentZeroDefaultConfig = ComponentZeroDefaultConfig() + ) -> None: + """Mock system component. + + Args: + config : dataclass configuration for setting component hyperparameters + """ + self.config = config + + def dummy_int_plus_str(self, builder: MockBuilder) -> None: + """Dummy component function. + + Returns: + config int plus string cast to int + """ + builder.int_plus_str = self.config.param_0 + int(self.config.param_1) + + +@dataclass +class ComponentOneDefaultConfig: + param_2: float = 1.2 + param_3: bool = True + + +class ComponentOne(MockCallback, SubComponent): + def __init__( + self, config: ComponentOneDefaultConfig = ComponentOneDefaultConfig() + ) -> None: + """Mock system component. + + Args: + config : dataclass configuration for setting component hyperparameters + """ + self.config = config + + def dummy_float_plus_bool(self, builder: MockBuilder) -> None: + """Dummy component function. + + Returns: + float plus boolean cast as float + """ + builder.float_plus_bool = self.config.param_2 + float(self.config.param_3) + + +@dataclass +class ComponentTwoDefaultConfig: + param_4: str = "2" + param_5: bool = True + + +class ComponentTwo(MockCallback, MainComponent): + def __init__( + self, config: ComponentTwoDefaultConfig = ComponentTwoDefaultConfig() + ) -> None: + """Mock system component. + + Args: + config : dataclass configuration for setting component hyperparameters + """ + self.config = config + + def dummy_str_plus_bool(self, builder: MockBuilder) -> None: + """Dummy component function. + + Returns: + string cast as int plus boolean cast as in + """ + builder.str_plus_bool = int(self.config.param_4) + int(self.config.param_5) + + +@dataclass +class DistributorDefaultConfig: + num_executors: int = 1 + nodes_on_gpu: List[str] = field(default_factory=list) + multi_process: bool = True + name: str = "system" + + +class MockDistributorComponent(MockCallback): + def __init__( + self, config: DistributorDefaultConfig = DistributorDefaultConfig() + ) -> None: + """Mock system distributor component. + + Args: + config : dataclass configuration for setting component hyperparameters + """ + self.config = config + + @property + def name(self) -> str: + """Component type name, e.g. 'dataset' or 'executor'. + + Returns: + Component type name + """ + return "distributor" + + +# Test Systems +class TestSystem(System): + def launch( + self, + num_executors: int, + nodes_on_gpu: List[str], + multi_process: bool = True, + name: str = "system", + builder_class: Callable = MockBuilder, + ) -> None: + """Run the system. + + Args: + config : system configuration including + num_executors : number of executor processes to run in parallel + nodes_on_gpu : which processes to run on gpu + multi_process : whether to run locally or distributed, local runs are + for debugging + name : name of the system + builder_class: callable builder class. + """ + return super().launch( + num_executors, nodes_on_gpu, multi_process, name, builder_class + ) + + +class TestSystemWithZeroComponents(TestSystem): + def design(self) -> SimpleNamespace: + """Mock system design with zero components. + + Returns: + system callback components + """ + components = SimpleNamespace(distributor=MockDistributorComponent) + return components + + +class TestSystemWithOneComponent(TestSystem): + def design(self) -> SimpleNamespace: + """Mock system design with one component. + + Returns: + system callback components + """ + components = SimpleNamespace( + main_component=ComponentZero, distributor=MockDistributorComponent + ) + return components + + +class TestSystemWithTwoComponents(TestSystem): + def design(self) -> SimpleNamespace: + """Mock system design with two components. + + Returns: + system callback components + """ + components = SimpleNamespace( + main_component=ComponentZero, + sub_component=ComponentOne, + distributor=MockDistributorComponent, + ) + return components + + +# Test fixtures +@pytest.fixture +def system_with_zero_components() -> System: + """Dummy system with zero components. + + Returns: + mock system + """ + return TestSystemWithZeroComponents() + + +@pytest.fixture +def system_with_one_component() -> System: + """Dummy system with one component. + + Returns: + mock system + """ + return TestSystemWithOneComponent() + + +@pytest.fixture +def system_with_two_components() -> System: + """Dummy system with two components. + + Returns: + mock system + """ + return TestSystemWithTwoComponents() + + +# Tests +def test_system_launch_without_configure( + system_with_two_components: System, +) -> None: + """Test if system can launch without having had changed (configured) the default \ + config. + + Args: + system_with_two_components : mock system + """ + system_with_two_components.launch(num_executors=1, nodes_on_gpu=["process"]) + system_with_two_components._builder.add_different_data_types() + assert system_with_two_components._builder.int_plus_str == 2 + assert system_with_two_components._builder.float_plus_bool == 2.2 + assert system_with_two_components._builder.str_plus_bool == 0 + + +def test_system_launch_with_configure( + system_with_two_components: System, +) -> None: + """Test if system can launch having had changed (configured) the default \ + config. + + Args: + system_with_two_components : mock system + """ + system_with_two_components.configure(param_0=2, param_3=False) + system_with_two_components.launch(num_executors=1, nodes_on_gpu=["process"]) + system_with_two_components._builder.add_different_data_types() + assert system_with_two_components._builder.int_plus_str == 3 + assert system_with_two_components._builder.float_plus_bool == 1.2 + assert system_with_two_components._builder.str_plus_bool == 0 + + +def test_system_update_with_existing_component( + system_with_two_components: System, +) -> None: + """Test if system can update existing component. + + Args: + system_with_two_components : mock system + """ + system_with_two_components.update(ComponentTwo) + system_with_two_components.launch(num_executors=1, nodes_on_gpu=["process"]) + system_with_two_components._builder.add_different_data_types() + assert system_with_two_components._builder.int_plus_str == 0 + assert system_with_two_components._builder.float_plus_bool == 2.2 + assert system_with_two_components._builder.str_plus_bool == 3 + + +def test_system_update_with_non_existing_component( + system_with_one_component: System, +) -> None: + """Test if system raises an error when trying to update a component that has not \ + yet been added to the system. + + Args: + system_with_one_component : mock system + """ + with pytest.raises(Exception): + system_with_one_component.update(ComponentOne) + + +def test_system_add_with_existing_component(system_with_one_component: System) -> None: + """Test if system raises an error when trying to add a component that has already \ + been added to the system, i.e. we don't want to overwrite a component by \ + mistake. + + Args: + system_with_one_component : mock system + """ + with pytest.raises(Exception): + system_with_one_component.add(ComponentTwo) + + +def test_system_add_with_non_existing_component( + system_with_one_component: System, +) -> None: + """Test if system can add a new component. + + Args: + system_with_one_component : mock system + """ + system_with_one_component.add(ComponentOne) + system_with_one_component.launch(num_executors=1, nodes_on_gpu=["process"]) + system_with_one_component._builder.add_different_data_types() + assert system_with_one_component._builder.int_plus_str == 2 + assert system_with_one_component._builder.float_plus_bool == 2.2 + assert system_with_one_component._builder.str_plus_bool == 0 + + +def test_system_update_twice(system_with_two_components: System) -> None: + """Test if system can update a component twice. + + Args: + system_with_two_components : mock system + """ + system_with_two_components.update(ComponentTwo) + system_with_two_components.update(ComponentZero) + system_with_two_components.launch(num_executors=1, nodes_on_gpu=["process"]) + system_with_two_components._builder.add_different_data_types() + assert system_with_two_components._builder.int_plus_str == 2 + assert system_with_two_components._builder.float_plus_bool == 2.2 + assert system_with_two_components._builder.str_plus_bool == 0 + + +def test_system_add_twice(system_with_zero_components: System) -> None: + """Test if system can add two components. + + Args: + system_with_zero_components : mock system + """ + system_with_zero_components.add(ComponentOne) + system_with_zero_components.add(ComponentTwo) + system_with_zero_components.launch(num_executors=1, nodes_on_gpu=["process"]) + system_with_zero_components._builder.add_different_data_types() + assert system_with_zero_components._builder.int_plus_str == 0 + assert system_with_zero_components._builder.float_plus_bool == 2.2 + assert system_with_zero_components._builder.str_plus_bool == 3 + + +def test_system_add_and_update(system_with_zero_components: System) -> None: + """Test if system can add and then update a component. + + Args: + system_with_zero_components : mock system + """ + system_with_zero_components.add(ComponentZero) + system_with_zero_components.update(ComponentTwo) + system_with_zero_components.launch(num_executors=1, nodes_on_gpu=["process"]) + system_with_zero_components._builder.add_different_data_types() + assert system_with_zero_components._builder.int_plus_str == 0 + assert system_with_zero_components._builder.float_plus_bool == 0 + assert system_with_zero_components._builder.str_plus_bool == 3 + + +def test_system_configure_one_component_params( + system_with_two_components: System, +) -> None: + """Test if system can configure a single component's parameters. + + Args: + system_with_two_components : mock system + """ + system_with_two_components.configure(param_0=2, param_1="2") + system_with_two_components.launch(num_executors=1, nodes_on_gpu=["process"]) + system_with_two_components._builder.add_different_data_types() + assert system_with_two_components._builder.int_plus_str == 4 + assert system_with_two_components._builder.float_plus_bool == 2.2 + assert system_with_two_components._builder.str_plus_bool == 0 + + +def test_system_configure_two_component_params( + system_with_two_components: System, +) -> None: + """Test if system can configure multiple component parameters. + + Args: + system_with_two_components : mock system + """ + system_with_two_components.configure(param_0=2, param_3=False) + system_with_two_components.launch(num_executors=1, nodes_on_gpu=["process"]) + system_with_two_components._builder.add_different_data_types() + assert system_with_two_components._builder.int_plus_str == 3 + assert system_with_two_components._builder.float_plus_bool == 1.2 + assert system_with_two_components._builder.str_plus_bool == 0 diff --git a/mava/systems/tf/mad4pg/training.py b/mava/systems/tf/mad4pg/training.py index 3237fa245..4fab17c15 100644 --- a/mava/systems/tf/mad4pg/training.py +++ b/mava/systems/tf/mad4pg/training.py @@ -197,7 +197,7 @@ def _forward(self, inputs: reverb.ReplaySample) -> None: critic_loss = losses.categorical( q_tm1, r_t[agent], discount * d_t[agent], q_t ) - self.critic_losses[agent] = tf.reduce_mean(critic_loss, axis=0) + self.critic_losses[agent] = tf.reduce_mean(critic_loss) # Actor learning. o_t_agent_feed = o_t_trans[agent] dpg_a_t = self._policy_networks[agent_key](o_t_agent_feed) @@ -214,13 +214,13 @@ def _forward(self, inputs: reverb.ReplaySample) -> None: clip_norm = True if self._max_gradient_norm is not None else False policy_loss = losses.dpg( - dpg_q_t, - dpg_a_t, + q_max=dpg_q_t, + a_max=dpg_a_t, tape=tape, dqda_clipping=dqda_clipping, clip_norm=clip_norm, ) - self.policy_losses[agent] = tf.reduce_mean(policy_loss, axis=0) + self.policy_losses[agent] = tf.reduce_mean(policy_loss) self.tape = tape @@ -594,7 +594,7 @@ def _forward(self, inputs: reverb.ReplaySample) -> None: data: Trajectory = inputs.data # Note (dries): The unused variable is start_of_episodes. - observations, actions, rewards, discounts, _, extras = ( + observations, actions, rewards, end_of_episode, _, extras = ( data.observations, data.actions, data.rewards, @@ -660,19 +660,24 @@ def _forward(self, inputs: reverb.ReplaySample) -> None: # Cast the additional discount to match # the environment discount dtype. - agent_discount = discounts[agent] + agent_discount = end_of_episode[agent] discount = tf.cast(self._discount, dtype=agent_discount.dtype) + agent_end_of_episode = end_of_episode[agent] + ones_mask = tf.ones(shape=(agent_end_of_episode.shape[0], 1)) + step_not_padded = tf.concat( + [ones_mask, agent_end_of_episode[:, :-1]], axis=1 + ) # Critic loss. critic_loss = recurrent_n_step_critic_loss( - q_values, - target_q_values, - rewards[agent], - discount * agent_discount, + q_values=q_values, + target_q_values=target_q_values, + rewards=rewards[agent], + discounts=discount * agent_discount, bootstrap_n=self._bootstrap_n, loss_fn=losses.categorical, ) - self.critic_losses[agent] = tf.reduce_mean(critic_loss, axis=0) + self.critic_losses[agent] = tf.reduce_mean(critic_loss) # Actor learning. obs_agent_feed = target_obs_trans[agent] @@ -718,7 +723,11 @@ def _forward(self, inputs: reverb.ReplaySample) -> None: dqda_clipping=dqda_clipping, clip_norm=clip_norm, ) - self.policy_losses[agent] = tf.reduce_mean(policy_loss, axis=0) + policy_mask = tf.reshape(step_not_padded, policy_loss.shape) + policy_loss = policy_loss * policy_mask + self.policy_losses[agent] = tf.reduce_sum(policy_loss) / tf.reduce_sum( + policy_mask + ) self.tape = tape diff --git a/mava/systems/tf/maddpg/execution.py b/mava/systems/tf/maddpg/execution.py index 576bc63c9..8d43f8351 100644 --- a/mava/systems/tf/maddpg/execution.py +++ b/mava/systems/tf/maddpg/execution.py @@ -21,6 +21,7 @@ import sonnet as snt import tensorflow as tf import tensorflow_probability as tfp +import tree from acme import types from acme.specs import EnvironmentSpec @@ -93,7 +94,6 @@ def __init__( variable_client=variable_client, ) - @tf.function def _policy( self, agent: str, observation: types.NestedTensor ) -> types.NestedTensor: @@ -147,12 +147,19 @@ def select_action( # Step the recurrent policy/value network forward # given the current observation and state. action, policy = self._policy(agent, observation.observation) - - # Return a numpy array with squeezed out batch dimension. - action = tf2_utils.to_numpy_squeeze(action) - policy = tf2_utils.to_numpy_squeeze(policy) return action, policy + @tf.function + def _select_actions( + self, observations: Dict[str, types.NestedArray] + ) -> Tuple[Dict[str, types.NestedArray], Dict[str, types.NestedArray]]: + """Select the actions for all agents in the system""" + actions = {} + policies = {} + for agent, observation in observations.items(): + actions[agent], policies[agent] = self.select_action(agent, observation) + return actions, policies + def select_actions( self, observations: Dict[str, types.NestedArray] ) -> Tuple[Dict[str, types.NestedArray], Dict[str, types.NestedArray]]: @@ -166,10 +173,9 @@ def select_actions( actions and policies for all agents in the system. """ - actions = {} - policies = {} - for agent, observation in observations.items(): - actions[agent], policies[agent] = self.select_action(agent, observation) + actions, policies = self._select_actions(observations) + actions = tree.map_structure(tf2_utils.to_numpy_squeeze, actions) + policies = tree.map_structure(tf2_utils.to_numpy_squeeze, policies) return actions, policies def observe_first( @@ -284,7 +290,6 @@ def __init__( store_recurrent_state=store_recurrent_state, ) - @tf.function def _policy( self, agent: str, @@ -322,41 +327,36 @@ def _policy( raise NotImplementedError return action, policy, new_state - def select_action( - self, agent: str, observation: types.NestedArray - ) -> types.NestedArray: - """select an action for a single agent in the system + def select_actions( + self, observations: Dict[str, types.NestedArray] + ) -> Tuple[Dict[str, types.NestedArray], Dict[str, types.NestedArray]]: + """select the actions for all agents in the system + Args: - agent: agent id - observation: observation tensor received from the + observations: agent observations from the environment. + Returns: - action and policy. + actions and policies for all agents in the system. """ - # TODO Mask actions here using observation.legal_actions - # Initialize the RNN state if necessary. - if self._states[agent] is None: - # index network either on agent type or on agent id - agent_key = self._agent_net_keys[agent] - self._states[agent] = self._policy_networks[agent_key].initia_state(1) - - # Step the recurrent policy forward given the current observation and state. - action, policy, new_state = self._policy( - agent, observation.observation, self._states[agent] + actions, policies, self._states = self._select_actions( + observations, self._states ) + actions = tree.map_structure(tf2_utils.to_numpy_squeeze, actions) + policies = tree.map_structure(tf2_utils.to_numpy_squeeze, policies) + return actions, policies - # Bookkeeping of recurrent states for the observe method. - self._update_state(agent, new_state) - - # Return a numpy array with squeezed out batch dimension. - action = tf2_utils.to_numpy_squeeze(action) - policy = tf2_utils.to_numpy_squeeze(policy) - return action, policy - - def select_actions( - self, observations: Dict[str, types.NestedArray] - ) -> Tuple[Dict[str, types.NestedArray], Dict[str, types.NestedArray]]: + @tf.function + def _select_actions( + self, + observations: Dict[str, types.NestedArray], + states: Dict[str, types.NestedArray], + ) -> Tuple[ + Dict[str, types.NestedArray], + Dict[str, types.NestedArray], + Dict[str, types.NestedArray], + ]: """select the actions for all agents in the system Args: observations: agent observations from the @@ -367,9 +367,32 @@ def select_actions( actions = {} policies = {} + new_states = {} for agent, observation in observations.items(): - actions[agent], policies[agent] = self.select_action(agent, observation) - return actions, policies + actions[agent], policies[agent], new_states[agent] = self.select_action( + agent, observation, states[agent] + ) + return actions, policies, new_states + + def select_action( # type: ignore + self, + agent: str, + observation: types.NestedArray, + agent_state: types.NestedArray, + ) -> Tuple[types.NestedArray, types.NestedArray, types.NestedArray]: + """select an action for a single agent in the system + Args: + agent: agent id + observation: observation tensor received from the + environment. + Returns: + action and policy. + """ + # Step the recurrent policy forward given the current observation and state. + action, policy, new_state = self._policy( + agent, observation.observation, agent_state + ) + return action, policy, new_state def observe_first( self, diff --git a/mava/systems/tf/maddpg/system.py b/mava/systems/tf/maddpg/system.py index 383524417..baecaa884 100644 --- a/mava/systems/tf/maddpg/system.py +++ b/mava/systems/tf/maddpg/system.py @@ -110,6 +110,7 @@ def __init__( # noqa learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, ): """Initialise the system + Args: environment_factory: function to instantiate an environment. @@ -226,11 +227,14 @@ def __init__( # noqa if type(network_sampling_setup) is not list: if network_sampling_setup == enums.NetworkSampler.fixed_agent_networks: - # if no network_sampling_setup is fixed, use shared_weights to - # determine setup + # if no network_sampling_setup is specified, assign a single network + # to all agents of the same type if weights are shared + # else assign seperate networks to each agent self._agent_net_keys = { - agent: "network_0" if shared_weights else f"network_{i}" - for i, agent in enumerate(agents) + agent: f"network_{agent.split('_')[0]}" + if shared_weights + else f"network_{agent}" + for agent in agents } self._network_sampling_setup = [ [ @@ -260,7 +264,6 @@ def __init__( # noqa raise ValueError( "network_sampling_setup must be a dict or fixed_agent_networks" ) - else: # if a dictionary is provided, use network_sampling_setup to determine setup _, self._agent_net_keys = sample_new_agent_keys( @@ -402,7 +405,8 @@ def __init__( # noqa ) def _get_extra_specs(self) -> Any: - """helper to establish specs for extra information + """Helper to establish specs for extra information + Returns: dictionary containing extra specs """ @@ -425,8 +429,10 @@ def _get_extra_specs(self) -> Any: def replay(self) -> Any: """Step counter + Args: checkpoint: whether to checkpoint the counter. + Returns: step counter object. """ @@ -487,11 +493,13 @@ def executor( variable_source: acme.VariableSource, ) -> mava.ParallelEnvironmentLoop: """System executor + Args: executor_id: id to identify the executor process for logging purposes. replay: replay data table to push data to. variable_source: variable server for updating network variables. + Returns: mava.ParallelEnvironmentLoop: environment-executor loop instance. """ @@ -538,10 +546,12 @@ def evaluator( logger: loggers.Logger = None, ) -> Any: """System evaluator (an executor process not connected to a dataset) + Args: variable_source: variable server for updating network variables. logger: logger object. + Returns: environment-executor evaluation loop instance for evaluating the performance of a system. @@ -588,11 +598,13 @@ def trainer( variable_source: MavaVariableSource, ) -> mava.core.Trainer: """System trainer + Args: trainer_id: Id of the trainer being created. replay: replay data table to pull data from. variable_source: variable server for updating network variables. + Returns: system trainer. """ @@ -622,8 +634,10 @@ def trainer( def build(self, name: str = "maddpg") -> Any: """Build the distributed system as a graph program. + Args: name: system name. + Returns: graph program for distributed system training. """ diff --git a/mava/systems/tf/maddpg/training.py b/mava/systems/tf/maddpg/training.py index 8fb792fa6..8c169aafd 100644 --- a/mava/systems/tf/maddpg/training.py +++ b/mava/systems/tf/maddpg/training.py @@ -418,7 +418,7 @@ def _forward(self, inputs: reverb.ReplaySample) -> None: critic_loss = trfl.td_learning( q_tm1, r_t[agent], discount * d_t[agent], q_t ).loss - self.critic_losses[agent] = tf.reduce_mean(critic_loss, axis=0) + self.critic_losses[agent] = tf.reduce_mean(critic_loss) # Actor learning. o_t_agent_feed = o_t_trans[agent] @@ -442,7 +442,7 @@ def _forward(self, inputs: reverb.ReplaySample) -> None: clip_norm=clip_norm, ) - self.policy_losses[agent] = tf.reduce_mean(policy_loss, axis=0) + self.policy_losses[agent] = tf.reduce_mean(policy_loss) self.tape = tape # Backward pass that calculates gradients and updates network. @@ -1262,7 +1262,7 @@ def _forward(self, inputs: reverb.ReplaySample) -> None: data: Trajectory = inputs.data # Note (dries): The unused variable is start_of_episodes. - observations, actions, rewards, discounts, _, extras = ( + observations, actions, rewards, end_of_episode, _, extras = ( data.observations, data.actions, data.rewards, @@ -1333,20 +1333,25 @@ def _forward(self, inputs: reverb.ReplaySample) -> None: # Cast the additional discount to match # the environment discount dtype. - agent_discount = discounts[agent] + agent_discount = end_of_episode[agent] discount = tf.cast(self._discount, dtype=agent_discount.dtype) + agent_end_of_episode = end_of_episode[agent] + ones_mask = tf.ones(shape=(agent_end_of_episode.shape[0], 1)) + step_not_padded = tf.concat( + [ones_mask, agent_end_of_episode[:, :-1]], axis=1 + ) # Critic loss. critic_loss = recurrent_n_step_critic_loss( - q_values, - rewards[agent], - discount * agent_discount, - target_q_values, + q_values=q_values, + target_q_values=target_q_values, + rewards=rewards[agent], + discounts=discount * agent_discount, bootstrap_n=self._bootstrap_n, loss_fn=trfl.td_learning, ) - self.critic_losses[agent] = tf.reduce_mean(critic_loss, axis=0) + self.critic_losses[agent] = tf.reduce_mean(critic_loss) # Actor learning. obs_agent_feed = target_obs_trans[agent] @@ -1386,13 +1391,19 @@ def _forward(self, inputs: reverb.ReplaySample) -> None: clip_norm = True if self._max_gradient_norm is not None else False policy_loss = losses.dpg( - dpg_q_values, - dpg_actions_comb, + q_max=dpg_q_values, + a_max=dpg_actions_comb, tape=tape, dqda_clipping=dqda_clipping, clip_norm=clip_norm, ) - self.policy_losses[agent] = tf.reduce_mean(policy_loss, axis=0) + + # Multiply by discounts to not train on padded data. + policy_mask = tf.reshape(step_not_padded, policy_loss.shape) + policy_loss = policy_loss * policy_mask + self.policy_losses[agent] = tf.reduce_sum(policy_loss) / tf.reduce_sum( + policy_mask + ) self.tape = tape # Backward pass that calculates gradients and updates network. diff --git a/mava/systems/tf/madqn/system.py b/mava/systems/tf/madqn/system.py index f26397fb0..cacf23595 100644 --- a/mava/systems/tf/madqn/system.py +++ b/mava/systems/tf/madqn/system.py @@ -214,11 +214,14 @@ def __init__( # noqa if type(network_sampling_setup) is not list: if network_sampling_setup == enums.NetworkSampler.fixed_agent_networks: - # if no network_sampling_setup is fixed, use shared_weights to - # determine setup + # if no network_sampling_setup is specified, assign a single network + # to all agents of the same type if weights are shared + # else assign seperate networks to each agent self._agent_net_keys = { - agent: "network_0" if shared_weights else f"network_{i}" - for i, agent in enumerate(agents) + agent: f"network_{agent.split('_')[0]}" + if shared_weights + else f"network_{agent}" + for agent in agents } self._network_sampling_setup = [ [ @@ -255,7 +258,6 @@ def __init__( # noqa agents, self._network_sampling_setup, # type: ignore ) - # Check that the environment and agent_net_keys has the same amount of agents sample_length = len(self._network_sampling_setup[0]) # type: ignore assert len(environment_spec.get_agent_ids()) == len(self._agent_net_keys.keys()) diff --git a/mava/systems/tf/mappo/builder.py b/mava/systems/tf/mappo/builder.py index 3cc3321a0..54c9187d1 100644 --- a/mava/systems/tf/mappo/builder.py +++ b/mava/systems/tf/mappo/builder.py @@ -40,7 +40,9 @@ class MAPPOConfig: Args: environment_spec: description of the action and observation spaces etc. for each agent in the system. - optimizer: optimizer(s) for updating networks. + policy_optimizer: optimizer(s) for updating policy networks. + critic_optimizer: optimizer for updating critic networks. This is not + used if using single optim. agent_net_keys: (dict, optional): specifies what network each agent uses. Defaults to {}. checkpoint_minute_interval (int): The number of minutes to wait between @@ -74,7 +76,8 @@ class MAPPOConfig: """ environment_spec: specs.EnvironmentSpec - optimizer: Union[snt.Optimizer, Dict[str, snt.Optimizer]] + policy_optimizer: Union[snt.Optimizer, Dict[str, snt.Optimizer]] + critic_optimizer: snt.Optimizer agent_net_keys: Dict[str, str] checkpoint_minute_interval: int sequence_length: int = 10 @@ -319,7 +322,8 @@ def make_trainer( critic_networks=critic_networks, dataset=dataset, agent_net_keys=agent_net_keys, - optimizer=self._config.optimizer, + critic_optimizer=self._config.critic_optimizer, + policy_optimizer=self._config.policy_optimizer, minibatch_size=self._config.minibatch_size, num_epochs=self._config.num_epochs, discount=self._config.discount, diff --git a/mava/systems/tf/mappo/networks.py b/mava/systems/tf/mappo/networks.py index 3f758960f..be6e60e8a 100644 --- a/mava/systems/tf/mappo/networks.py +++ b/mava/systems/tf/mappo/networks.py @@ -39,6 +39,7 @@ def make_default_networks( 256, ), critic_networks_layer_sizes: Union[Dict[str, Sequence], Sequence] = (512, 512, 256), + observation_network: snt.Module = None, seed: Optional[int] = None, ) -> Dict[str, snt.Module]: """Default networks for mappo. @@ -81,7 +82,10 @@ def make_default_networks( for key in specs.keys(): # Create the shared observation network; here simply a state-less operation. - observation_network = tf2_utils.to_sonnet_module(tf.identity) + if observation_network is None: + observation_network = tf2_utils.to_sonnet_module(tf.identity) + else: + observation_network = observation_network # Note: The discrete case must be placed first as it inherits from BoundedArray. if isinstance(specs[key].actions, dm_env.specs.DiscreteArray): # discrete diff --git a/mava/systems/tf/mappo/system.py b/mava/systems/tf/mappo/system.py index 7fab6dde4..c2df8e4b5 100644 --- a/mava/systems/tf/mappo/system.py +++ b/mava/systems/tf/mappo/system.py @@ -57,7 +57,10 @@ def __init__( shared_weights: bool = True, agent_net_keys: Dict[str, str] = {}, executor_variable_update_period: int = 100, - optimizer: Union[snt.Optimizer, Dict[str, snt.Optimizer]] = snt.optimizers.Adam( + policy_optimizer: Union[ + snt.Optimizer, Dict[str, snt.Optimizer] + ] = snt.optimizers.Adam(learning_rate=5e-4), + critic_optimizer: Optional[snt.Optimizer] = snt.optimizers.Adam( learning_rate=5e-4 ), discount: float = 0.99, @@ -82,7 +85,7 @@ def __init__( train_loop_fn_kwargs: Dict = {}, eval_loop_fn_kwargs: Dict = {}, evaluator_interval: Optional[dict] = None, - learning_rate_scheduler_fn: Optional[Callable[[int], None]] = None, + learning_rate_scheduler_fn: Optional[Dict[str, Callable[[int], None]]] = None, normalize_advantage: bool = False, ): """Initialise the system @@ -114,8 +117,10 @@ def __init__( Defaults to {}. executor_variable_update_period : number of steps before updating executor variables from the variable source. Defaults to 100. - optimizer : optimizer(s) for updating networks. + policy_optimizer : optimizer(s) for updating policy networks. Defaults to snt.optimizers.Adam(learning_rate=5e-4). + critic_optimizer : optimizer for updating critic + networks. This is not used if using single optim. discount : discount factor to use for TD updates. Defaults to 0.99. lambda_gae : scalar determining the mix of bootstrapping @@ -160,8 +165,13 @@ def __init__( to the training loop. Defaults to {}. eval_loop_fn_kwargs: possible keyword arguments to send to the evaluation loop. Defaults to {}. - learning_rate_scheduler_fn: an optional learning rate scheduler for - the optimiser. + learning_rate_scheduler_fn: dict with two functions/classes (one for the + policy and one for the critic optimizer), that takes in a trainer + step t and returns the current learning rate, + e.g. {"policy": policy_lr_schedule ,"critic": critic_lr_schedule}. + See + examples/debugging/simple_spread/feedforward/decentralised/run_maddpg_lr_schedule.py + for an example. evaluator_interval: An optional condition that is used to evaluate/test system performance after [evaluator_interval] condition has been met. If None, evaluation will @@ -225,12 +235,15 @@ def __init__( self._network_factory = network_factory self._logger_factory = logger_factory self._environment_spec = environment_spec - # Setup agent networks + # Setup agent networks to assign a single network to all agents of the same type + # if weights are shared else assign separate networks to each agent self._agent_net_keys = agent_net_keys if not agent_net_keys: agents = environment_spec.get_agent_ids() self._agent_net_keys = { - agent: agent.split("_")[0] if shared_weights else agent + agent: f"network_{agent.split('_')[0]}" + if shared_weights + else f"network_{agent}" for agent in agents } self._num_exectors = num_executors @@ -264,7 +277,8 @@ def __init__( sequence_length=self._sequence_length, sequence_period=self._sequence_period, checkpoint=checkpoint, - optimizer=optimizer, + policy_optimizer=policy_optimizer, + critic_optimizer=critic_optimizer, checkpoint_subpath=checkpoint_subpath, checkpoint_minute_interval=checkpoint_minute_interval, evaluator_interval=evaluator_interval, diff --git a/mava/systems/tf/mappo/training.py b/mava/systems/tf/mappo/training.py index 9e0afc7a0..5733cfbdf 100644 --- a/mava/systems/tf/mappo/training.py +++ b/mava/systems/tf/mappo/training.py @@ -53,7 +53,8 @@ def __init__( policy_networks: Dict[str, snt.Module], critic_networks: Dict[str, snt.Module], dataset: tf.data.Dataset, - optimizer: Union[snt.Optimizer, Dict[str, snt.Optimizer]], + policy_optimizer: Union[snt.Optimizer, Dict[str, snt.Optimizer]], + critic_optimizer: Optional[Union[snt.Optimizer, Dict[str, snt.Optimizer]]], agent_net_keys: Dict[str, str], checkpoint_minute_interval: int, minibatch_size: Optional[int] = None, @@ -83,7 +84,11 @@ def __init__( critic_networks: critic network(s), shared or for each agent in the system. dataset (tf.data.Dataset): training dataset. - optimizer: optimizer for updating policy networks. + policy_optimizer: optimizer + for updating policy networks. + critic_optimizer: optimizer + for updating critic networks. This is not used if using + single optim. agent_net_keys: specifies what network each agent uses. Defaults to {}. checkpoint_minute_interval: The number of minutes to wait between @@ -137,12 +142,16 @@ def __init__( self._normalize_advantage = normalize_advantage # Create optimizers for different agent types. - if not isinstance(optimizer, dict): - self._optimizer: Dict[str, snt.Optimizer] = {} + if not isinstance(policy_optimizer, dict): + self._policy_optimizers: Dict[str, snt.Optimizer] = {} for agent in self.unique_net_keys: - self._optimizer[agent] = copy.deepcopy(optimizer) + self._policy_optimizers[agent] = copy.deepcopy(policy_optimizer) else: - self._optimizer = optimizer + self._policy_optimizers = policy_optimizer + + self._critic_optimizers: Dict[str, snt.Optimizer] = {} + for agent in self.unique_net_keys: + self._critic_optimizers[agent] = copy.deepcopy(critic_optimizer) # Expose the variables. policy_networks_to_expose = {} @@ -196,7 +205,8 @@ def __init__( "policy": self._policy_networks[agent_key], "critic": self._critic_networks[agent_key], "observation": self._observation_networks[agent_key], - "optimizer": self._optimizer, + "policy_optimizer": self._policy_optimizers, + "critic_optimizer": self._critic_optimizers, } subdir = os.path.join("trainer", agent_key) @@ -282,11 +292,7 @@ def _step( """ losses: Dict[str, NestedArray] = { - agent: { - "critic_loss": tf.zeros(()), - "policy_loss": tf.zeros(()), - "total_loss": tf.zeros(()), - } + agent: {"critic_loss": tf.zeros(()), "policy_loss": tf.zeros(())} for agent in self._agents } # Get data from replay. @@ -306,8 +312,6 @@ def _step( + loss[agent]["critic_loss"], "policy_loss": losses[agent]["policy_loss"] + loss[agent]["policy_loss"], - "total_loss": losses[agent]["total_loss"] - + loss[agent]["total_loss"], } # Log losses per agent @@ -326,7 +330,7 @@ def forward_backward(self, inputs: Any) -> Dict[str, Dict[str, Any]]: self._backward_pass() # Log losses per agent return train_utils.map_losses_per_agent_ac( - self.critic_losses, self.policy_losses, self.total_losses + self.critic_losses, self.policy_losses ) # Forward pass that calculates loss. @@ -349,9 +353,6 @@ def _forward_pass(self, inputs: Any) -> None: data.extras, ) - # transform observation using observation networks - observations_trans = self._transform_observations(observations) - # Get log_probs. log_probs = extras["log_probs"] @@ -361,6 +362,8 @@ def _forward_pass(self, inputs: Any) -> None: total_losses: Dict[str, Any] = {} with tf.GradientTape(persistent=True) as tape: + # transform observation using observation networks + observations_trans = self._transform_observations(observations) for agent in self._agents: action, reward, termination, behaviour_log_prob = ( actions[agent], @@ -369,6 +372,10 @@ def _forward_pass(self, inputs: Any) -> None: log_probs[agent], ) + loss_mask = tf.concat( + (tf.ones((1, termination.shape[1])), termination[:-1]), 0 + ) + actor_observation = observations_trans[agent] critic_observation = self._get_critic_feed(observations_trans, agent) @@ -425,9 +432,9 @@ def _forward_pass(self, inputs: Any) -> None: # TODO Clip values to reduce variablility # Need to keep track of old value estimates (either in replay or in # training state) and clip them. - masked_critic_loss = unclipped_critic_loss * termination[:-1] + masked_critic_loss = unclipped_critic_loss * loss_mask[:-1] critic_loss = tf.reduce_sum(masked_critic_loss) / tf.reduce_sum( - termination[:-1] + loss_mask[:-1] ) critic_loss = critic_loss * self._baseline_cost @@ -445,16 +452,16 @@ def _forward_pass(self, inputs: Any) -> None: rhos * advantages, clipped_rhos * advantages ) - masked_policy_grad_loss = clipped_objective * termination[:-1] + masked_policy_grad_loss = clipped_objective * loss_mask[:-1] policy_gradient_loss = tf.reduce_sum( masked_policy_grad_loss - ) / tf.reduce_sum(termination[:-1]) + ) / tf.reduce_sum(loss_mask[:-1]) # Entropy regularization. Only implemented for categorical dist. try: - masked_entropy_loss = policy.entropy()[:-1] * termination[:-1] + masked_entropy_loss = policy.entropy()[:-1] * loss_mask[:-1] entropy_loss = -tf.reduce_sum(masked_entropy_loss) / tf.reduce_sum( - termination[:-1] + loss_mask[:-1] ) except NotImplementedError: @@ -479,28 +486,41 @@ def _backward_pass(self) -> None: """Trainer backward pass updating network parameters""" # Calculate the gradients and update the networks - total_loss = self.total_losses + policy_losses = self.policy_losses + critic_losses = self.critic_losses tape = self.tape for agent in self._agents: # Get agent_key. agent_key = self._agent_net_keys[agent] - # Get trainable variables. - variables = ( - self._policy_networks[agent_key].trainable_variables - + self._critic_networks[agent_key].trainable_variables + policy_variables = self._policy_networks[agent_key].trainable_variables + # Only use critic vars to update the observation network + # if we have two optims. + critic_variables = ( + self._critic_networks[agent_key].trainable_variables + self._observation_networks[agent_key].trainable_variables ) # Get gradients. - gradients = tape.gradient(total_loss, variables) + critic_gradients = tape.gradient(critic_losses[agent], critic_variables) + # Optionally apply clipping. + critic_grads = tf.clip_by_global_norm( + critic_gradients, self._max_gradient_norm + )[0] + # Apply gradients. + self._critic_optimizers[agent_key].apply(critic_grads, critic_variables) + + # Get gradients. + policy_gradients = tape.gradient(policy_losses[agent], policy_variables) # Optionally apply clipping. - grads = tf.clip_by_global_norm(gradients, self._max_gradient_norm)[0] + policy_grads = tf.clip_by_global_norm( + policy_gradients, self._max_gradient_norm + )[0] # Apply gradients. - self._optimizer[agent_key].apply(grads, variables) + self._policy_optimizers[agent_key].apply(policy_grads, policy_variables) train_utils.safe_del(self, "tape") @@ -553,7 +573,10 @@ def after_trainer_step(self) -> None: info: Dict[str, Dict[str, float]] = {} for agent in self._agents: info[agent] = {} - info[agent]["policy_learning_rate"] = self._optimizer[ + info[agent]["policy_learning_rate"] = self._policy_optimizers[ + self._agent_net_keys[agent] + ].learning_rate + info[agent]["critic_learning_rate"] = self._critic_optimizers[ self._agent_net_keys[agent] ].learning_rate if self._logger: @@ -565,9 +588,10 @@ def _decay_lr(self, trainer_step: int) -> None: Args: trainer_step : trainer step time t. """ - train_utils.decay_lr( - self._learning_rate_scheduler_fn, # type: ignore - self._optimizer, + train_utils.decay_lr_actor_critic( + self._learning_rate_scheduler_fn, + self._policy_optimizers, + self._critic_optimizers, trainer_step, ) @@ -583,7 +607,8 @@ def __init__( policy_networks: Dict[str, snt.Module], critic_networks: Dict[str, snt.Module], dataset: tf.data.Dataset, - optimizer: Union[snt.Optimizer, Dict[str, snt.Optimizer]], + policy_optimizer: Union[snt.Optimizer, Dict[str, snt.Optimizer]], + critic_optimizer: Optional[Union[snt.Optimizer, Dict[str, snt.Optimizer]]], agent_net_keys: Dict[str, str], checkpoint_minute_interval: int, minibatch_size: Optional[int] = None, @@ -610,7 +635,9 @@ def __init__( policy_networks : policy networks for each agent in the system. critic_networks : critic network(s), shared or for each agent in the system. dataset : training dataset. - optimizer : optimizer for updating networks. + policy_optimizer : optimizer for updating policy networks. + critic_optimizer : optimizer for updating critic networks. This is not + necessary if using single optim. agent_net_keys : specifies what network each agent uses. checkpoint_minute_interval : The number of minutes to wait between checkpoints. @@ -652,7 +679,8 @@ def __init__( dataset=dataset, agent_net_keys=agent_net_keys, checkpoint_minute_interval=checkpoint_minute_interval, - optimizer=optimizer, + policy_optimizer=policy_optimizer, + critic_optimizer=critic_optimizer, minibatch_size=minibatch_size, num_epochs=num_epochs, discount=discount, diff --git a/mava/systems/tf/value_decomposition/system.py b/mava/systems/tf/value_decomposition/system.py index 8bf848748..45ba6ed8b 100644 --- a/mava/systems/tf/value_decomposition/system.py +++ b/mava/systems/tf/value_decomposition/system.py @@ -47,7 +47,7 @@ def __init__( self, environment_factory: Callable[[bool], dm_env.Environment], network_factory: Callable[[acme_specs.BoundedArray], Dict[str, snt.Module]], - mixer: snt.Module, + mixer: Union[snt.Module, str], exploration_scheduler_fn: Union[ EpsilonScheduler, Mapping[str, EpsilonScheduler], @@ -97,7 +97,7 @@ def __init__( environment_factory: function to instantiate an environment. network_factory: function to instantiate system networks. - mixer: mixing network + mixer: mixing network. Either a sonnet module or the strings "qmix"/"vdn" exploration_scheduler_fn: function to schedule exploration. e.g. epsilon greedy logger_factory: function to diff --git a/mava/utils/environments/pettingzoo_utils.py b/mava/utils/environments/pettingzoo_utils.py index 8bc2bc845..03dc8bbd2 100644 --- a/mava/utils/environments/pettingzoo_utils.py +++ b/mava/utils/environments/pettingzoo_utils.py @@ -126,7 +126,7 @@ def make_environment( if env_type == "parallel": env = env_module.parallel_env(**kwargs) # type: ignore - if env_class == "atari": + if env_class == "atari" or "pong" in env_name: env = atari_preprocessing(env) # wrap parallel environment environment = PettingZooParallelEnvWrapper( diff --git a/mava/utils/lp_utils.py b/mava/utils/lp_utils.py index 0d14eeaab..66b697a31 100644 --- a/mava/utils/lp_utils.py +++ b/mava/utils/lp_utils.py @@ -57,9 +57,11 @@ def partial_kwargs(function: Callable[..., Any], **kwargs: Any) -> Callable[..., are not defined by `function` or if they do not have defaults. This is useful as a way to define a factory function with default parameters and then to override them in a safe way. + Args: function: the base function before partial application. **kwargs: keyword argument overrides. + Returns: A function. """ @@ -85,14 +87,13 @@ def partial_kwargs(function: Callable[..., Any], **kwargs: Any) -> Callable[..., class StepsLimiter: - """Process that terminates an experiment when `max_steps` is reached.""" - def __init__( self, counter: counting.Counter, max_steps: Optional[int], steps_key: str = "executor_steps", ): + """Process that terminates an experiment when `max_steps` is reached.""" self._counter = counter self._max_steps = max_steps self._steps_key = steps_key diff --git a/mava/utils/training_utils.py b/mava/utils/training_utils.py index 44ed824c7..2cee89a75 100644 --- a/mava/utils/training_utils.py +++ b/mava/utils/training_utils.py @@ -1,5 +1,6 @@ import os import time +import warnings from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union import numpy as np @@ -227,10 +228,13 @@ def checkpoint_networks(system_checkpointer: Dict) -> None: Args: system_checkpointer : checkpointer used by the system. """ - if system_checkpointer and len(system_checkpointer.keys()) > 0: - for network_key in system_checkpointer.keys(): - checkpointer = system_checkpointer[network_key] - checkpointer.save() + try: + if system_checkpointer and len(system_checkpointer.keys()) > 0: + for network_key in system_checkpointer.keys(): + checkpointer = system_checkpointer[network_key] + checkpointer.save() + except Exception as ex: + warnings.warn(f"Failed to checkpoint. Error: {ex}") def set_growing_gpu_memory() -> None: diff --git a/mava/wrappers/flatland.py b/mava/wrappers/flatland.py index 971de0577..91ce68659 100644 --- a/mava/wrappers/flatland.py +++ b/mava/wrappers/flatland.py @@ -423,7 +423,6 @@ def environment(self) -> RailEnv: @property def num_agents(self) -> int: """Returns the number of trains/agents in the flatland environment""" - print(self._environment.number_of_agents) return int(self._environment.number_of_agents) def __getattr__(self, name: str) -> Any: diff --git a/setup.py b/setup.py index 265fa2b78..ff712240f 100644 --- a/setup.py +++ b/setup.py @@ -52,7 +52,7 @@ testing_formatting_requirements = [ "pytest==6.2.4", "pre-commit", - "mypy==0.910", + "mypy==0.941", "pytest-xdist", "flake8==3.8.2", "black==21.4b1", @@ -104,7 +104,7 @@ "matplotlib", "dataclasses", "box2d-py", - "gym", + "gym<=0.23.0", ], extras_require={ "tf": tf_requirements, diff --git a/tests/systems/mappo_system_test.py b/tests/systems/mappo_system_test.py index cfe3cd88e..9ca505f6d 100644 --- a/tests/systems/mappo_system_test.py +++ b/tests/systems/mappo_system_test.py @@ -50,7 +50,8 @@ def test_mappo_on_debugging_env(self) -> None: num_executors=1, batch_size=32, max_queue_size=1000, - optimizer=snt.optimizers.Adam(learning_rate=1e-3), + policy_optimizer=snt.optimizers.Adam(learning_rate=1e-3), + critic_optimizer=snt.optimizers.Adam(learning_rate=1e-3), checkpoint=False, ) program = system.build()