From ef03bf6b16b9248bebaf11f6635e9c5b26f8329b Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Fri, 5 Feb 2021 15:06:41 -0800 Subject: [PATCH] setuptools-based plugin for StatsWriters (#4788) --- .github/workflows/pytest.yml | 1 + com.unity.ml-agents/CHANGELOG.md | 3 + docs/Training-Plugins.md | 58 +++++++++++++++++ ml-agents-plugin-examples/README.md | 3 + .../mlagents_plugin_examples/__init__.py | 0 .../example_stats_writer.py | 27 ++++++++ .../tests/__init__.py | 0 .../tests/test_stats_writer_plugin.py | 13 ++++ ml-agents-plugin-examples/setup.py | 17 +++++ ml-agents/mlagents/plugins/__init__.py | 1 + ml-agents/mlagents/plugins/stats_writer.py | 63 +++++++++++++++++++ ml-agents/mlagents/trainers/cli_utils.py | 3 + ml-agents/mlagents/trainers/learn.py | 41 ++++-------- ml-agents/mlagents/trainers/settings.py | 18 ++++++ ml-agents/mlagents/trainers/stats.py | 7 +++ ml-agents/setup.py | 8 ++- 16 files changed, 234 insertions(+), 29 deletions(-) create mode 100644 docs/Training-Plugins.md create mode 100644 ml-agents-plugin-examples/README.md create mode 100644 ml-agents-plugin-examples/mlagents_plugin_examples/__init__.py create mode 100644 ml-agents-plugin-examples/mlagents_plugin_examples/example_stats_writer.py create mode 100644 ml-agents-plugin-examples/mlagents_plugin_examples/tests/__init__.py create mode 100644 ml-agents-plugin-examples/mlagents_plugin_examples/tests/test_stats_writer_plugin.py create mode 100644 ml-agents-plugin-examples/setup.py create mode 100644 ml-agents/mlagents/plugins/__init__.py create mode 100644 ml-agents/mlagents/plugins/stats_writer.py diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 2fdbccb54e..65f930003b 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -45,6 +45,7 @@ jobs: python -m pip install --progress-bar=off -e ./ml-agents python -m pip install --progress-bar=off -r test_requirements.txt python -m pip install --progress-bar=off -e ./gym-unity + python -m pip install --progress-bar=off -e ./ml-agents-plugin-examples - name: Save python dependencies run: | pip freeze > pip_versions-${{ matrix.python-version }}.txt diff --git a/com.unity.ml-agents/CHANGELOG.md b/com.unity.ml-agents/CHANGELOG.md index 0af2fc19c7..4e1acad93e 100755 --- a/com.unity.ml-agents/CHANGELOG.md +++ b/com.unity.ml-agents/CHANGELOG.md @@ -12,6 +12,9 @@ and this project adheres to #### com.unity.ml-agents (C#) #### ml-agents / ml-agents-envs / gym-unity (Python) - TensorFlow trainers have been removed, please use the Torch trainers instead. (#4707) +- A plugin system for `mlagents-learn` has been added. You can now define custom + `StatsWriter` implementations and register them to be called during training. + More types of plugins will be added in the future. (#4788) ### Minor Changes #### com.unity.ml-agents / com.unity.ml-agents.extensions (C#) diff --git a/docs/Training-Plugins.md b/docs/Training-Plugins.md new file mode 100644 index 0000000000..a66d676bf6 --- /dev/null +++ b/docs/Training-Plugins.md @@ -0,0 +1,58 @@ +# Customizing Training via Plugins + +ML-Agents provides support for running your own python implementations of specific interfaces during the training +process. These interfaces are currently fairly limited, but will be expanded in the future. + +** Note ** Plugin interfaces should currently be considered "in beta", and they may change in future releases. + +## How to Write Your Own Plugin +[This video](https://www.youtube.com/watch?v=fY3Y_xPKWNA) explains the basics of how to create a plugin system using +setuptools, and is the same approach that ML-Agents' plugin system is based on. + +The `ml-agents-plugin-examples` directory contains a reference implementation of each plugin interface, so it's a good +starting point. + +### setup.py +If you don't already have a `setup.py` file for your python code, you'll need to add one. `ml-agents-plugin-examples` +has a [minimal example](../ml-agents-plugin-examples/setup.py) of this. + +In the call to `setup()`, you'll need to add to the `entry_points` dictionary for each plugin interface that you +implement. The form of this is `{entry point name}={plugin module}:{plugin function}`. For example, in + `ml-agents-plugin-examples`: +```python +entry_points={ + ML_AGENTS_STATS_WRITER: [ + "example=mlagents_plugin_examples.example_stats_writer:get_example_stats_writer" + ] +} +``` +* `ML_AGENTS_STATS_WRITER` (which is a string constant, `mlagents.stats_writer`) is the name of the plugin interface. +This must be one of the provided interfaces ([see below](#plugin-interfaces)). +* `example` is the plugin implementation name. This can be anything. +* `mlagents_plugin_examples.example_stats_writer` is the plugin module. This points to the module where the +plugin registration function is defined. +* `get_example_stats_writer` is the plugin registration function. This is called when running `mlagents-learn`. The +arguments and expected return type for this are different for each plugin interface. + +### Local Installation +Once you've defined `entry_points` in your `setup.py`, you will need to run +``` +pip install -e [path to your plugin code] +``` +in the same python virtual environment that you have `mlagents` installed. + +## Plugin Interfaces + +### StatsWriter +The StatsWriter class receives various information from the training process, such as the average Agent reward in +each summary period. By default, we log this information to the console and write it to +[TensorBoard](Using-Tensorboard.md). + +#### Interface +The `StatsWriter.write_stats()` method must be implemented in any derived classes. It takes a "category" parameter, +which typically is the behavior name of the Agents being trained, and a dictionary of `StatSummary` values with +string keys. + +#### Registration +The `StatsWriter` registration function takes a `RunOptions` argument and returns a list of `StatsWriter`s. An +example implementation is provided in [`mlagents_plugin_examples`](../ml-agents-plugin-examples/mlagents_plugin_examples/example_stats_writer.py) diff --git a/ml-agents-plugin-examples/README.md b/ml-agents-plugin-examples/README.md new file mode 100644 index 0000000000..db66662d70 --- /dev/null +++ b/ml-agents-plugin-examples/README.md @@ -0,0 +1,3 @@ +# ML-Agents Plugins + +See the [Plugins documentation](../docs/Training-Plugins.md) for more information. diff --git a/ml-agents-plugin-examples/mlagents_plugin_examples/__init__.py b/ml-agents-plugin-examples/mlagents_plugin_examples/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/ml-agents-plugin-examples/mlagents_plugin_examples/example_stats_writer.py b/ml-agents-plugin-examples/mlagents_plugin_examples/example_stats_writer.py new file mode 100644 index 0000000000..17c4afe007 --- /dev/null +++ b/ml-agents-plugin-examples/mlagents_plugin_examples/example_stats_writer.py @@ -0,0 +1,27 @@ +from typing import Dict, List +from mlagents.trainers.settings import RunOptions +from mlagents.trainers.stats import StatsWriter, StatsSummary + + +class ExampleStatsWriter(StatsWriter): + """ + Example implementation of the StatsWriter abstract class. + This doesn't do anything interesting, just prints the stats that it gets. + """ + + def write_stats( + self, category: str, values: Dict[str, StatsSummary], step: int + ) -> None: + print(f"ExampleStatsWriter category: {category} values: {values}") + + +def get_example_stats_writer(run_options: RunOptions) -> List[StatsWriter]: + """ + Registration function. This is referenced in setup.py and will + be called by mlagents-learn when it starts to determine the + list of StatsWriters to use. + + It must return a list of StatsWriters. + """ + print("Creating a new stats writer! This is so exciting!") + return [ExampleStatsWriter()] diff --git a/ml-agents-plugin-examples/mlagents_plugin_examples/tests/__init__.py b/ml-agents-plugin-examples/mlagents_plugin_examples/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/ml-agents-plugin-examples/mlagents_plugin_examples/tests/test_stats_writer_plugin.py b/ml-agents-plugin-examples/mlagents_plugin_examples/tests/test_stats_writer_plugin.py new file mode 100644 index 0000000000..574af9b753 --- /dev/null +++ b/ml-agents-plugin-examples/mlagents_plugin_examples/tests/test_stats_writer_plugin.py @@ -0,0 +1,13 @@ +import pytest + +from mlagents.plugins.stats_writer import register_stats_writer_plugins +from mlagents.trainers.settings import RunOptions + +from mlagents_plugin_examples.example_stats_writer import ExampleStatsWriter + + +@pytest.mark.check_environment_trains +def test_register_stats_writers(): + # Make sure that the ExampleStatsWriter gets returned from the list of all StatsWriters + stats_writers = register_stats_writer_plugins(RunOptions()) + assert any(isinstance(sw, ExampleStatsWriter) for sw in stats_writers) diff --git a/ml-agents-plugin-examples/setup.py b/ml-agents-plugin-examples/setup.py new file mode 100644 index 0000000000..f3fb24e0b5 --- /dev/null +++ b/ml-agents-plugin-examples/setup.py @@ -0,0 +1,17 @@ +from setuptools import setup +from mlagents.plugins import ML_AGENTS_STATS_WRITER + +setup( + name="mlagents_plugin_examples", + version="0.0.1", + # Example of how to add your own registration functions that will be called + # by mlagents-learn. + # + # Here, the get_example_stats_writer() function in mlagents_plugin_examples/example_stats_writer.py + # will get registered with the ML_AGENTS_STATS_WRITER plugin interface. + entry_points={ + ML_AGENTS_STATS_WRITER: [ + "example=mlagents_plugin_examples.example_stats_writer:get_example_stats_writer" + ] + }, +) diff --git a/ml-agents/mlagents/plugins/__init__.py b/ml-agents/mlagents/plugins/__init__.py new file mode 100644 index 0000000000..a5a5353a15 --- /dev/null +++ b/ml-agents/mlagents/plugins/__init__.py @@ -0,0 +1 @@ +ML_AGENTS_STATS_WRITER = "mlagents.stats_writer" diff --git a/ml-agents/mlagents/plugins/stats_writer.py b/ml-agents/mlagents/plugins/stats_writer.py new file mode 100644 index 0000000000..17d55926ea --- /dev/null +++ b/ml-agents/mlagents/plugins/stats_writer.py @@ -0,0 +1,63 @@ +import sys +from typing import List + +# importlib.metadata is new in python3.8 +# We use the backport for older python versions. +if sys.version_info < (3, 8): + import importlib_metadata +else: + import importlib.metadata as importlib_metadata # pylint: disable=E0611 + +from mlagents.trainers.stats import StatsWriter + +from mlagents_envs import logging_util +from mlagents.plugins import ML_AGENTS_STATS_WRITER +from mlagents.trainers.settings import RunOptions +from mlagents.trainers.stats import TensorboardWriter, GaugeWriter, ConsoleWriter + + +logger = logging_util.get_logger(__name__) + + +def get_default_stats_writers(run_options: RunOptions) -> List[StatsWriter]: + """ + The StatsWriters that mlagents-learn always uses: + * A TensorboardWriter to write information to TensorBoard + * A GaugeWriter to record our internal stats + * A ConsoleWriter to output to stdout. + """ + checkpoint_settings = run_options.checkpoint_settings + return [ + TensorboardWriter( + checkpoint_settings.write_path, + clear_past_data=not checkpoint_settings.resume, + ), + GaugeWriter(), + ConsoleWriter(), + ] + + +def register_stats_writer_plugins(run_options: RunOptions) -> List[StatsWriter]: + """ + Registers all StatsWriter plugins (including the default one), + and evaluates them, and returns the list of all the StatsWriter implementations. + """ + all_stats_writers: List[StatsWriter] = [] + entry_points = importlib_metadata.entry_points()[ML_AGENTS_STATS_WRITER] + + for entry_point in entry_points: + + try: + logger.debug(f"Initializing StatsWriter plugins: {entry_point.name}") + plugin_func = entry_point.load() + plugin_stats_writers = plugin_func(run_options) + logger.debug( + f"Found {len(plugin_stats_writers)} StatsWriters for plugin {entry_point.name}" + ) + all_stats_writers += plugin_stats_writers + except BaseException: + # Catch all exceptions from setting up the plugin, so that bad user code doesn't break things. + logger.exception( + f"Error initializing StatsWriter plugins for {entry_point.name}. This plugin will not be used." + ) + return all_stats_writers diff --git a/ml-agents/mlagents/trainers/cli_utils.py b/ml-agents/mlagents/trainers/cli_utils.py index 6849731600..fcb8b14b89 100644 --- a/ml-agents/mlagents/trainers/cli_utils.py +++ b/ml-agents/mlagents/trainers/cli_utils.py @@ -189,6 +189,9 @@ def _create_parser() -> argparse.ArgumentParser: action=RaiseRemovedWarning, help="(Removed) Use the TensorFlow framework.", ) + argparser.add_argument( + "--results-dir", default="results", help="Results base directory" + ) eng_conf = argparser.add_argument_group(title="Engine Configuration") eng_conf.add_argument( diff --git a/ml-agents/mlagents/trainers/learn.py b/ml-agents/mlagents/trainers/learn.py index b43b5b24fa..66939345b9 100644 --- a/ml-agents/mlagents/trainers/learn.py +++ b/ml-agents/mlagents/trainers/learn.py @@ -14,12 +14,7 @@ from mlagents.trainers.environment_parameter_manager import EnvironmentParameterManager from mlagents.trainers.trainer import TrainerFactory from mlagents.trainers.directory_utils import validate_existing_directories -from mlagents.trainers.stats import ( - TensorboardWriter, - StatsReporter, - GaugeWriter, - ConsoleWriter, -) +from mlagents.trainers.stats import StatsReporter from mlagents.trainers.cli_utils import parser from mlagents_envs.environment import UnityEnvironment from mlagents.trainers.settings import RunOptions @@ -34,6 +29,7 @@ add_metadata as add_timer_metadata, ) from mlagents_envs import logging_util +from mlagents.plugins.stats_writer import register_stats_writer_plugins logger = logging_util.get_logger(__name__) @@ -65,21 +61,15 @@ def run_training(run_seed: int, options: RunOptions) -> None: checkpoint_settings = options.checkpoint_settings env_settings = options.env_settings engine_settings = options.engine_settings - base_path = "results" - write_path = os.path.join(base_path, checkpoint_settings.run_id) - maybe_init_path = ( - os.path.join(base_path, checkpoint_settings.initialize_from) - if checkpoint_settings.initialize_from is not None - else None - ) - run_logs_dir = os.path.join(write_path, "run_logs") + + run_logs_dir = checkpoint_settings.run_logs_dir port: Optional[int] = env_settings.base_port # Check if directory exists validate_existing_directories( - write_path, + checkpoint_settings.write_path, checkpoint_settings.resume, checkpoint_settings.force, - maybe_init_path, + checkpoint_settings.maybe_init_path, ) # Make run logs directory os.makedirs(run_logs_dir, exist_ok=True) @@ -90,14 +80,9 @@ def run_training(run_seed: int, options: RunOptions) -> None: ) # Configure Tensorboard Writers and StatsReporter - tb_writer = TensorboardWriter( - write_path, clear_past_data=not checkpoint_settings.resume - ) - gauge_write = GaugeWriter() - console_writer = ConsoleWriter() - StatsReporter.add_writer(tb_writer) - StatsReporter.add_writer(gauge_write) - StatsReporter.add_writer(console_writer) + stats_writers = register_stats_writer_plugins(options) + for sw in stats_writers: + StatsReporter.add_writer(sw) if env_settings.env_path is None: port = None @@ -117,18 +102,18 @@ def run_training(run_seed: int, options: RunOptions) -> None: trainer_factory = TrainerFactory( trainer_config=options.behaviors, - output_path=write_path, + output_path=checkpoint_settings.write_path, train_model=not checkpoint_settings.inference, load_model=checkpoint_settings.resume, seed=run_seed, param_manager=env_parameter_manager, - init_path=maybe_init_path, + init_path=checkpoint_settings.maybe_init_path, multi_gpu=False, ) # Create controller and begin training. tc = TrainerController( trainer_factory, - write_path, + checkpoint_settings.write_path, checkpoint_settings.run_id, env_parameter_manager, not checkpoint_settings.inference, @@ -140,7 +125,7 @@ def run_training(run_seed: int, options: RunOptions) -> None: tc.start_learning(env_manager) finally: env_manager.close() - write_run_options(write_path, options) + write_run_options(checkpoint_settings.write_path, options) write_timing_tree(run_logs_dir) write_training_status(run_logs_dir) diff --git a/ml-agents/mlagents/trainers/settings.py b/ml-agents/mlagents/trainers/settings.py index 02865c96f3..16905949f5 100644 --- a/ml-agents/mlagents/trainers/settings.py +++ b/ml-agents/mlagents/trainers/settings.py @@ -1,3 +1,4 @@ +import os.path import warnings import attr @@ -706,6 +707,23 @@ class CheckpointSettings: force: bool = parser.get_default("force") train_model: bool = parser.get_default("train_model") inference: bool = parser.get_default("inference") + results_dir: str = parser.get_default("results_dir") + + @property + def write_path(self) -> str: + return os.path.join(self.results_dir, self.run_id) + + @property + def maybe_init_path(self) -> Optional[str]: + return ( + os.path.join(self.results_dir, self.initialize_from) + if self.initialize_from is not None + else None + ) + + @property + def run_logs_dir(self) -> str: + return os.path.join(self.write_path, "run_logs") @attr.s(auto_attribs=True) diff --git a/ml-agents/mlagents/trainers/stats.py b/ml-agents/mlagents/trainers/stats.py index a466b95ac9..7e1996f90a 100644 --- a/ml-agents/mlagents/trainers/stats.py +++ b/ml-agents/mlagents/trainers/stats.py @@ -87,6 +87,13 @@ class StatsWriter(abc.ABC): def write_stats( self, category: str, values: Dict[str, StatsSummary], step: int ) -> None: + """ + Callback to record training information + :param category: Category of the statistics. Usually this is the behavior name. + :param values: Dictionary of statistics. + :param step: The current training step. + :return: + """ pass def add_property( diff --git a/ml-agents/setup.py b/ml-agents/setup.py index 55fcc94467..30a01a25e2 100644 --- a/ml-agents/setup.py +++ b/ml-agents/setup.py @@ -3,6 +3,7 @@ from setuptools import setup, find_packages from setuptools.command.install import install +from mlagents.plugins import ML_AGENTS_STATS_WRITER import mlagents.trainers VERSION = mlagents.trainers.__version__ @@ -71,13 +72,18 @@ def run(self): "cattrs>=1.0.0,<1.1.0", "attrs>=19.3.0", 'pypiwin32==223;platform_system=="Windows"', + "importlib_metadata; python_version<'3.8'", ], python_requires=">=3.6.1", entry_points={ "console_scripts": [ "mlagents-learn=mlagents.trainers.learn:main", "mlagents-run-experiment=mlagents.trainers.run_experiment:main", - ] + ], + # Plugins - each plugin type should have an entry here for the default behavior + ML_AGENTS_STATS_WRITER: [ + "default=mlagents.plugins.stats_writer:get_default_stats_writers" + ], }, cmdclass={"verify": VerifyVersionCommand}, )