Skip to content

Commit

Permalink
Clean up after SLURM test.
Browse files Browse the repository at this point in the history
  • Loading branch information
chrischute committed Aug 21, 2018
1 parent 52d2bcd commit 37c097c
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 63 deletions.
10 changes: 5 additions & 5 deletions checkpoint/pbt_checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
class PBTCheckpoint(object):
"""Checkpoint for saving model performance during training."""

def __init__(self, member_id, metric_value, hyperparameters, parameters_path):
def __init__(self, client_id, metric_value, hyperparameters, parameters_path):
"""
Args:
member_id (int): ID for population member.
client_id (int): ID for population member.
metric_value (float): Value of metric for determining which checkpoints are best.
hyperparameters (dict): Dictionary of hyperparameters.
parameters_path (str): Path to saved network parameters.
"""
self._member_id = member_id
self._client_id = client_id
self._metric_value = metric_value
self._hyperparameters = hyperparameters.copy()
self._parameters_path = parameters_path

def member_id(self):
return self._member_id
def client_id(self):
return self._client_id

def metric_value(self):
return self._metric_value
Expand Down
11 changes: 2 additions & 9 deletions client/pbt_client.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import json
import math
import pandas as pd
import random
import time

from ast import literal_eval
from checkpoint import PBTCheckpoint
from pbt.checkpoint import PBTCheckpoint
from multiprocessing.managers import SyncManager


Expand All @@ -28,7 +27,6 @@ class PBTClientManager(SyncManager):
self._client_id = int(str(self._client.get_id()))
self._hyperparameters = self._read_config(config_path)
self._parameters_path = None
print(json.dumps(self._hyperparameters, indent=2))

@staticmethod
def step():
Expand All @@ -38,20 +36,15 @@ def step():
def exploit(self):
"""Exploit another member of the population, i.e. copy their parameters and hyperparameters."""
checkpoint = self._client.exploit()
print('{}: EXPLOIT({})'.format(self._client_id, checkpoint.member_id()))
self._hyperparameters = checkpoint.hyperparameters().copy()
self._parameters_path = checkpoint.parameters_path()
print(json.dumps(self._hyperparameters, indent=2))

def explore(self):
"""Explore the hyperparameter space, i.e. randomly mutate each hyperparameter."""
print('{}: EXPLORE'.format(self._client_id))
for k, v in self._hyperparameters.items():
mutation = random.choice([0.8, 1.2])
self._hyperparameters[k] = mutation * v

print(json.dumps(self._hyperparameters, indent=2))

def save(self, parameters_path, metric_value):
"""Save a checkpoint by sending information to the server.
Expand All @@ -74,7 +67,7 @@ def should_exploit(self):

return should_exploit

def checkpoint_path(self):
def parameters_path(self):
"""Get the client's current checkpoint path."""
return self._parameters_path

Expand Down
47 changes: 31 additions & 16 deletions server/pbt_server.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
import json
import random

from multiprocessing.managers import SyncManager


class PBTServer(object):
"""Manager for a population based training session."""
def __init__(self, port, auth_key='', maximize_metric=True):
def __init__(self, port, auth_key='', maximize_metric=True, verbose=True):
"""
Args:
port: Port on which to run the manager server.
auth_key: Authorization key for the manager server.
maximize_metric: Whether the manager should maximize the metric values,
as opposed to minimizing them.
verbose: Log to console if verbose.
"""
auth_key = auth_key.encode('UTF-8')

Expand All @@ -27,6 +29,7 @@ class PBTServerManager(SyncManager):
self._port = port
self._auth_key = auth_key
self._maximize_metric = maximize_metric
self._verbose = verbose

self._checkpoints = {} # Maps member ID -> list of PBTCheckpoints
self._truncation_ratio = 0.2 # Ratio of population for truncation selection
Expand All @@ -39,6 +42,8 @@ def get_id(self):
client_id = self._num_clients
self._num_clients += 1

self._write('New client: {}'.format(client_id))

return client_id

def save(self, checkpoint):
Expand All @@ -47,23 +52,27 @@ def save(self, checkpoint):
Args:
checkpoint: PBCheckpoint containing population member's performance values.
"""
if checkpoint.member_id() not in self._checkpoints:
self._checkpoints[checkpoint.member_id()] = []
self._checkpoints[checkpoint.member_id()].append(checkpoint)
if checkpoint.client_id() not in self._checkpoints:
self._checkpoints[checkpoint.client_id()] = []
self._checkpoints[checkpoint.client_id()].append(checkpoint)

self._write('{}: Saved checkpoint (performance: {})'.format(checkpoint.client_id(), checkpoint.metric_value()))
self._write(json.dumps(checkpoint.hyperparameters(), indent=2))

def should_exploit(self, member_id):
"""Check whether a member should exploit another member of the population
def should_exploit(self, client_id):
"""Check whether a client should exploit another member of the population
Args:
member_id: ID of member asking whether it should exploit another member.
client_id: ID of client asking whether it should exploit another member.
Returns:
True if member is under-performing and should exploit another member.
True if client is under-performing and should exploit another member.
"""
first_surviving_idx = max(1, int(self._truncation_ratio * len(self._checkpoints)))
checkpoints = self._sorted_best_checkpoints(best_first=False)
for checkpoint in checkpoints[:first_surviving_idx]:
if checkpoint.member_id() == member_id:
if checkpoint.client_id() == client_id:
self._write('{}: Should exploit'.format(client_id))
return True

