Skip to content

Commit

Permalink
ResNet eval_only mode (tensorflow#5186)
Browse files Browse the repository at this point in the history
* Make ResNet robust to the case that epochs_between_evals does not divide train_epochs, and add an --eval_only option

* add some comments to make the control flow easier to follow

* address PR comments
  • Loading branch information
Taylor Robie authored Aug 27, 2018
1 parent 9bf586d commit d1c48af
Showing 1 changed file with 29 additions and 10 deletions.
39 changes: 29 additions & 10 deletions official/resnet/resnet_run_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from __future__ import division
from __future__ import print_function

import math
import os

# pylint: disable=g-bad-import-order
Expand Down Expand Up @@ -429,12 +430,12 @@ def resnet_main(
model_dir=flags_obj.model_dir,
batch_size=flags_obj.batch_size)

def input_fn_train():
def input_fn_train(num_epochs):
return input_function(
is_training=True, data_dir=flags_obj.data_dir,
batch_size=distribution_utils.per_device_batch_size(
flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
num_epochs=flags_obj.epochs_between_evals,
num_epochs=num_epochs,
num_gpus=flags_core.get_num_gpus(flags_obj))

def input_fn_eval():
Expand All @@ -444,14 +445,28 @@ def input_fn_eval():
flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
num_epochs=1)

total_training_cycle = (flags_obj.train_epochs //
flags_obj.epochs_between_evals)
for cycle_index in range(total_training_cycle):
tf.logging.info('Starting a training cycle: %d/%d',
cycle_index, total_training_cycle)

classifier.train(input_fn=input_fn_train, hooks=train_hooks,
max_steps=flags_obj.max_train_steps)
if flags_obj.eval_only or not flags_obj.train_epochs:
# If --eval_only is set, perform a single loop with zero train epochs.
schedule, n_loops = [0], 1
else:
# Compute the number of times to loop while training. All but the last
# pass will train for `epochs_between_evals` epochs, while the last will
# train for the number needed to reach `training_epochs`. For instance if
# train_epochs = 25 and epochs_between_evals = 10
# schedule will be set to [10, 10, 5]. That is to say, the loop will:
# Train for 10 epochs and then evaluate.
# Train for another 10 epochs and then evaluate.
# Train for a final 5 epochs (to reach 25 epochs) and then evaluate.
n_loops = math.ceil(flags_obj.train_epochs / flags_obj.epochs_between_evals)
schedule = [flags_obj.epochs_between_evals for _ in range(int(n_loops))]
schedule[-1] = flags_obj.train_epochs - sum(schedule[:-1]) # over counting.

for cycle_index, num_train_epochs in enumerate(schedule):
tf.logging.info('Starting cycle: %d/%d', cycle_index, int(n_loops))

if num_train_epochs:
classifier.train(input_fn=lambda: input_fn_train(num_train_epochs),
hooks=train_hooks, max_steps=flags_obj.max_train_steps)

tf.logging.info('Starting to evaluate.')

Expand Down Expand Up @@ -499,6 +514,10 @@ def define_resnet_flags(resnet_size_choices=None):
help=flags_core.help_wrap(
'If not None initialize all the network except the final layer with '
'these values'))
flags.DEFINE_boolean(
name="eval_only", default=False,
help=flags_core.help_wrap('Skip training and only perform evaluation on '
'the latest checkpoint.'))

choice_kwargs = dict(
name='resnet_size', short_name='rs', default='50',
Expand Down

0 comments on commit d1c48af

Please sign in to comment.