Skip to content

Commit

Permalink
Resnet transfer learning (tensorflow#5047)
Browse files Browse the repository at this point in the history
* warm start a resent with all but the dense layer and only update the final layer weights when fine tuning

* Update README for Transfer Learning

* make lint happy and variable naming error related to scaled gradients

* edit the test cases for cifar10 and imagenet to reflect the default case of no fine tuning
  • Loading branch information
zacwellmer authored and Taylor Robie committed Aug 14, 2018
1 parent c2f755c commit 7bffd37
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 6 deletions.
16 changes: 16 additions & 0 deletions official/resnet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,19 @@ ResNet-50 v2 (fp16, Accuracy 75.56%):
ResNet-50 v1 (Accuracy 75.91%):
* [Checkpoint](http://download.tensorflow.org/models/official/20180601_resnet_v1_imagenet_checkpoint.tar.gz)
* [SavedModel](http://download.tensorflow.org/models/official/20180601_resnet_v1_imagenet_savedmodel.tar.gz)

### Transfer Learning
You can use a pretrained model to initialize a training process. In addition you are able to freeze all but the final fully connected layers to fine tune your model. Transfer Learning is useful when training on your own small datasets. For a brief look at transfer learning in the context of convolutional neural networks, we recommend reading these [short notes](http://cs231n.github.io/transfer-learning/).


To fine tune a pretrained resnet you must make three changes to your training procedure:

1) Build the exact same model as previously except we change the number of labels in the final classification layer.

2) Restore all weights from the pre-trained resnet except for the final classification layer; this will get randomly initialized instead.

3) Freeze earlier layers of the network

We can perform these three operations by specifying two flags: ```--pretrained_model_checkpoint_path``` and ```--fine_tune```. The first flag is a string that points to the path of a pre-trained resnet model. If this flag is specified, it will load all but the final classification layer. A key thing to note: if both ```--pretrained_model_checkpoint_path``` and a non empty ```model_dir``` directory are passed, the tensorflow estimator will load only the ```model_dir```. For more on this please see [WarmStartSettings](https://www.tensorflow.org/versions/master/api_docs/python/tf/estimator/WarmStartSettings) and [Estimators](https://www.tensorflow.org/guide/estimators).

The second flag ```--fine_tune``` is a boolean that indicates whether earlier layers of the network should be frozen. You may set this flag to false if you wish to continue training a pre-trained model from a checkpoint. If you set this flag to true, you can train a new classification layer from scratch.
3 changes: 2 additions & 1 deletion official/resnet/cifar10_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,8 @@ def loss_filter_fn(_):
resnet_version=params['resnet_version'],
loss_scale=params['loss_scale'],
loss_filter_fn=loss_filter_fn,
dtype=params['dtype']
dtype=params['dtype'],
fine_tune=params['fine_tune']
)


Expand Down
1 change: 1 addition & 0 deletions official/resnet/cifar10_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def cifar10_model_fn_helper(self, mode, resnet_version, dtype):
'batch_size': _BATCH_SIZE,
'resnet_version': resnet_version,
'loss_scale': 128 if dtype == tf.float16 else 1,
'fine_tune': False,
})

predictions = spec.predictions
Expand Down
3 changes: 2 additions & 1 deletion official/resnet/imagenet_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,8 @@ def imagenet_model_fn(features, labels, mode, params):
resnet_version=params['resnet_version'],
loss_scale=params['loss_scale'],
loss_filter_fn=None,
dtype=params['dtype']
dtype=params['dtype'],
fine_tune=params['fine_tune']
)


Expand Down
1 change: 1 addition & 0 deletions official/resnet/imagenet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ def resnet_model_fn_helper(self, mode, resnet_version, dtype):
'batch_size': _BATCH_SIZE,
'resnet_version': resnet_version,
'loss_scale': 128 if dtype == tf.float16 else 1,
'fine_tune': False,
})