return False
Expand All @@ -78,6 +87,8 @@ def exploit(self):
checkpoints = self._sorted_best_checkpoints(best_first=True)
exploited_checkpoint = random.choice(checkpoints[:first_ineligible_idx])

self._write('{}: Got exploited'.format(exploited_checkpoint.client_id()))

return exploited_checkpoint

def port(self):
Expand All @@ -88,23 +99,23 @@ def shut_down(self):
"""Shut down the server."""
self._server.shutdown()

def _get_best_checkpoint(self, member_id):
def _get_best_checkpoint(self, client_id):
"""Get the best checkpoint for a member of the population,
as rated by checkpoint metric values.
Args:
member_id: ID of the member whose checkpoints will be considered.
client_id: ID of the client whose checkpoints will be considered.
Returns:
Best PBTCheckpoint for this
Best PBTCheckpoint for the specified client.
"""
if member_id not in self._checkpoints or len(self._checkpoints[member_id]) == 0:
if client_id not in self._checkpoints or len(self._checkpoints[client_id]) == 0:
raise ValueError('_get_best_checkpoint called on a member with no registered checkpoints.')

if self._maximize_metric:
best_checkpoint = max(self._checkpoints[member_id])
best_checkpoint = max(self._checkpoints[client_id])
else:
best_checkpoint = min(self._checkpoints[member_id])
best_checkpoint = min(self._checkpoints[client_id])

return best_checkpoint

Expand All @@ -117,7 +128,11 @@ def _sorted_best_checkpoints(self, best_first=True):
Returns:
List of best checkpoints for all members.
"""
best_checkpoints = [self._get_best_checkpoint(member_id) for member_id in self._checkpoints]
best_checkpoints = [self._get_best_checkpoint(client_id) for client_id in self._checkpoints]
sorted_best_checkpoints = list(sorted(best_checkpoints, reverse=(best_first == self._maximize_metric)))

return sorted_best_checkpoints

def _write(self, s):
if self._verbose:
print(s)
4 changes: 2 additions & 2 deletions templates/config.csv
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
hyperparameter,min_value,max_value,search_scale
lr,0.0001,0.1,log
regular_lr,0.0001,0.1,log
fine_tuning_lr,0.00001,0.01,log
weight_decay,0.00001,0.001,log
dropout,0.8,0.2,linear
41 changes: 41 additions & 0 deletions templates/run_clients_slurm.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#!/bin/bash

# Example script to start 16 clients using SLURM

#SBATCH --partition=deep
#SBATCH --time=72:00:00
#SBATCH --nodes=12
#SBATCH --ntasks=12
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=4
#SBATCH --mem=8G

# GPUs per node
#SBATCH --gres=gpu:1

#SBATCH --job-name="pbt"
#SBATCH --output=%j_pbt.out

# Send email when the job begins, fails, or ends
#SBATCH [email protected]
#SBATCH --mail-type=ALL

# Print some useful job information
echo "SLURM_JOBID="$SLURM_JOBID
echo "SLURM_JOB_NODELIST"=$SLURM_JOB_NODELIST
echo "SLURM_PWD "$SLURM_SUBMIT_DIR

# Command to run on each node
srun --gres=gpu:1 --exclusive -n1 -N1-1 python train.py --name=inception_PBT --use_pbt=True --pbt_server_url=deep24 &
srun --gres=gpu:1 --exclusive -n1 -N1-1 python train.py --name=inception_PBT --use_pbt=True --pbt_server_url=deep24 &
srun --gres=gpu:1 --exclusive -n1 -N1-1 python train.py --name=inception_PBT --use_pbt=True --pbt_server_url=deep24 &
srun --gres=gpu:1 --exclusive -n1 -N1-1 python train.py --name=inception_PBT --use_pbt=True --pbt_server_url=deep24 &
srun --gres=gpu:1 --exclusive -n1 -N1-1 python train.py --name=inception_PBT --use_pbt=True --pbt_server_url=deep24 &
srun --gres=gpu:1 --exclusive -n1 -N1-1 python train.py --name=inception_PBT --use_pbt=True --pbt_server_url=deep24 &
srun --gres=gpu:1 --exclusive -n1 -N1-1 python train.py --name=inception_PBT --use_pbt=True --pbt_server_url=deep24 &
srun --gres=gpu:1 --exclusive -n1 -N1-1 python train.py --name=inception_PBT --use_pbt=True --pbt_server_url=deep24 &
srun --gres=gpu:1 --exclusive -n1 -N1-1 python train.py --name=inception_PBT --use_pbt=True --pbt_server_url=deep24 &
srun --gres=gpu:1 --exclusive -n1 -N1-1 python train.py --name=inception_PBT --use_pbt=True --pbt_server_url=deep24 &
srun --gres=gpu:1 --exclusive -n1 -N1-1 python train.py --name=inception_PBT --use_pbt=True --pbt_server_url=deep24 &
srun --gres=gpu:1 --exclusive -n1 -N1-1 python train.py --name=inception_PBT --use_pbt=True --pbt_server_url=deep24 &
wait
31 changes: 0 additions & 31 deletions templates/slurm_clients.sh

This file was deleted.

0 comments on commit 37c097c

Please sign in to comment.