Skip to content

Commit

Permalink
Add 5 epoch warmup to resnet (tensorflow#5176)
Browse files Browse the repository at this point in the history
* Add 5 epoch warmup

* get_lr with warm_up only for imagenet

* Add base_lr, remove fp16 unittest arg validation

* Remove validation check stopping v1 and FP16
  • Loading branch information
tfboyd authored and Taylor Robie committed Aug 27, 2018
1 parent 981c003 commit 9bf586d
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 29 deletions.
7 changes: 0 additions & 7 deletions official/resnet/cifar10_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,13 +165,6 @@ def test_cifar10_end_to_end_synthetic_v2(self):
extra_flags=['-resnet_version', '2']
)

def test_flag_restriction(self):
with self.assertRaises(SystemExit):
integration.run_synthetic(
main=cifar10_main.run_cifar, tmp_root=self.get_temp_dir(),
extra_flags=['-resnet_version', '1', "-dtype", "fp16"]
)


if __name__ == '__main__':
tf.test.main()
12 changes: 11 additions & 1 deletion official/resnet/imagenet_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,10 +285,20 @@ def _get_block_sizes(resnet_size):

def imagenet_model_fn(features, labels, mode, params):
"""Our model_fn for ResNet to be used with our Estimator."""

# Warmup and higher lr may not be valid for fine tuning with small batches
# and smaller numbers of training images.
if params['fine_tune']:
warmup = False
base_lr = .1
else:
warmup = True
base_lr = .128

learning_rate_fn = resnet_run_loop.learning_rate_with_decay(
batch_size=params['batch_size'], batch_denom=256,
num_images=_NUM_IMAGES['train'], boundary_epochs=[30, 60, 80, 90],
decay_rates=[1, 0.1, 0.01, 0.001, 1e-4])
decay_rates=[1, 0.1, 0.01, 0.001, 1e-4], warmup=warmup, base_lr=base_lr)

return resnet_run_loop.resnet_model_fn(
features=features,
Expand Down
7 changes: 0 additions & 7 deletions official/resnet/imagenet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,13 +304,6 @@ def test_imagenet_end_to_end_synthetic_v2_huge(self):
extra_flags=['-resnet_version', '2', '-resnet_size', '200']
)

def test_flag_restriction(self):
with self.assertRaises(SystemExit):
integration.run_synthetic(
main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(),
extra_flags=['-resnet_version', '1', '-dtype', 'fp16']
)


if __name__ == '__main__':
tf.test.main()
28 changes: 14 additions & 14 deletions official/resnet/resnet_run_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,8 @@ def input_fn(is_training, data_dir, batch_size, *args, **kwargs): # pylint: dis
# Functions for running training/eval/validation loops for the model.
################################################################################
def learning_rate_with_decay(
batch_size, batch_denom, num_images, boundary_epochs, decay_rates):
batch_size, batch_denom, num_images, boundary_epochs, decay_rates,
base_lr=0.1, warmup=False):
"""Get a learning rate that decays step-wise as training progresses.
Args:
Expand All @@ -152,13 +153,14 @@ def learning_rate_with_decay(
decay_rates: list of floats representing the decay rates to be used
for scaling the learning rate. It should have one more element
than `boundary_epochs`, and all elements should have the same type.
base_lr: Initial learning rate scaled based on batch_denom.
warmup: Run a 5 epoch warmup to the initial lr.
Returns:
Returns a function that takes a single argument - the number of batches
trained so far (global_step)- and returns the learning rate to be used
for training the next batch.
"""
initial_learning_rate = 0.1 * batch_size / batch_denom
initial_learning_rate = base_lr * batch_size / batch_denom
batches_per_epoch = num_images / batch_size

# Reduce the learning rate at certain epochs.
Expand All @@ -168,8 +170,15 @@ def learning_rate_with_decay(
vals = [initial_learning_rate * decay for decay in decay_rates]

def learning_rate_fn(global_step):
global_step = tf.cast(global_step, tf.int32)
return tf.train.piecewise_constant(global_step, boundaries, vals)
"""Builds scaled learning rate function with 5 epoch warm up."""
lr = tf.train.piecewise_constant(global_step, boundaries, vals)
if warmup:
warmup_steps = int(batches_per_epoch * 5)
warmup_lr = (
initial_learning_rate * tf.cast(global_step, tf.float32) / tf.cast(
warmup_steps, tf.float32))
return tf.cond(global_step < warmup_steps, lambda: warmup_lr, lambda: lr)
return lr

return learning_rate_fn

Expand Down Expand Up @@ -499,12 +508,3 @@ def define_resnet_flags(resnet_size_choices=None):
flags.DEFINE_string(**choice_kwargs)
else:
flags.DEFINE_enum(enum_values=resnet_size_choices, **choice_kwargs)

# The current implementation of ResNet v1 is numerically unstable when run
# with fp16 and will produce NaN errors soon after training begins.
msg = ('ResNet version 1 is not currently supported with fp16. '
'Please use version 2 instead.')
@flags.multi_flags_validator(['dtype', 'resnet_version'], message=msg)
def _forbid_v1_fp16(flag_values): # pylint: disable=unused-variable
return (flags_core.DTYPE_MAP[flag_values['dtype']][0] != tf.float16 or
flag_values['resnet_version'] != '1')

0 comments on commit 9bf586d

Please sign in to comment.