Skip to content

Commit

Permalink
Add usage examples.
Browse files Browse the repository at this point in the history
  • Loading branch information
chrischute committed Oct 8, 2018
1 parent 845403c commit 8647062
Show file tree
Hide file tree
Showing 7 changed files with 127 additions and 89 deletions.
49 changes: 47 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,50 @@

This is a Python implementation of population-based training, as described in
[Population Based Training of Neural Networks](https://arxiv.org/abs/1711.09846) by
Max Jaderberg, Valentin Dalibard, Simon Osindero, Wojciech M. Czarnecki, Jeff Donahue, Ali Razavi, Oriol Vinyals,
Tim Green, Iain Dunning, Karen Simonyan, Chrisantha Fernando, Koray Kavukcuoglu.
Jaderberg et al.

![TensorBoard Plot of Metrics during PBT](img/kappa.png)
![TensorBoard Plot of Hyperparameters during PBT](img/lrs.png)
*Example training run: Evaluation metric (top) and hyperparameter values (bottom) over time
during population-based training (population size 10).*

### Usage
Clone this repository and add it to your project's source tree. Then add PBT to your project
with the following commands:

1. Start a PBT server.
```python
server = PBTServer(args.port, args.auth_key, args.maximize_metric)
```

2. Create a PBT client:
```python
pbt_client = PBTClient(args.pbt_server_url, args.pbt_server_port, args.pbt_server_key, args.pbt_config_path)
```

3. Exploit and explore: Suppose we've just written a checkpoint to `ckpt_path` and
evaluated our model, producing a score `metric_val` (*e.g.,* validation accuracy).
Then we might do the following:
```python
pbt_client.save(ckpt_path, metric_val)
if pbt_client.should_exploit():
# Exploit
pbt_client.exploit()

# Load model and optimizer parameters from exploited network
model = load_model(pbt_client.parameters_path(), args.gpu_ids)
model.train()
load_optimizer(pbt_client.parameters_path(), gpu_ids, optimizer)

# Explore
pbt_client.explore()
```
Note each step performed in the block above:
1. `pbt_client.save`: Tell the PBT server that this client just saved a checkpoint to `ckpt_path`
with evaluation score `metric_val`.
2. `pbt_client.should_exploit`: Ask the PBT server if this client should exploit another model. *E.g.,* when
using truncation selection, this is true when the client's performance ranks in the bottom 20% of the population.
3. `pbt_client.exploit`: Ask the PBT server for a checkpoint path of a model to exploit.
4. `load_model` and `load_optimizer`: Load the parameters and hyperparameters of the exploited model.
5. `pbt_client.explore`: Explore the hyperparameter space. *E.g.,* when using the perturb strategy,
multiply each hyperparameter by 0.8 or 1.2.
File renamed without changes.
78 changes: 78 additions & 0 deletions examples/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""Example train loop for Population-Based Training."""

# These imports are not expected to work as-is.
# This is just meant to show how pbt might look in a train loop.
import models
import optim
import torch
import torch.nn as nn
import torch.nn.functional as F

from args import TrainArgParser
from data_loader import DataLoader
from evaluator import ModelEvaluator
from pbt.client import PBTClient
from saver import ModelSaver


def train(args):

if args.ckpt_path:
model, ckpt_info = ModelSaver.load_model(args.ckpt_path, args.gpu_ids)
args.start_epoch = ckpt_info['epoch'] + 1
else:
model_fn = models.__dict__[args.model]
model = model_fn(**vars(args))
model = nn.DataParallel(model, args.gpu_ids)
model = model.to(args.device)
model.train()

# Set up population-based training client
pbt_client = PBTClient(args.pbt_server_url, args.pbt_server_port, args.pbt_server_key, args.pbt_config_path)

# Get optimizer and scheduler
parameters = model.module.parameters()
optimizer = optim.get_optimizer(parameters, args, pbt_client)
ModelSaver.load_optimizer(args.ckpt_path, args.gpu_ids, optimizer)

# Get logger, evaluator, saver
train_loader = DataLoader(args, 'train', is_training_set=True)
eval_loaders = [DataLoader(args, 'valid', is_training_set=False)]
evaluator = ModelEvaluator(eval_loaders, args.epochs_per_eval,
args.max_eval, args.num_visuals, use_ten_crop=args.use_ten_crop)
saver = ModelSaver(**vars(args))

for _ in range(args.num_epochs):
optim.update_hyperparameters(model.module, optimizer, pbt_client.hyperparameters())

for inputs, targets in train_loader:
with torch.set_grad_enabled(True):
logits = model.forward(inputs.to(args.device))
loss = F.binary_cross_entropy_with_logits(logits, targets.to(args.device))

optimizer.zero_grad()
loss.backward()
optimizer.step()

metrics = evaluator.evaluate(model, args.device)
metric_val = metrics.get(args.metric_name, None)
ckpt_path = saver.save(model, args.model, optimizer, args.device, metric_val)

pbt_client.save(ckpt_path, metric_val)
if pbt_client.should_exploit():
# Exploit
pbt_client.exploit()

# Load model and optimizer parameters from exploited network
model, ckpt_info = ModelSaver.load_model(pbt_client.parameters_path(), args.gpu_ids)
model = model.to(args.device)
model.train()
ModelSaver.load_optimizer(pbt_client.parameters_path(), args.gpu_ids, optimizer)

# Explore
pbt_client.explore()


if __name__ == '__main__':
parser = TrainArgParser()
train(parser.parse_args())
Binary file added img/kappa.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added img/lrs.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions scripts/run_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def update_hyperparameters(hyperparameters):
help='Port on which the server listens for clients.')
parser.add_argument('--auth_key', type=str, default='insecure',
help='Key for clients to authenticate with server.')
parser.add_argument('--config_path', type=str, default='templates/hyperparameters.csv',
help='Path to configuration file defining hyperparameter search space (see templates).')
parser.add_argument('--config_path', type=str, default='examples/hyperparameters.csv',
help='Path to configuration file defining hyperparameter search space (see examples).')

# Training Settings
parser.add_argument('--num_epochs', type=int, default=10,
Expand Down
85 changes: 0 additions & 85 deletions scripts/test_pyro.py

This file was deleted.

0 comments on commit 8647062

Please sign in to comment.