Skip to content

Commit

Permalink
setuptools-based plugin for StatsWriters (Unity-Technologies#4788)
Browse files Browse the repository at this point in the history
  • Loading branch information
Chris Elion authored Feb 5, 2021
1 parent 2352955 commit ef03bf6
Show file tree
Hide file tree
Showing 16 changed files with 234 additions and 29 deletions.
1 change: 1 addition & 0 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ jobs:
python -m pip install --progress-bar=off -e ./ml-agents
python -m pip install --progress-bar=off -r test_requirements.txt
python -m pip install --progress-bar=off -e ./gym-unity
python -m pip install --progress-bar=off -e ./ml-agents-plugin-examples
- name: Save python dependencies
run: |
pip freeze > pip_versions-${{ matrix.python-version }}.txt
Expand Down
3 changes: 3 additions & 0 deletions com.unity.ml-agents/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ and this project adheres to
#### com.unity.ml-agents (C#)
#### ml-agents / ml-agents-envs / gym-unity (Python)
- TensorFlow trainers have been removed, please use the Torch trainers instead. (#4707)
- A plugin system for `mlagents-learn` has been added. You can now define custom
`StatsWriter` implementations and register them to be called during training.
More types of plugins will be added in the future. (#4788)

### Minor Changes
#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#)
Expand Down
58 changes: 58 additions & 0 deletions docs/Training-Plugins.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Customizing Training via Plugins

ML-Agents provides support for running your own python implementations of specific interfaces during the training
process. These interfaces are currently fairly limited, but will be expanded in the future.

** Note ** Plugin interfaces should currently be considered "in beta", and they may change in future releases.

## How to Write Your Own Plugin
[This video](https://www.youtube.com/watch?v=fY3Y_xPKWNA) explains the basics of how to create a plugin system using
setuptools, and is the same approach that ML-Agents' plugin system is based on.

The `ml-agents-plugin-examples` directory contains a reference implementation of each plugin interface, so it's a good
starting point.

### setup.py
If you don't already have a `setup.py` file for your python code, you'll need to add one. `ml-agents-plugin-examples`
has a [minimal example](../ml-agents-plugin-examples/setup.py) of this.

In the call to `setup()`, you'll need to add to the `entry_points` dictionary for each plugin interface that you
implement. The form of this is `{entry point name}={plugin module}:{plugin function}`. For example, in
`ml-agents-plugin-examples`:
```python
entry_points={
ML_AGENTS_STATS_WRITER: [
"example=mlagents_plugin_examples.example_stats_writer:get_example_stats_writer"
]
}
```
* `ML_AGENTS_STATS_WRITER` (which is a string constant, `mlagents.stats_writer`) is the name of the plugin interface.
This must be one of the provided interfaces ([see below](#plugin-interfaces)).
* `example` is the plugin implementation name. This can be anything.
* `mlagents_plugin_examples.example_stats_writer` is the plugin module. This points to the module where the
plugin registration function is defined.
* `get_example_stats_writer` is the plugin registration function. This is called when running `mlagents-learn`. The
arguments and expected return type for this are different for each plugin interface.

### Local Installation
Once you've defined `entry_points` in your `setup.py`, you will need to run
```
pip install -e [path to your plugin code]
```
in the same python virtual environment that you have `mlagents` installed.

## Plugin Interfaces

### StatsWriter
The StatsWriter class receives various information from the training process, such as the average Agent reward in
each summary period. By default, we log this information to the console and write it to
[TensorBoard](Using-Tensorboard.md).

#### Interface
The `StatsWriter.write_stats()` method must be implemented in any derived classes. It takes a "category" parameter,
which typically is the behavior name of the Agents being trained, and a dictionary of `StatSummary` values with
string keys.

#### Registration
The `StatsWriter` registration function takes a `RunOptions` argument and returns a list of `StatsWriter`s. An
example implementation is provided in [`mlagents_plugin_examples`](../ml-agents-plugin-examples/mlagents_plugin_examples/example_stats_writer.py)
3 changes: 3 additions & 0 deletions ml-agents-plugin-examples/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# ML-Agents Plugins

See the [Plugins documentation](../docs/Training-Plugins.md) for more information.
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from typing import Dict, List
from mlagents.trainers.settings import RunOptions
from mlagents.trainers.stats import StatsWriter, StatsSummary


class ExampleStatsWriter(StatsWriter):
"""
Example implementation of the StatsWriter abstract class.
This doesn't do anything interesting, just prints the stats that it gets.
"""

def write_stats(
self, category: str, values: Dict[str, StatsSummary], step: int
) -> None:
print(f"ExampleStatsWriter category: {category} values: {values}")


def get_example_stats_writer(run_options: RunOptions) -> List[StatsWriter]:
"""
Registration function. This is referenced in setup.py and will
be called by mlagents-learn when it starts to determine the
list of StatsWriters to use.
It must return a list of StatsWriters.
"""
print("Creating a new stats writer! This is so exciting!")
return [ExampleStatsWriter()]
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import pytest

from mlagents.plugins.stats_writer import register_stats_writer_plugins
from mlagents.trainers.settings import RunOptions

from mlagents_plugin_examples.example_stats_writer import ExampleStatsWriter


@pytest.mark.check_environment_trains
def test_register_stats_writers():
# Make sure that the ExampleStatsWriter gets returned from the list of all StatsWriters
stats_writers = register_stats_writer_plugins(RunOptions())
assert any(isinstance(sw, ExampleStatsWriter) for sw in stats_writers)
17 changes: 17 additions & 0 deletions ml-agents-plugin-examples/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from setuptools import setup
from mlagents.plugins import ML_AGENTS_STATS_WRITER

setup(
name="mlagents_plugin_examples",
version="0.0.1",
# Example of how to add your own registration functions that will be called
# by mlagents-learn.
#
# Here, the get_example_stats_writer() function in mlagents_plugin_examples/example_stats_writer.py
# will get registered with the ML_AGENTS_STATS_WRITER plugin interface.
entry_points={
ML_AGENTS_STATS_WRITER: [
"example=mlagents_plugin_examples.example_stats_writer:get_example_stats_writer"
]
},
)
1 change: 1 addition & 0 deletions ml-agents/mlagents/plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ML_AGENTS_STATS_WRITER = "mlagents.stats_writer"
63 changes: 63 additions & 0 deletions ml-agents/mlagents/plugins/stats_writer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import sys
from typing import List

# importlib.metadata is new in python3.8
# We use the backport for older python versions.
if sys.version_info < (3, 8):
import importlib_metadata
else:
import importlib.metadata as importlib_metadata # pylint: disable=E0611

from mlagents.trainers.stats import StatsWriter

from mlagents_envs import logging_util
from mlagents.plugins import ML_AGENTS_STATS_WRITER
from mlagents.trainers.settings import RunOptions
from mlagents.trainers.stats import TensorboardWriter, GaugeWriter, ConsoleWriter


logger = logging_util.get_logger(__name__)


def get_default_stats_writers(run_options: RunOptions) -> List[StatsWriter]:
"""
The StatsWriters that mlagents-learn always uses:
* A TensorboardWriter to write information to TensorBoard
* A GaugeWriter to record our internal stats
* A ConsoleWriter to output to stdout.
"""
checkpoint_settings = run_options.checkpoint_settings
return [
TensorboardWriter(
checkpoint_settings.write_path,
clear_past_data=not checkpoint_settings.resume,
),
GaugeWriter(),
ConsoleWriter(),
]


def register_stats_writer_plugins(run_options: RunOptions) -> List[StatsWriter]:
"""
Registers all StatsWriter plugins (including the default one),
and evaluates them, and returns the list of all the StatsWriter implementations.
"""
all_stats_writers: List[StatsWriter] = []
entry_points = importlib_metadata.entry_points()[ML_AGENTS_STATS_WRITER]

for entry_point in entry_points:

try:
logger.debug(f"Initializing StatsWriter plugins: {entry_point.name}")
plugin_func = entry_point.load()
plugin_stats_writers = plugin_func(run_options)
logger.debug(
f"Found {len(plugin_stats_writers)} StatsWriters for plugin {entry_point.name}"
)
all_stats_writers += plugin_stats_writers
except BaseException:
# Catch all exceptions from setting up the plugin, so that bad user code doesn't break things.
logger.exception(
f"Error initializing StatsWriter plugins for {entry_point.name}. This plugin will not be used."
)
return all_stats_writers
3 changes: 3 additions & 0 deletions ml-agents/mlagents/trainers/cli_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,9 @@ def _create_parser() -> argparse.ArgumentParser:
action=RaiseRemovedWarning,
help="(Removed) Use the TensorFlow framework.",
)
argparser.add_argument(
"--results-dir", default="results", help="Results base directory"
)

eng_conf = argparser.add_argument_group(title="Engine Configuration")
eng_conf.add_argument(
Expand Down
41 changes: 13 additions & 28 deletions ml-agents/mlagents/trainers/learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,7 @@
from mlagents.trainers.environment_parameter_manager import EnvironmentParameterManager
from mlagents.trainers.trainer import TrainerFactory
from mlagents.trainers.directory_utils import validate_existing_directories
from mlagents.trainers.stats import (
TensorboardWriter,
StatsReporter,
GaugeWriter,
ConsoleWriter,
)
from mlagents.trainers.stats import StatsReporter
from mlagents.trainers.cli_utils import parser
from mlagents_envs.environment import UnityEnvironment
from mlagents.trainers.settings import RunOptions
Expand All @@ -34,6 +29,7 @@
add_metadata as add_timer_metadata,
)
from mlagents_envs import logging_util
from mlagents.plugins.stats_writer import register_stats_writer_plugins

logger = logging_util.get_logger(__name__)

Expand Down Expand Up @@ -65,21 +61,15 @@ def run_training(run_seed: int, options: RunOptions) -> None:
checkpoint_settings = options.checkpoint_settings
env_settings = options.env_settings
engine_settings = options.engine_settings
base_path = "results"
write_path = os.path.join(base_path, checkpoint_settings.run_id)
maybe_init_path = (
os.path.join(base_path, checkpoint_settings.initialize_from)
if checkpoint_settings.initialize_from is not None
else None
)
run_logs_dir = os.path.join(write_path, "run_logs")

run_logs_dir = checkpoint_settings.run_logs_dir
port: Optional[int] = env_settings.base_port
# Check if directory exists
validate_existing_directories(
write_path,
checkpoint_settings.write_path,
checkpoint_settings.resume,
checkpoint_settings.force,
maybe_init_path,
checkpoint_settings.maybe_init_path,
)
# Make run logs directory
os.makedirs(run_logs_dir, exist_ok=True)
Expand All @@ -90,14 +80,9 @@ def run_training(run_seed: int, options: RunOptions) -> None:
)

# Configure Tensorboard Writers and StatsReporter
tb_writer = TensorboardWriter(
write_path, clear_past_data=not checkpoint_settings.resume
)
gauge_write = GaugeWriter()
console_writer = ConsoleWriter()
StatsReporter.add_writer(tb_writer)
StatsReporter.add_writer(gauge_write)
StatsReporter.add_writer(console_writer)
stats_writers = register_stats_writer_plugins(options)
for sw in stats_writers:
StatsReporter.add_writer(sw)

if env_settings.env_path is None:
port = None
Expand All @@ -117,18 +102,18 @@ def run_training(run_seed: int, options: RunOptions) -> None:

trainer_factory = TrainerFactory(
trainer_config=options.behaviors,
output_path=write_path,
output_path=checkpoint_settings.write_path,
train_model=not checkpoint_settings.inference,
load_model=checkpoint_settings.resume,
seed=run_seed,
param_manager=env_parameter_manager,
init_path=maybe_init_path,
init_path=checkpoint_settings.maybe_init_path,
multi_gpu=False,
)
# Create controller and begin training.
tc = TrainerController(
trainer_factory,
write_path,
checkpoint_settings.write_path,
checkpoint_settings.run_id,
env_parameter_manager,
not checkpoint_settings.inference,
Expand All @@ -140,7 +125,7 @@ def run_training(run_seed: int, options: RunOptions) -> None:
tc.start_learning(env_manager)
finally:
env_manager.close()
write_run_options(write_path, options)
write_run_options(checkpoint_settings.write_path, options)
write_timing_tree(run_logs_dir)
write_training_status(run_logs_dir)

Expand Down
18 changes: 18 additions & 0 deletions ml-agents/mlagents/trainers/settings.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os.path
import warnings

import attr
Expand Down Expand Up @@ -706,6 +707,23 @@ class CheckpointSettings:
force: bool = parser.get_default("force")
train_model: bool = parser.get_default("train_model")
inference: bool = parser.get_default("inference")
results_dir: str = parser.get_default("results_dir")

@property
def write_path(self) -> str:
return os.path.join(self.results_dir, self.run_id)

@property
def maybe_init_path(self) -> Optional[str]:
return (
os.path.join(self.results_dir, self.initialize_from)
if self.initialize_from is not None
else None
)

@property
def run_logs_dir(self) -> str:
return os.path.join(self.write_path, "run_logs")


@attr.s(auto_attribs=True)
Expand Down
7 changes: 7 additions & 0 deletions ml-agents/mlagents/trainers/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,13 @@ class StatsWriter(abc.ABC):
def write_stats(
self, category: str, values: Dict[str, StatsSummary], step: int
) -> None:
"""
Callback to record training information
:param category: Category of the statistics. Usually this is the behavior name.
:param values: Dictionary of statistics.
:param step: The current training step.
:return:
"""
pass

def add_property(
Expand Down
8 changes: 7 additions & 1 deletion ml-agents/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from setuptools import setup, find_packages
from setuptools.command.install import install
from mlagents.plugins import ML_AGENTS_STATS_WRITER
import mlagents.trainers

VERSION = mlagents.trainers.__version__
Expand Down Expand Up @@ -71,13 +72,18 @@ def run(self):
"cattrs>=1.0.0,<1.1.0",
"attrs>=19.3.0",
'pypiwin32==223;platform_system=="Windows"',
"importlib_metadata; python_version<'3.8'",
],
python_requires=">=3.6.1",
entry_points={
"console_scripts": [
"mlagents-learn=mlagents.trainers.learn:main",
"mlagents-run-experiment=mlagents.trainers.run_experiment:main",
]
],
# Plugins - each plugin type should have an entry here for the default behavior
ML_AGENTS_STATS_WRITER: [
"default=mlagents.plugins.stats_writer:get_default_stats_writers"
],
},
cmdclass={"verify": VerifyVersionCommand},
)

0 comments on commit ef03bf6

Please sign in to comment.