forked from Unity-Technologies/ml-agents
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
setuptools-based plugin for StatsWriters (Unity-Technologies#4788)
- Loading branch information
Chris Elion
authored
Feb 5, 2021
1 parent
2352955
commit ef03bf6
Showing
16 changed files
with
234 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# Customizing Training via Plugins | ||
|
||
ML-Agents provides support for running your own python implementations of specific interfaces during the training | ||
process. These interfaces are currently fairly limited, but will be expanded in the future. | ||
|
||
** Note ** Plugin interfaces should currently be considered "in beta", and they may change in future releases. | ||
|
||
## How to Write Your Own Plugin | ||
[This video](https://www.youtube.com/watch?v=fY3Y_xPKWNA) explains the basics of how to create a plugin system using | ||
setuptools, and is the same approach that ML-Agents' plugin system is based on. | ||
|
||
The `ml-agents-plugin-examples` directory contains a reference implementation of each plugin interface, so it's a good | ||
starting point. | ||
|
||
### setup.py | ||
If you don't already have a `setup.py` file for your python code, you'll need to add one. `ml-agents-plugin-examples` | ||
has a [minimal example](../ml-agents-plugin-examples/setup.py) of this. | ||
|
||
In the call to `setup()`, you'll need to add to the `entry_points` dictionary for each plugin interface that you | ||
implement. The form of this is `{entry point name}={plugin module}:{plugin function}`. For example, in | ||
`ml-agents-plugin-examples`: | ||
```python | ||
entry_points={ | ||
ML_AGENTS_STATS_WRITER: [ | ||
"example=mlagents_plugin_examples.example_stats_writer:get_example_stats_writer" | ||
] | ||
} | ||
``` | ||
* `ML_AGENTS_STATS_WRITER` (which is a string constant, `mlagents.stats_writer`) is the name of the plugin interface. | ||
This must be one of the provided interfaces ([see below](#plugin-interfaces)). | ||
* `example` is the plugin implementation name. This can be anything. | ||
* `mlagents_plugin_examples.example_stats_writer` is the plugin module. This points to the module where the | ||
plugin registration function is defined. | ||
* `get_example_stats_writer` is the plugin registration function. This is called when running `mlagents-learn`. The | ||
arguments and expected return type for this are different for each plugin interface. | ||
|
||
### Local Installation | ||
Once you've defined `entry_points` in your `setup.py`, you will need to run | ||
``` | ||
pip install -e [path to your plugin code] | ||
``` | ||
in the same python virtual environment that you have `mlagents` installed. | ||
|
||
## Plugin Interfaces | ||
|
||
### StatsWriter | ||
The StatsWriter class receives various information from the training process, such as the average Agent reward in | ||
each summary period. By default, we log this information to the console and write it to | ||
[TensorBoard](Using-Tensorboard.md). | ||
|
||
#### Interface | ||
The `StatsWriter.write_stats()` method must be implemented in any derived classes. It takes a "category" parameter, | ||
which typically is the behavior name of the Agents being trained, and a dictionary of `StatSummary` values with | ||
string keys. | ||
|
||
#### Registration | ||
The `StatsWriter` registration function takes a `RunOptions` argument and returns a list of `StatsWriter`s. An | ||
example implementation is provided in [`mlagents_plugin_examples`](../ml-agents-plugin-examples/mlagents_plugin_examples/example_stats_writer.py) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# ML-Agents Plugins | ||
|
||
See the [Plugins documentation](../docs/Training-Plugins.md) for more information. |
Empty file.
27 changes: 27 additions & 0 deletions
27
ml-agents-plugin-examples/mlagents_plugin_examples/example_stats_writer.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from typing import Dict, List | ||
from mlagents.trainers.settings import RunOptions | ||
from mlagents.trainers.stats import StatsWriter, StatsSummary | ||
|
||
|
||
class ExampleStatsWriter(StatsWriter): | ||
""" | ||
Example implementation of the StatsWriter abstract class. | ||
This doesn't do anything interesting, just prints the stats that it gets. | ||
""" | ||
|
||
def write_stats( | ||
self, category: str, values: Dict[str, StatsSummary], step: int | ||
) -> None: | ||
print(f"ExampleStatsWriter category: {category} values: {values}") | ||
|
||
|
||
def get_example_stats_writer(run_options: RunOptions) -> List[StatsWriter]: | ||
""" | ||
Registration function. This is referenced in setup.py and will | ||
be called by mlagents-learn when it starts to determine the | ||
list of StatsWriters to use. | ||
It must return a list of StatsWriters. | ||
""" | ||
print("Creating a new stats writer! This is so exciting!") | ||
return [ExampleStatsWriter()] |
Empty file.
13 changes: 13 additions & 0 deletions
13
ml-agents-plugin-examples/mlagents_plugin_examples/tests/test_stats_writer_plugin.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
import pytest | ||
|
||
from mlagents.plugins.stats_writer import register_stats_writer_plugins | ||
from mlagents.trainers.settings import RunOptions | ||
|
||
from mlagents_plugin_examples.example_stats_writer import ExampleStatsWriter | ||
|
||
|
||
@pytest.mark.check_environment_trains | ||
def test_register_stats_writers(): | ||
# Make sure that the ExampleStatsWriter gets returned from the list of all StatsWriters | ||
stats_writers = register_stats_writer_plugins(RunOptions()) | ||
assert any(isinstance(sw, ExampleStatsWriter) for sw in stats_writers) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from setuptools import setup | ||
from mlagents.plugins import ML_AGENTS_STATS_WRITER | ||
|
||
setup( | ||
name="mlagents_plugin_examples", | ||
version="0.0.1", | ||
# Example of how to add your own registration functions that will be called | ||
# by mlagents-learn. | ||
# | ||
# Here, the get_example_stats_writer() function in mlagents_plugin_examples/example_stats_writer.py | ||
# will get registered with the ML_AGENTS_STATS_WRITER plugin interface. | ||
entry_points={ | ||
ML_AGENTS_STATS_WRITER: [ | ||
"example=mlagents_plugin_examples.example_stats_writer:get_example_stats_writer" | ||
] | ||
}, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
ML_AGENTS_STATS_WRITER = "mlagents.stats_writer" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import sys | ||
from typing import List | ||
|
||
# importlib.metadata is new in python3.8 | ||
# We use the backport for older python versions. | ||
if sys.version_info < (3, 8): | ||
import importlib_metadata | ||
else: | ||
import importlib.metadata as importlib_metadata # pylint: disable=E0611 | ||
|
||
from mlagents.trainers.stats import StatsWriter | ||
|
||
from mlagents_envs import logging_util | ||
from mlagents.plugins import ML_AGENTS_STATS_WRITER | ||
from mlagents.trainers.settings import RunOptions | ||
from mlagents.trainers.stats import TensorboardWriter, GaugeWriter, ConsoleWriter | ||
|
||
|
||
logger = logging_util.get_logger(__name__) | ||
|
||
|
||
def get_default_stats_writers(run_options: RunOptions) -> List[StatsWriter]: | ||
""" | ||
The StatsWriters that mlagents-learn always uses: | ||
* A TensorboardWriter to write information to TensorBoard | ||
* A GaugeWriter to record our internal stats | ||
* A ConsoleWriter to output to stdout. | ||
""" | ||
checkpoint_settings = run_options.checkpoint_settings | ||
return [ | ||
TensorboardWriter( | ||
checkpoint_settings.write_path, | ||
clear_past_data=not checkpoint_settings.resume, | ||
), | ||
GaugeWriter(), | ||
ConsoleWriter(), | ||
] | ||
|
||
|
||
def register_stats_writer_plugins(run_options: RunOptions) -> List[StatsWriter]: | ||
""" | ||
Registers all StatsWriter plugins (including the default one), | ||
and evaluates them, and returns the list of all the StatsWriter implementations. | ||
""" | ||
all_stats_writers: List[StatsWriter] = [] | ||
entry_points = importlib_metadata.entry_points()[ML_AGENTS_STATS_WRITER] | ||
|
||
for entry_point in entry_points: | ||
|
||
try: | ||
logger.debug(f"Initializing StatsWriter plugins: {entry_point.name}") | ||
plugin_func = entry_point.load() | ||
plugin_stats_writers = plugin_func(run_options) | ||
logger.debug( | ||
f"Found {len(plugin_stats_writers)} StatsWriters for plugin {entry_point.name}" | ||
) | ||
all_stats_writers += plugin_stats_writers | ||
except BaseException: | ||
# Catch all exceptions from setting up the plugin, so that bad user code doesn't break things. | ||
logger.exception( | ||
f"Error initializing StatsWriter plugins for {entry_point.name}. This plugin will not be used." | ||
) | ||
return all_stats_writers |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters