From 37c097c0afe7fc4213906d25164e246f5b99728c Mon Sep 17 00:00:00 2001 From: Christopher Chute Date: Mon, 20 Aug 2018 19:15:44 -0700 Subject: [PATCH] Clean up after SLURM test. --- checkpoint/pbt_checkpoint.py | 10 ++++---- client/pbt_client.py | 11 ++------ server/pbt_server.py | 47 ++++++++++++++++++++++------------ templates/config.csv | 4 +-- templates/run_clients_slurm.sh | 41 +++++++++++++++++++++++++++++ templates/slurm_clients.sh | 31 ---------------------- 6 files changed, 81 insertions(+), 63 deletions(-) create mode 100644 templates/run_clients_slurm.sh delete mode 100644 templates/slurm_clients.sh diff --git a/checkpoint/pbt_checkpoint.py b/checkpoint/pbt_checkpoint.py index 42736e5..8474709 100644 --- a/checkpoint/pbt_checkpoint.py +++ b/checkpoint/pbt_checkpoint.py @@ -1,21 +1,21 @@ class PBTCheckpoint(object): """Checkpoint for saving model performance during training.""" - def __init__(self, member_id, metric_value, hyperparameters, parameters_path): + def __init__(self, client_id, metric_value, hyperparameters, parameters_path): """ Args: - member_id (int): ID for population member. + client_id (int): ID for population member. metric_value (float): Value of metric for determining which checkpoints are best. hyperparameters (dict): Dictionary of hyperparameters. parameters_path (str): Path to saved network parameters. """ - self._member_id = member_id + self._client_id = client_id self._metric_value = metric_value self._hyperparameters = hyperparameters.copy() self._parameters_path = parameters_path - def member_id(self): - return self._member_id + def client_id(self): + return self._client_id def metric_value(self): return self._metric_value diff --git a/client/pbt_client.py b/client/pbt_client.py index f1e6cf3..7b994ee 100644 --- a/client/pbt_client.py +++ b/client/pbt_client.py @@ -1,11 +1,10 @@ -import json import math import pandas as pd import random import time from ast import literal_eval -from checkpoint import PBTCheckpoint +from pbt.checkpoint import PBTCheckpoint from multiprocessing.managers import SyncManager @@ -28,7 +27,6 @@ class PBTClientManager(SyncManager): self._client_id = int(str(self._client.get_id())) self._hyperparameters = self._read_config(config_path) self._parameters_path = None - print(json.dumps(self._hyperparameters, indent=2)) @staticmethod def step(): @@ -38,20 +36,15 @@ def step(): def exploit(self): """Exploit another member of the population, i.e. copy their parameters and hyperparameters.""" checkpoint = self._client.exploit() - print('{}: EXPLOIT({})'.format(self._client_id, checkpoint.member_id())) self._hyperparameters = checkpoint.hyperparameters().copy() self._parameters_path = checkpoint.parameters_path() - print(json.dumps(self._hyperparameters, indent=2)) def explore(self): """Explore the hyperparameter space, i.e. randomly mutate each hyperparameter.""" - print('{}: EXPLORE'.format(self._client_id)) for k, v in self._hyperparameters.items(): mutation = random.choice([0.8, 1.2]) self._hyperparameters[k] = mutation * v - print(json.dumps(self._hyperparameters, indent=2)) - def save(self, parameters_path, metric_value): """Save a checkpoint by sending information to the server. @@ -74,7 +67,7 @@ def should_exploit(self): return should_exploit - def checkpoint_path(self): + def parameters_path(self): """Get the client's current checkpoint path.""" return self._parameters_path diff --git a/server/pbt_server.py b/server/pbt_server.py index 9c33400..ffeaddf 100644 --- a/server/pbt_server.py +++ b/server/pbt_server.py @@ -1,3 +1,4 @@ +import json import random from multiprocessing.managers import SyncManager @@ -5,13 +6,14 @@ class PBTServer(object): """Manager for a population based training session.""" - def __init__(self, port, auth_key='', maximize_metric=True): + def __init__(self, port, auth_key='', maximize_metric=True, verbose=True): """ Args: port: Port on which to run the manager server. auth_key: Authorization key for the manager server. maximize_metric: Whether the manager should maximize the metric values, as opposed to minimizing them. + verbose: Log to console if verbose. """ auth_key = auth_key.encode('UTF-8') @@ -27,6 +29,7 @@ class PBTServerManager(SyncManager): self._port = port self._auth_key = auth_key self._maximize_metric = maximize_metric + self._verbose = verbose self._checkpoints = {} # Maps member ID -> list of PBTCheckpoints self._truncation_ratio = 0.2 # Ratio of population for truncation selection @@ -39,6 +42,8 @@ def get_id(self): client_id = self._num_clients self._num_clients += 1 + self._write('New client: {}'.format(client_id)) + return client_id def save(self, checkpoint): @@ -47,23 +52,27 @@ def save(self, checkpoint): Args: checkpoint: PBCheckpoint containing population member's performance values. """ - if checkpoint.member_id() not in self._checkpoints: - self._checkpoints[checkpoint.member_id()] = [] - self._checkpoints[checkpoint.member_id()].append(checkpoint) + if checkpoint.client_id() not in self._checkpoints: + self._checkpoints[checkpoint.client_id()] = [] + self._checkpoints[checkpoint.client_id()].append(checkpoint) + + self._write('{}: Saved checkpoint (performance: {})'.format(checkpoint.client_id(), checkpoint.metric_value())) + self._write(json.dumps(checkpoint.hyperparameters(), indent=2)) - def should_exploit(self, member_id): - """Check whether a member should exploit another member of the population + def should_exploit(self, client_id): + """Check whether a client should exploit another member of the population Args: - member_id: ID of member asking whether it should exploit another member. + client_id: ID of client asking whether it should exploit another member. Returns: - True if member is under-performing and should exploit another member. + True if client is under-performing and should exploit another member. """ first_surviving_idx = max(1, int(self._truncation_ratio * len(self._checkpoints))) checkpoints = self._sorted_best_checkpoints(best_first=False) for checkpoint in checkpoints[:first_surviving_idx]: - if checkpoint.member_id() == member_id: + if checkpoint.client_id() == client_id: + self._write('{}: Should exploit'.format(client_id)) return True return False @@ -78,6 +87,8 @@ def exploit(self): checkpoints = self._sorted_best_checkpoints(best_first=True) exploited_checkpoint = random.choice(checkpoints[:first_ineligible_idx]) + self._write('{}: Got exploited'.format(exploited_checkpoint.client_id())) + return exploited_checkpoint def port(self): @@ -88,23 +99,23 @@ def shut_down(self): """Shut down the server.""" self._server.shutdown() - def _get_best_checkpoint(self, member_id): + def _get_best_checkpoint(self, client_id): """Get the best checkpoint for a member of the population, as rated by checkpoint metric values. Args: - member_id: ID of the member whose checkpoints will be considered. + client_id: ID of the client whose checkpoints will be considered. Returns: - Best PBTCheckpoint for this + Best PBTCheckpoint for the specified client. """ - if member_id not in self._checkpoints or len(self._checkpoints[member_id]) == 0: + if client_id not in self._checkpoints or len(self._checkpoints[client_id]) == 0: raise ValueError('_get_best_checkpoint called on a member with no registered checkpoints.') if self._maximize_metric: - best_checkpoint = max(self._checkpoints[member_id]) + best_checkpoint = max(self._checkpoints[client_id]) else: - best_checkpoint = min(self._checkpoints[member_id]) + best_checkpoint = min(self._checkpoints[client_id]) return best_checkpoint @@ -117,7 +128,11 @@ def _sorted_best_checkpoints(self, best_first=True): Returns: List of best checkpoints for all members. """ - best_checkpoints = [self._get_best_checkpoint(member_id) for member_id in self._checkpoints] + best_checkpoints = [self._get_best_checkpoint(client_id) for client_id in self._checkpoints] sorted_best_checkpoints = list(sorted(best_checkpoints, reverse=(best_first == self._maximize_metric))) return sorted_best_checkpoints + + def _write(self, s): + if self._verbose: + print(s) diff --git a/templates/config.csv b/templates/config.csv index 73a7792..43c3bda 100644 --- a/templates/config.csv +++ b/templates/config.csv @@ -1,4 +1,4 @@ hyperparameter,min_value,max_value,search_scale -lr,0.0001,0.1,log +regular_lr,0.0001,0.1,log +fine_tuning_lr,0.00001,0.01,log weight_decay,0.00001,0.001,log -dropout,0.8,0.2,linear diff --git a/templates/run_clients_slurm.sh b/templates/run_clients_slurm.sh new file mode 100644 index 0000000..e61b971 --- /dev/null +++ b/templates/run_clients_slurm.sh @@ -0,0 +1,41 @@ +#!/bin/bash + +# Example script to start 16 clients using SLURM + +#SBATCH --partition=deep +#SBATCH --time=72:00:00 +#SBATCH --nodes=12 +#SBATCH --ntasks=12 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=4 +#SBATCH --mem=8G + +# GPUs per node +#SBATCH --gres=gpu:1 + +#SBATCH --job-name="pbt" +#SBATCH --output=%j_pbt.out + +# Send email when the job begins, fails, or ends +#SBATCH --mail-user=chute@stanford.edu +#SBATCH --mail-type=ALL + +# Print some useful job information +echo "SLURM_JOBID="$SLURM_JOBID +echo "SLURM_JOB_NODELIST"=$SLURM_JOB_NODELIST +echo "SLURM_PWD "$SLURM_SUBMIT_DIR + +# Command to run on each node +srun --gres=gpu:1 --exclusive -n1 -N1-1 python train.py --name=inception_PBT --use_pbt=True --pbt_server_url=deep24 & +srun --gres=gpu:1 --exclusive -n1 -N1-1 python train.py --name=inception_PBT --use_pbt=True --pbt_server_url=deep24 & +srun --gres=gpu:1 --exclusive -n1 -N1-1 python train.py --name=inception_PBT --use_pbt=True --pbt_server_url=deep24 & +srun --gres=gpu:1 --exclusive -n1 -N1-1 python train.py --name=inception_PBT --use_pbt=True --pbt_server_url=deep24 & +srun --gres=gpu:1 --exclusive -n1 -N1-1 python train.py --name=inception_PBT --use_pbt=True --pbt_server_url=deep24 & +srun --gres=gpu:1 --exclusive -n1 -N1-1 python train.py --name=inception_PBT --use_pbt=True --pbt_server_url=deep24 & +srun --gres=gpu:1 --exclusive -n1 -N1-1 python train.py --name=inception_PBT --use_pbt=True --pbt_server_url=deep24 & +srun --gres=gpu:1 --exclusive -n1 -N1-1 python train.py --name=inception_PBT --use_pbt=True --pbt_server_url=deep24 & +srun --gres=gpu:1 --exclusive -n1 -N1-1 python train.py --name=inception_PBT --use_pbt=True --pbt_server_url=deep24 & +srun --gres=gpu:1 --exclusive -n1 -N1-1 python train.py --name=inception_PBT --use_pbt=True --pbt_server_url=deep24 & +srun --gres=gpu:1 --exclusive -n1 -N1-1 python train.py --name=inception_PBT --use_pbt=True --pbt_server_url=deep24 & +srun --gres=gpu:1 --exclusive -n1 -N1-1 python train.py --name=inception_PBT --use_pbt=True --pbt_server_url=deep24 & +wait diff --git a/templates/slurm_clients.sh b/templates/slurm_clients.sh deleted file mode 100644 index dbbcae0..0000000 --- a/templates/slurm_clients.sh +++ /dev/null @@ -1,31 +0,0 @@ -#!/bin/bash - -# Example script to start 16 clients using SLURM - -#SBATCH --partition=deep -#SBATCH --time=24:00:00 -#SBATCH --nodes=16 -#SBATCH --ntasks=16 -#SBATCH --ntasks-per-node=1 -#SBATCH --cpus-per-task=2 -#SBATCH --mem=8G - -# One GPU per worker -#SBATCH --gres=gpu:1 - -#SBATCH --job-name="pbt" -#SBATCH --output=pbt%j.out - -# Send email when the job begins, fails, or ends -#SBATCH --mail-user=chute@stanford.edu -#SBATCH --mail-type=ALL - -# Print some useful job information -echo "SLURM_JOBID="$SLURM_JOBID -echo "SLURM_JOB_NODELIST"=$SLURM_JOB_NODELIST -echo "SLURM_NNODES"=$SLURM_NNODES -echo "SLURMTMPDIR="$SLURMTMPDIR -echo "working directory = "$SLURM_SUBMIT_DIR - -# Command to run on each node -srun python scripts/run_client.py --server_url=deep6 --config_path=templates/config.csv