predictions = spec.predictions
Expand Down
45 changes: 41 additions & 4 deletions official/resnet/resnet_run_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,8 @@ def learning_rate_fn(global_step):
def resnet_model_fn(features, labels, mode, model_class,
resnet_size, weight_decay, learning_rate_fn, momentum,
data_format, resnet_version, loss_scale,
loss_filter_fn=None, dtype=resnet_model.DEFAULT_DTYPE):
loss_filter_fn=None, dtype=resnet_model.DEFAULT_DTYPE,
fine_tune=False):
"""Shared functionality for different resnet model_fns.
Initializes the ResnetModel representing the model layers
Expand Down Expand Up @@ -210,6 +211,7 @@ def resnet_model_fn(features, labels, mode, model_class,
otherwise. If None, batch_normalization variables will be excluded
from the loss.
dtype: the TensorFlow dtype to use for calculations.
fine_tune: If True only train the dense layers(final layers).
Returns:
EstimatorSpec parameterized according to the input params and the
Expand Down Expand Up @@ -281,19 +283,36 @@ def exclude_batch_norm(name):
momentum=momentum
)

def _dense_grad_filter(gvs):
'''
only apply gradient updates to the final layer. This function is used for
fine tuning
Args:
gvs : list of tuples with gradients and variable info
Returns:
filtered gradients so that only the dense layer remains
'''
return [(g, v) for g, v in gvs if 'dense' in v.name]

if loss_scale != 1:
# When computing fp16 gradients, often intermediate tensor values are
# so small, they underflow to 0. To avoid this, we multiply the loss by
# loss_scale to make these tensor values loss_scale times bigger.
scaled_grad_vars = optimizer.compute_gradients(loss * loss_scale)

if fine_tune:
scaled_grad_vars = _dense_grad_filter(scaled_grad_vars)

# Once the gradient computation is complete we can scale the gradients
# back to the correct scale before passing them to the optimizer.
unscaled_grad_vars = [(grad / loss_scale, var)
for grad, var in scaled_grad_vars]
minimize_op = optimizer.apply_gradients(unscaled_grad_vars, global_step)
else:
minimize_op = optimizer.minimize(loss, global_step)
grad_vars = optimizer.compute_gradients(loss)
if fine_tune:
grad_vars = _dense_grad_filter(grad_vars)
minimize_op = optimizer.apply_gradients(grad_vars, global_step)

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
train_op = tf.group(minimize_op, update_ops)
Expand Down Expand Up @@ -354,15 +373,24 @@ def resnet_main(
run_config = tf.estimator.RunConfig(
train_distribute=distribution_strategy, session_config=session_config)

# initialize our model with all but the dense layer from pretrained resnet
if flags_obj.pretrained_model_checkpoint_path is not None:
warm_start_settings = tf.estimator.WarmStartSettings(
flags_obj.pretrained_model_checkpoint_path,
vars_to_warm_start='^(?!.*dense)')
else:
warm_start_settings = None

classifier = tf.estimator.Estimator(
model_fn=model_function, model_dir=flags_obj.model_dir, config=run_config,
params={
warm_start_from=warm_start_settings, params={
'resnet_size': int(flags_obj.resnet_size),
'data_format': flags_obj.data_format,
'batch_size': flags_obj.batch_size,
'resnet_version': int(flags_obj.resnet_version),
'loss_scale': flags_core.get_loss_scale(flags_obj),
'dtype': flags_core.get_tf_dtype(flags_obj)
'dtype': flags_core.get_tf_dtype(flags_obj),
'fine_tune': flags_obj.fine_tune
})

run_params = {
Expand Down Expand Up @@ -446,6 +474,15 @@ def define_resnet_flags(resnet_size_choices=None):
enum_values=['1', '2'],
help=flags_core.help_wrap(
'Version of ResNet. (1 or 2) See README.md for details.'))
flags.DEFINE_bool(
name='fine_tune', short_name='ft', default=False,
help=flags_core.help_wrap(
'If True do not train any parameters except for the final layer.'))
flags.DEFINE_string(
name='pretrained_model_checkpoint_path', short_name='pmcp', default=None,
help=flags_core.help_wrap(
'If not None initialize all the network except the final layer with '
'these values'))

choice_kwargs = dict(
name='resnet_size', short_name='rs', default='50',
Expand Down

0 comments on commit 7bffd37

Please sign in to comment.