diff --git a/official/resnet/cifar10_main.py b/official/resnet/cifar10_main.py index de4a58e16a8..6f22fdf7233 100644 --- a/official/resnet/cifar10_main.py +++ b/official/resnet/cifar10_main.py @@ -140,7 +140,7 @@ class Cifar10Model(resnet_model.Model): """Model class with appropriate defaults for CIFAR-10 data.""" def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES, - version=resnet_model.DEFAULT_VERSION, + resnet_version=resnet_model.DEFAULT_VERSION, dtype=resnet_model.DEFAULT_DTYPE): """These are the parameters that work for CIFAR-10 data. @@ -150,8 +150,8 @@ def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES, data format to use when setting up the model. num_classes: The number of output classes needed from the model. This enables users to extend the same model to their own datasets. - version: Integer representing which version of the ResNet network to use. - See README for details. Valid values: [1, 2] + resnet_version: Integer representing which version of the ResNet network + to use. See README for details. Valid values: [1, 2] dtype: The TensorFlow dtype to use for calculations. Raises: @@ -174,7 +174,7 @@ def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES, block_sizes=[num_blocks] * 3, block_strides=[1, 2, 2], final_size=64, - version=version, + resnet_version=resnet_version, data_format=data_format, dtype=dtype ) @@ -211,7 +211,7 @@ def loss_filter_fn(_): learning_rate_fn=learning_rate_fn, momentum=0.9, data_format=params['data_format'], - version=params['version'], + resnet_version=params['resnet_version'], loss_scale=params['loss_scale'], loss_filter_fn=loss_filter_fn, dtype=params['dtype'] diff --git a/official/resnet/cifar10_test.py b/official/resnet/cifar10_test.py index 13adf46ecee..74d4694858d 100644 --- a/official/resnet/cifar10_test.py +++ b/official/resnet/cifar10_test.py @@ -76,7 +76,7 @@ def test_dataset_input_fn(self): for pixel in row: self.assertAllClose(pixel, np.array([-1.225, 0., 1.225]), rtol=1e-3) - def cifar10_model_fn_helper(self, mode, version, dtype): + def cifar10_model_fn_helper(self, mode, resnet_version, dtype): input_fn = cifar10_main.get_synth_input_fn() dataset = input_fn(True, '', _BATCH_SIZE) iterator = dataset.make_one_shot_iterator() @@ -87,7 +87,7 @@ def cifar10_model_fn_helper(self, mode, version, dtype): 'resnet_size': 32, 'data_format': 'channels_last', 'batch_size': _BATCH_SIZE, - 'version': version, + 'resnet_version': resnet_version, 'loss_scale': 128 if dtype == tf.float16 else 1, }) @@ -111,56 +111,57 @@ def cifar10_model_fn_helper(self, mode, version, dtype): self.assertEqual(eval_metric_ops['accuracy'][1].dtype, tf.float32) def test_cifar10_model_fn_train_mode_v1(self): - self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=1, + self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, resnet_version=1, dtype=tf.float32) def test_cifar10_model_fn_trainmode__v2(self): - self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=2, + self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, resnet_version=2, dtype=tf.float32) def test_cifar10_model_fn_eval_mode_v1(self): - self.cifar10_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=1, + self.cifar10_model_fn_helper(tf.estimator.ModeKeys.EVAL, resnet_version=1, dtype=tf.float32) def test_cifar10_model_fn_eval_mode_v2(self): - self.cifar10_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=2, + self.cifar10_model_fn_helper(tf.estimator.ModeKeys.EVAL, resnet_version=2, dtype=tf.float32) def test_cifar10_model_fn_predict_mode_v1(self): - self.cifar10_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=1, - dtype=tf.float32) + self.cifar10_model_fn_helper(tf.estimator.ModeKeys.PREDICT, + resnet_version=1, dtype=tf.float32) def test_cifar10_model_fn_predict_mode_v2(self): - self.cifar10_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=2, - dtype=tf.float32) + self.cifar10_model_fn_helper(tf.estimator.ModeKeys.PREDICT, + resnet_version=2, dtype=tf.float32) - def _test_cifar10model_shape(self, version): + def _test_cifar10model_shape(self, resnet_version): batch_size = 135 num_classes = 246 model = cifar10_main.Cifar10Model(32, data_format='channels_last', - num_classes=num_classes, version=version) + num_classes=num_classes, + resnet_version=resnet_version) fake_input = tf.random_uniform([batch_size, _HEIGHT, _WIDTH, _NUM_CHANNELS]) output = model(fake_input, training=True) self.assertAllEqual(output.shape, (batch_size, num_classes)) def test_cifar10model_shape_v1(self): - self._test_cifar10model_shape(version=1) + self._test_cifar10model_shape(resnet_version=1) def test_cifar10model_shape_v2(self): - self._test_cifar10model_shape(version=2) + self._test_cifar10model_shape(resnet_version=2) def test_cifar10_end_to_end_synthetic_v1(self): integration.run_synthetic( main=cifar10_main.run_cifar, tmp_root=self.get_temp_dir(), - extra_flags=['-v', '1'] + extra_flags=['-resnet_version', '1'] ) def test_cifar10_end_to_end_synthetic_v2(self): integration.run_synthetic( main=cifar10_main.run_cifar, tmp_root=self.get_temp_dir(), - extra_flags=['-v', '2'] + extra_flags=['-resnet_version', '2'] ) diff --git a/official/resnet/imagenet_main.py b/official/resnet/imagenet_main.py index 6dd19707c15..263f41e34ee 100644 --- a/official/resnet/imagenet_main.py +++ b/official/resnet/imagenet_main.py @@ -197,7 +197,7 @@ class ImagenetModel(resnet_model.Model): """Model class with appropriate defaults for Imagenet data.""" def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES, - version=resnet_model.DEFAULT_VERSION, + resnet_version=resnet_model.DEFAULT_VERSION, dtype=resnet_model.DEFAULT_DTYPE): """These are the parameters that work for Imagenet data. @@ -207,8 +207,8 @@ def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES, data format to use when setting up the model. num_classes: The number of output classes needed from the model. This enables users to extend the same model to their own datasets. - version: Integer representing which version of the ResNet network to use. - See README for details. Valid values: [1, 2] + resnet_version: Integer representing which version of the ResNet network + to use. See README for details. Valid values: [1, 2] dtype: The TensorFlow dtype to use for calculations. """ @@ -232,7 +232,7 @@ def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES, block_sizes=_get_block_sizes(resnet_size), block_strides=[1, 2, 2, 2], final_size=final_size, - version=version, + resnet_version=resnet_version, data_format=data_format, dtype=dtype ) @@ -289,7 +289,7 @@ def imagenet_model_fn(features, labels, mode, params): learning_rate_fn=learning_rate_fn, momentum=0.9, data_format=params['data_format'], - version=params['version'], + resnet_version=params['resnet_version'], loss_scale=params['loss_scale'], loss_filter_fn=None, dtype=params['dtype'] diff --git a/official/resnet/imagenet_test.py b/official/resnet/imagenet_test.py index 046252bc4c7..7394c93ef4d 100644 --- a/official/resnet/imagenet_test.py +++ b/official/resnet/imagenet_test.py @@ -41,7 +41,7 @@ def tearDown(self): super(BaseTest, self).tearDown() tf.gfile.DeleteRecursively(self.get_temp_dir()) - def _tensor_shapes_helper(self, resnet_size, version, dtype, with_gpu): + def _tensor_shapes_helper(self, resnet_size, resnet_version, dtype, with_gpu): """Checks the tensor shapes after each phase of the ResNet model.""" def reshape(shape): """Returns the expected dimensions depending on if a GPU is being used.""" @@ -59,7 +59,7 @@ def reshape(shape): model = imagenet_main.ImagenetModel( resnet_size=resnet_size, data_format='channels_first' if with_gpu else 'channels_last', - version=version, + resnet_version=resnet_version, dtype=dtype ) inputs = tf.random_uniform([1, 224, 224, 3]) @@ -95,97 +95,99 @@ def reshape(shape): self.assertAllEqual(dense.shape, (1, _LABEL_CLASSES)) self.assertAllEqual(output.shape, (1, _LABEL_CLASSES)) - def tensor_shapes_helper(self, resnet_size, version, with_gpu=False): - self._tensor_shapes_helper(resnet_size=resnet_size, version=version, + def tensor_shapes_helper(self, resnet_size, resnet_version, with_gpu=False): + self._tensor_shapes_helper(resnet_size=resnet_size, + resnet_version=resnet_version, dtype=tf.float32, with_gpu=with_gpu) - self._tensor_shapes_helper(resnet_size=resnet_size, version=version, + self._tensor_shapes_helper(resnet_size=resnet_size, + resnet_version=resnet_version, dtype=tf.float16, with_gpu=with_gpu) def test_tensor_shapes_resnet_18_v1(self): - self.tensor_shapes_helper(18, version=1) + self.tensor_shapes_helper(18, resnet_version=1) def test_tensor_shapes_resnet_18_v2(self): - self.tensor_shapes_helper(18, version=2) + self.tensor_shapes_helper(18, resnet_version=2) def test_tensor_shapes_resnet_34_v1(self): - self.tensor_shapes_helper(34, version=1) + self.tensor_shapes_helper(34, resnet_version=1) def test_tensor_shapes_resnet_34_v2(self): - self.tensor_shapes_helper(34, version=2) + self.tensor_shapes_helper(34, resnet_version=2) def test_tensor_shapes_resnet_50_v1(self): - self.tensor_shapes_helper(50, version=1) + self.tensor_shapes_helper(50, resnet_version=1) def test_tensor_shapes_resnet_50_v2(self): - self.tensor_shapes_helper(50, version=2) + self.tensor_shapes_helper(50, resnet_version=2) def test_tensor_shapes_resnet_101_v1(self): - self.tensor_shapes_helper(101, version=1) + self.tensor_shapes_helper(101, resnet_version=1) def test_tensor_shapes_resnet_101_v2(self): - self.tensor_shapes_helper(101, version=2) + self.tensor_shapes_helper(101, resnet_version=2) def test_tensor_shapes_resnet_152_v1(self): - self.tensor_shapes_helper(152, version=1) + self.tensor_shapes_helper(152, resnet_version=1) def test_tensor_shapes_resnet_152_v2(self): - self.tensor_shapes_helper(152, version=2) + self.tensor_shapes_helper(152, resnet_version=2) def test_tensor_shapes_resnet_200_v1(self): - self.tensor_shapes_helper(200, version=1) + self.tensor_shapes_helper(200, resnet_version=1) def test_tensor_shapes_resnet_200_v2(self): - self.tensor_shapes_helper(200, version=2) + self.tensor_shapes_helper(200, resnet_version=2) @unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU') def test_tensor_shapes_resnet_18_with_gpu_v1(self): - self.tensor_shapes_helper(18, version=1, with_gpu=True) + self.tensor_shapes_helper(18, resnet_version=1, with_gpu=True) @unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU') def test_tensor_shapes_resnet_18_with_gpu_v2(self): - self.tensor_shapes_helper(18, version=2, with_gpu=True) + self.tensor_shapes_helper(18, resnet_version=2, with_gpu=True) @unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU') def test_tensor_shapes_resnet_34_with_gpu_v1(self): - self.tensor_shapes_helper(34, version=1, with_gpu=True) + self.tensor_shapes_helper(34, resnet_version=1, with_gpu=True) @unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU') def test_tensor_shapes_resnet_34_with_gpu_v2(self): - self.tensor_shapes_helper(34, version=2, with_gpu=True) + self.tensor_shapes_helper(34, resnet_version=2, with_gpu=True) @unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU') def test_tensor_shapes_resnet_50_with_gpu_v1(self): - self.tensor_shapes_helper(50, version=1, with_gpu=True) + self.tensor_shapes_helper(50, resnet_version=1, with_gpu=True) @unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU') def test_tensor_shapes_resnet_50_with_gpu_v2(self): - self.tensor_shapes_helper(50, version=2, with_gpu=True) + self.tensor_shapes_helper(50, resnet_version=2, with_gpu=True) @unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU') def test_tensor_shapes_resnet_101_with_gpu_v1(self): - self.tensor_shapes_helper(101, version=1, with_gpu=True) + self.tensor_shapes_helper(101, resnet_version=1, with_gpu=True) @unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU') def test_tensor_shapes_resnet_101_with_gpu_v2(self): - self.tensor_shapes_helper(101, version=2, with_gpu=True) + self.tensor_shapes_helper(101, resnet_version=2, with_gpu=True) @unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU') def test_tensor_shapes_resnet_152_with_gpu_v1(self): - self.tensor_shapes_helper(152, version=1, with_gpu=True) + self.tensor_shapes_helper(152, resnet_version=1, with_gpu=True) @unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU') def test_tensor_shapes_resnet_152_with_gpu_v2(self): - self.tensor_shapes_helper(152, version=2, with_gpu=True) + self.tensor_shapes_helper(152, resnet_version=2, with_gpu=True) @unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU') def test_tensor_shapes_resnet_200_with_gpu_v1(self): - self.tensor_shapes_helper(200, version=1, with_gpu=True) + self.tensor_shapes_helper(200, resnet_version=1, with_gpu=True) @unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU') def test_tensor_shapes_resnet_200_with_gpu_v2(self): - self.tensor_shapes_helper(200, version=2, with_gpu=True) + self.tensor_shapes_helper(200, resnet_version=2, with_gpu=True) - def resnet_model_fn_helper(self, mode, version, dtype): + def resnet_model_fn_helper(self, mode, resnet_version, dtype): """Tests that the EstimatorSpec is given the appropriate arguments.""" tf.train.create_global_step() @@ -199,7 +201,7 @@ def resnet_model_fn_helper(self, mode, version, dtype): 'resnet_size': 50, 'data_format': 'channels_last', 'batch_size': _BATCH_SIZE, - 'version': version, + 'resnet_version': resnet_version, 'loss_scale': 128 if dtype == tf.float16 else 1, }) @@ -223,36 +225,36 @@ def resnet_model_fn_helper(self, mode, version, dtype): self.assertEqual(eval_metric_ops['accuracy'][1].dtype, tf.float32) def test_resnet_model_fn_train_mode_v1(self): - self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=1, + self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN, resnet_version=1, dtype=tf.float32) def test_resnet_model_fn_train_mode_v2(self): - self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=2, + self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN, resnet_version=2, dtype=tf.float32) def test_resnet_model_fn_eval_mode_v1(self): - self.resnet_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=1, + self.resnet_model_fn_helper(tf.estimator.ModeKeys.EVAL, resnet_version=1, dtype=tf.float32) def test_resnet_model_fn_eval_mode_v2(self): - self.resnet_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=2, + self.resnet_model_fn_helper(tf.estimator.ModeKeys.EVAL, resnet_version=2, dtype=tf.float32) def test_resnet_model_fn_predict_mode_v1(self): - self.resnet_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=1, + self.resnet_model_fn_helper(tf.estimator.ModeKeys.PREDICT, resnet_version=1, dtype=tf.float32) def test_resnet_model_fn_predict_mode_v2(self): - self.resnet_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=2, + self.resnet_model_fn_helper(tf.estimator.ModeKeys.PREDICT, resnet_version=2, dtype=tf.float32) - def _test_imagenetmodel_shape(self, version): + def _test_imagenetmodel_shape(self, resnet_version): batch_size = 135 num_classes = 246 model = imagenet_main.ImagenetModel( 50, data_format='channels_last', num_classes=num_classes, - version=version) + resnet_version=resnet_version) fake_input = tf.random_uniform([batch_size, 224, 224, 3]) output = model(fake_input, training=True) @@ -260,10 +262,10 @@ def _test_imagenetmodel_shape(self, version): self.assertAllEqual(output.shape, (batch_size, num_classes)) def test_imagenetmodel_shape_v1(self): - self._test_imagenetmodel_shape(version=1) + self._test_imagenetmodel_shape(resnet_version=1) def test_imagenetmodel_shape_v2(self): - self._test_imagenetmodel_shape(version=2) + self._test_imagenetmodel_shape(resnet_version=2) def test_imagenet_end_to_end_synthetic_v1(self): integration.run_synthetic( @@ -280,25 +282,25 @@ def test_imagenet_end_to_end_synthetic_v2(self): def test_imagenet_end_to_end_synthetic_v1_tiny(self): integration.run_synthetic( main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(), - extra_flags=['-v', '1', '-rs', '18'] + extra_flags=['-resnet_version', '1', '-resnet_size', '18'] ) def test_imagenet_end_to_end_synthetic_v2_tiny(self): integration.run_synthetic( main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(), - extra_flags=['-v', '2', '-rs', '18'] + extra_flags=['-resnet_version', '2', '-resnet_size', '18'] ) def test_imagenet_end_to_end_synthetic_v1_huge(self): integration.run_synthetic( main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(), - extra_flags=['-v', '1', '-rs', '200'] + extra_flags=['-resnet_version', '1', '-resnet_size', '200'] ) def test_imagenet_end_to_end_synthetic_v2_huge(self): integration.run_synthetic( main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(), - extra_flags=['-v', '2', '-rs', '200'] + extra_flags=['-resnet_version', '2', '-resnet_size', '200'] ) diff --git a/official/resnet/layer_test.py b/official/resnet/layer_test.py index 2ed04b1aad3..bbef600ef93 100644 --- a/official/resnet/layer_test.py +++ b/official/resnet/layer_test.py @@ -41,14 +41,22 @@ DATA_FORMAT = "channels_last" # CPU instructions often preclude channels_first BATCH_SIZE = 32 BLOCK_TESTS = [ - dict(bottleneck=True, projection=True, version=1, width=8, channels=4), - dict(bottleneck=True, projection=True, version=2, width=8, channels=4), - dict(bottleneck=True, projection=False, version=1, width=8, channels=4), - dict(bottleneck=True, projection=False, version=2, width=8, channels=4), - dict(bottleneck=False, projection=True, version=1, width=8, channels=4), - dict(bottleneck=False, projection=True, version=2, width=8, channels=4), - dict(bottleneck=False, projection=False, version=1, width=8, channels=4), - dict(bottleneck=False, projection=False, version=2, width=8, channels=4), + dict(bottleneck=True, projection=True, resnet_version=1, width=8, + channels=4), + dict(bottleneck=True, projection=True, resnet_version=2, width=8, + channels=4), + dict(bottleneck=True, projection=False, resnet_version=1, width=8, + channels=4), + dict(bottleneck=True, projection=False, resnet_version=2, width=8, + channels=4), + dict(bottleneck=False, projection=True, resnet_version=1, width=8, + channels=4), + dict(bottleneck=False, projection=True, resnet_version=2, width=8, + channels=4), + dict(bottleneck=False, projection=False, resnet_version=1, width=8, + channels=4), + dict(bottleneck=False, projection=False, resnet_version=2, width=8, + channels=4), ] @@ -95,7 +103,7 @@ def projection_shortcut(inputs): return projection_shortcut def _resnet_block_ops(self, test, batch_size, bottleneck, projection, - version, width, channels): + resnet_version, width, channels): """Test whether resnet block construction has changed. Args: @@ -104,7 +112,7 @@ def _resnet_block_ops(self, test, batch_size, bottleneck, projection, batch normalization. bottleneck: Whether or not to use bottleneck layers. projection: Whether or not to project the input. - version: Which version of ResNet to test. + resnet_version: Which version of ResNet to test. width: The width of the fake image. channels: The number of channels in the fake image. """ @@ -113,12 +121,12 @@ def _resnet_block_ops(self, test, batch_size, bottleneck, projection, batch_size, "bottleneck" if bottleneck else "building", "_projection" if projection else "", - version, + resnet_version, width, channels ) - if version == 1: + if resnet_version == 1: block_fn = resnet_model._building_block_v1 if bottleneck: block_fn = resnet_model._bottleneck_block_v1 diff --git a/official/resnet/resnet_model.py b/official/resnet/resnet_model.py index 92fefc34f65..128a765a4c8 100644 --- a/official/resnet/resnet_model.py +++ b/official/resnet/resnet_model.py @@ -354,7 +354,7 @@ def __init__(self, resnet_size, bottleneck, num_classes, num_filters, kernel_size, conv_stride, first_pool_size, first_pool_stride, block_sizes, block_strides, - final_size, version=DEFAULT_VERSION, data_format=None, + final_size, resnet_version=DEFAULT_VERSION, data_format=None, dtype=DEFAULT_DTYPE): """Creates a model for classifying an image. @@ -377,8 +377,8 @@ def __init__(self, resnet_size, bottleneck, num_classes, num_filters, block_strides: List of integers representing the desired stride size for each of the sets of block layers. Should be same length as block_sizes. final_size: The expected size of the model after the second pooling. - version: Integer representing which version of the ResNet network to use. - See README for details. Valid values: [1, 2] + resnet_version: Integer representing which version of the ResNet network + to use. See README for details. Valid values: [1, 2] data_format: Input format ('channels_last', 'channels_first', or None). If set to None, the format is dependent on whether a GPU is available. dtype: The TensorFlow dtype to use for calculations. If not specified @@ -393,19 +393,19 @@ def __init__(self, resnet_size, bottleneck, num_classes, num_filters, data_format = ( 'channels_first' if tf.test.is_built_with_cuda() else 'channels_last') - self.resnet_version = version - if version not in (1, 2): + self.resnet_version = resnet_version + if resnet_version not in (1, 2): raise ValueError( 'Resnet version should be 1 or 2. See README for citations.') self.bottleneck = bottleneck if bottleneck: - if version == 1: + if resnet_version == 1: self.block_fn = _bottleneck_block_v1 else: self.block_fn = _bottleneck_block_v2 else: - if version == 1: + if resnet_version == 1: self.block_fn = _building_block_v1 else: self.block_fn = _building_block_v2 diff --git a/official/resnet/resnet_run_loop.py b/official/resnet/resnet_run_loop.py index 748515e1cb0..12493b15976 100644 --- a/official/resnet/resnet_run_loop.py +++ b/official/resnet/resnet_run_loop.py @@ -155,8 +155,8 @@ def learning_rate_fn(global_step): def resnet_model_fn(features, labels, mode, model_class, resnet_size, weight_decay, learning_rate_fn, momentum, - data_format, version, loss_scale, loss_filter_fn=None, - dtype=resnet_model.DEFAULT_DTYPE): + data_format, resnet_version, loss_scale, + loss_filter_fn=None, dtype=resnet_model.DEFAULT_DTYPE): """Shared functionality for different resnet model_fns. Initializes the ResnetModel representing the model layers @@ -180,8 +180,8 @@ def resnet_model_fn(features, labels, mode, model_class, momentum: momentum term used for optimization data_format: Input format ('channels_last', 'channels_first', or None). If set to None, the format is dependent on whether a GPU is available. - version: Integer representing which version of the ResNet network to use. - See README for details. Valid values: [1, 2] + resnet_version: Integer representing which version of the ResNet network to + use. See README for details. Valid values: [1, 2] loss_scale: The factor to scale the loss for numerical stability. A detailed summary is present in the arg parser help text. loss_filter_fn: function that takes a string variable name and returns @@ -200,7 +200,8 @@ def resnet_model_fn(features, labels, mode, model_class, features = tf.cast(features, dtype) - model = model_class(resnet_size, data_format, version=version, dtype=dtype) + model = model_class(resnet_size, data_format, resnet_version=resnet_version, + dtype=dtype) logits = model(features, mode == tf.estimator.ModeKeys.TRAIN) @@ -379,7 +380,7 @@ def resnet_main( 'resnet_size': int(flags_obj.resnet_size), 'data_format': flags_obj.data_format, 'batch_size': flags_obj.batch_size, - 'version': int(flags_obj.version), + 'resnet_version': int(flags_obj.resnet_version), 'loss_scale': flags_core.get_loss_scale(flags_obj), 'dtype': flags_core.get_tf_dtype(flags_obj) }) @@ -388,7 +389,7 @@ def resnet_main( 'batch_size': flags_obj.batch_size, 'dtype': flags_core.get_tf_dtype(flags_obj), 'resnet_size': flags_obj.resnet_size, - 'resnet_version': flags_obj.version, + 'resnet_version': flags_obj.resnet_version, 'synthetic_data': flags_obj.use_synthetic_data, 'train_epochs': flags_obj.train_epochs, } @@ -456,7 +457,8 @@ def define_resnet_flags(resnet_size_choices=None): flags.adopt_module_key_flags(flags_core) flags.DEFINE_enum( - name='version', short_name='rv', default='2', enum_values=['1', '2'], + name='resnet_version', short_name='rv', default='2', + enum_values=['1', '2'], help=flags_core.help_wrap( 'Version of ResNet. (1 or 2) See README.md for details.'))