diff --git a/configs/_base_/models/vgg11.py b/configs/_base_/models/vgg11.py new file mode 100644 index 00000000000..2b6ee1426aa --- /dev/null +++ b/configs/_base_/models/vgg11.py @@ -0,0 +1,10 @@ +# model settings +model = dict( + type='ImageClassifier', + backbone=dict(type='VGG', depth=11, num_classes=1000), + neck=None, + head=dict( + type='ClsHead', + loss=dict(type='CrossEntropyLoss', loss_weight=1.0), + topk=(1, 5), + )) diff --git a/configs/_base_/models/vgg11bn.py b/configs/_base_/models/vgg11bn.py new file mode 100644 index 00000000000..cb4c64e95a8 --- /dev/null +++ b/configs/_base_/models/vgg11bn.py @@ -0,0 +1,11 @@ +# model settings +model = dict( + type='ImageClassifier', + backbone=dict( + type='VGG', depth=11, norm_cfg=dict(type='BN'), num_classes=1000), + neck=None, + head=dict( + type='ClsHead', + loss=dict(type='CrossEntropyLoss', loss_weight=1.0), + topk=(1, 5), + )) diff --git a/configs/_base_/models/vgg13.py b/configs/_base_/models/vgg13.py new file mode 100644 index 00000000000..a9389100a61 --- /dev/null +++ b/configs/_base_/models/vgg13.py @@ -0,0 +1,10 @@ +# model settings +model = dict( + type='ImageClassifier', + backbone=dict(type='VGG', depth=13, num_classes=1000), + neck=None, + head=dict( + type='ClsHead', + loss=dict(type='CrossEntropyLoss', loss_weight=1.0), + topk=(1, 5), + )) diff --git a/configs/_base_/models/vgg13bn.py b/configs/_base_/models/vgg13bn.py new file mode 100644 index 00000000000..b12173b51b8 --- /dev/null +++ b/configs/_base_/models/vgg13bn.py @@ -0,0 +1,11 @@ +# model settings +model = dict( + type='ImageClassifier', + backbone=dict( + type='VGG', depth=13, norm_cfg=dict(type='BN'), num_classes=1000), + neck=None, + head=dict( + type='ClsHead', + loss=dict(type='CrossEntropyLoss', loss_weight=1.0), + topk=(1, 5), + )) diff --git a/configs/_base_/models/vgg16.py b/configs/_base_/models/vgg16.py new file mode 100644 index 00000000000..93ce864fac2 --- /dev/null +++ b/configs/_base_/models/vgg16.py @@ -0,0 +1,10 @@ +# model settings +model = dict( + type='ImageClassifier', + backbone=dict(type='VGG', depth=16, num_classes=1000), + neck=None, + head=dict( + type='ClsHead', + loss=dict(type='CrossEntropyLoss', loss_weight=1.0), + topk=(1, 5), + )) diff --git a/configs/_base_/models/vgg16bn.py b/configs/_base_/models/vgg16bn.py new file mode 100644 index 00000000000..765e34f6367 --- /dev/null +++ b/configs/_base_/models/vgg16bn.py @@ -0,0 +1,11 @@ +# model settings +model = dict( + type='ImageClassifier', + backbone=dict( + type='VGG', depth=16, norm_cfg=dict(type='BN'), num_classes=1000), + neck=None, + head=dict( + type='ClsHead', + loss=dict(type='CrossEntropyLoss', loss_weight=1.0), + topk=(1, 5), + )) diff --git a/configs/_base_/models/vgg19.py b/configs/_base_/models/vgg19.py new file mode 100644 index 00000000000..6f4ab061b2c --- /dev/null +++ b/configs/_base_/models/vgg19.py @@ -0,0 +1,10 @@ +# model settings +model = dict( + type='ImageClassifier', + backbone=dict(type='VGG', depth=19, num_classes=1000), + neck=None, + head=dict( + type='ClsHead', + loss=dict(type='CrossEntropyLoss', loss_weight=1.0), + topk=(1, 5), + )) diff --git a/configs/_base_/models/vgg19bn.py b/configs/_base_/models/vgg19bn.py new file mode 100644 index 00000000000..c468b5dea2c --- /dev/null +++ b/configs/_base_/models/vgg19bn.py @@ -0,0 +1,11 @@ +# model settings +model = dict( + type='ImageClassifier', + backbone=dict( + type='VGG', depth=19, norm_cfg=dict(type='BN'), num_classes=1000), + neck=None, + head=dict( + type='ClsHead', + loss=dict(type='CrossEntropyLoss', loss_weight=1.0), + topk=(1, 5), + )) diff --git a/configs/imagenet/vgg11.py b/configs/imagenet/vgg11.py new file mode 100644 index 00000000000..348c63bedad --- /dev/null +++ b/configs/imagenet/vgg11.py @@ -0,0 +1,5 @@ +_base_ = [ + '../_base_/models/vgg11.py', + '../_base_/datasets/imagenet_bs32_pil_resize.py', + '../_base_/default_runtime.py' +] diff --git a/configs/imagenet/vgg11bn.py b/configs/imagenet/vgg11bn.py new file mode 100644 index 00000000000..7462e559f10 --- /dev/null +++ b/configs/imagenet/vgg11bn.py @@ -0,0 +1,5 @@ +_base_ = [ + '../_base_/models/vgg11bn.py', + '../_base_/datasets/imagenet_bs32_pil_resize.py', + '../_base_/default_runtime.py' +] diff --git a/configs/imagenet/vgg13.py b/configs/imagenet/vgg13.py new file mode 100644 index 00000000000..3e72c8e7bd0 --- /dev/null +++ b/configs/imagenet/vgg13.py @@ -0,0 +1,5 @@ +_base_ = [ + '../_base_/models/vgg13.py', + '../_base_/datasets/imagenet_bs32_pil_resize.py', + '../_base_/default_runtime.py' +] diff --git a/configs/imagenet/vgg13bn.py b/configs/imagenet/vgg13bn.py new file mode 100644 index 00000000000..591bc01ae3f --- /dev/null +++ b/configs/imagenet/vgg13bn.py @@ -0,0 +1,5 @@ +_base_ = [ + '../_base_/models/vgg13bn.py', + '../_base_/datasets/imagenet_bs32_pil_resize.py', + '../_base_/default_runtime.py' +] diff --git a/configs/imagenet/vgg16.py b/configs/imagenet/vgg16.py new file mode 100644 index 00000000000..a9302549763 --- /dev/null +++ b/configs/imagenet/vgg16.py @@ -0,0 +1,5 @@ +_base_ = [ + '../_base_/models/vgg16.py', + '../_base_/datasets/imagenet_bs32_pil_resize.py', + '../_base_/default_runtime.py' +] diff --git a/configs/imagenet/vgg16bn.py b/configs/imagenet/vgg16bn.py new file mode 100644 index 00000000000..e46a4bec73a --- /dev/null +++ b/configs/imagenet/vgg16bn.py @@ -0,0 +1,5 @@ +_base_ = [ + '../_base_/models/vgg16bn.py', + '../_base_/datasets/imagenet_bs32_pil_resize.py', + '../_base_/default_runtime.py' +] diff --git a/configs/imagenet/vgg19.py b/configs/imagenet/vgg19.py new file mode 100644 index 00000000000..e562224eb62 --- /dev/null +++ b/configs/imagenet/vgg19.py @@ -0,0 +1,5 @@ +_base_ = [ + '../_base_/models/vgg19.py', + '../_base_/datasets/imagenet_bs32_pil_resize.py', + '../_base_/default_runtime.py' +] diff --git a/configs/imagenet/vgg19bn.py b/configs/imagenet/vgg19bn.py new file mode 100644 index 00000000000..b040cb74693 --- /dev/null +++ b/configs/imagenet/vgg19bn.py @@ -0,0 +1,5 @@ +_base_ = [ + '../_base_/models/vgg19bn.py', + '../_base_/datasets/imagenet_bs32_pil_resize.py', + '../_base_/default_runtime.py' +] diff --git a/docs/model_zoo.md b/docs/model_zoo.md index e087ac5b38e..700a67ecb13 100644 --- a/docs/model_zoo.md +++ b/docs/model_zoo.md @@ -8,6 +8,14 @@ The ResNet family models below are trained by standard data augmentations, i.e., | Model | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Download | |:---------------------:|:---------:|:--------:|:---------:|:---------:|:--------:| +| VGG-11 | 132.86 | 7.63 | 69.03 | 88.63 | [model](https://openmmlab.oss-cn-hangzhou.aliyuncs.com/mmclassification/v0/imagenet/vgg11-01ecd97e.pth)* | +| VGG-13 | 133.05 | 11.34 | 69.93 | 89.26 | [model](https://openmmlab.oss-cn-hangzhou.aliyuncs.com/mmclassification/v0/imagenet/vgg13-9ad3945d.pth)*| +| VGG-16 | 138.36 | 15.5 | 71.59 | 90.39 | [model](https://openmmlab.oss-cn-hangzhou.aliyuncs.com/mmclassification/v0/imagenet/vgg16-91b6d117.pth)*| +| VGG-19 | 143.67 | 19.67 | 72.38 | 90.88 | [model](https://openmmlab.oss-cn-hangzhou.aliyuncs.com/mmclassification/v0/imagenet/vgg19-fee352a8.pth)*| +| VGG-11-BN | 132.87 | 7.64 | 70.37 | 89.81 | [model](https://openmmlab.oss-cn-hangzhou.aliyuncs.com/mmclassification/v0/imagenet/vgg11_bn-6fbbbf3f.pth)*| +| VGG-13-BN | 133.05 | 11.36 | 71.55 | 90.37 | [model](https://openmmlab.oss-cn-hangzhou.aliyuncs.com/mmclassification/v0/imagenet/vgg13_bn-4b5f9390.pth)*| +| VGG-16-BN | 138.37 | 15.53 | 73.36 | 91.5 | [model](https://openmmlab.oss-cn-hangzhou.aliyuncs.com/mmclassification/v0/imagenet/vgg16_bn-3ac6d8fd.pth)*| +| VGG-19-BN | 143.68 | 19.7 | 74.24 | 91.84 | [model](https://openmmlab.oss-cn-hangzhou.aliyuncs.com/mmclassification/v0/imagenet/vgg19_bn-7c058385.pth)*| | ResNet-18 | 11.69 | 1.82 | 70.07 | 89.44 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmclassification/v0/imagenet/resnet18_batch256_20200708-34ab8f90.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmclassification/v0/imagenet/resnet18_batch256_20200708-34ab8f90.log.json) | | ResNet-34 | 21.8 | 3.68 | 73.85 | 91.53 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmclassification/v0/imagenet/resnet34_batch256_20200708-32ffb4f7.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmclassification/v0/imagenet/resnet34_batch256_20200708-32ffb4f7.log.json) | | ResNet-50 | 25.56 | 4.12 | 76.55 | 93.15 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmclassification/v0/imagenet/resnet50_batch256_20200708-cfb998bf.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmclassification/v0/imagenet/resnet50_batch256_20200708-cfb998bf.log.json) | diff --git a/mmcls/models/backbones/__init__.py b/mmcls/models/backbones/__init__.py index 2dbace9e9e1..d0771fd62e1 100644 --- a/mmcls/models/backbones/__init__.py +++ b/mmcls/models/backbones/__init__.py @@ -11,9 +11,10 @@ from .seresnext import SEResNeXt from .shufflenet_v1 import ShuffleNetV1 from .shufflenet_v2 import ShuffleNetV2 +from .vgg import VGG __all__ = [ - 'LeNet5', 'AlexNet', 'RegNet', 'ResNet', 'ResNeXt', 'ResNetV1d', 'ResNeSt', - 'ResNet_CIFAR', 'SEResNet', 'SEResNeXt', 'ShuffleNetV1', 'ShuffleNetV2', - 'MobileNetV2', 'MobileNetv3' + 'LeNet5', 'AlexNet', 'VGG', 'RegNet', 'ResNet', 'ResNeXt', 'ResNetV1d', + 'ResNeSt', 'ResNet_CIFAR', 'SEResNet', 'SEResNeXt', 'ShuffleNetV1', + 'ShuffleNetV2', 'MobileNetV2', 'MobileNetv3' ] diff --git a/mmcls/models/backbones/vgg.py b/mmcls/models/backbones/vgg.py new file mode 100644 index 00000000000..aaa9eb9ebe0 --- /dev/null +++ b/mmcls/models/backbones/vgg.py @@ -0,0 +1,194 @@ +import torch.nn as nn +from mmcv.cnn import ConvModule, constant_init, kaiming_init, normal_init +from mmcv.utils.parrots_wrapper import _BatchNorm + +from ..builder import BACKBONES +from .base_backbone import BaseBackbone + + +def make_vgg_layer(in_channels, + out_channels, + num_blocks, + conv_cfg=None, + norm_cfg=None, + act_cfg=dict(type='ReLU'), + dilation=1, + with_norm=False, + ceil_mode=False): + layers = [] + for _ in range(num_blocks): + layer = ConvModule( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + dilation=dilation, + padding=dilation, + bias=True, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + layers.append(layer) + in_channels = out_channels + layers.append(nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=ceil_mode)) + + return layers + + +@BACKBONES.register_module() +class VGG(BaseBackbone): + """VGG backbone. + + Args: + depth (int): Depth of vgg, from {11, 13, 16, 19}. + with_norm (bool): Use BatchNorm or not. + num_classes (int): number of classes for classification. + num_stages (int): VGG stages, normally 5. + dilations (Sequence[int]): Dilation of each stage. + out_indices (Sequence[int]): Output from which stages. If only one + stage is specified, a single tensor (feature map) is returned, + otherwise multiple stages are specified, a tuple of tensors will + be returned. When it is None, the default behavior depends on + whether num_classes is specified. If num_classes <= 0, the default + value is (4, ), outputing the last feature map before classifier. + If num_classes > 0, the default value is (5, ), outputing the + classification score. Default: None. + frozen_stages (int): Stages to be frozen (all param fixed). -1 means + not freezing any parameters. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + ceil_mode (bool): Whether to use ceil_mode of MaxPool. Default: False. + with_last_pool (bool): Whether to keep the last pooling before + classifier. Default: True. + """ + + # Parameters to build layers. Each element specifies the number of conv in + # each stage. For example, VGG11 contains 11 layers with learnable + # parameters. 11 is computed as 11 = (1 + 1 + 2 + 2 + 2) + 3, + # where 3 indicates the last three fully-connected layers. + arch_settings = { + 11: (1, 1, 2, 2, 2), + 13: (2, 2, 2, 2, 2), + 16: (2, 2, 3, 3, 3), + 19: (2, 2, 4, 4, 4) + } + + def __init__(self, + depth, + num_classes=-1, + num_stages=5, + dilations=(1, 1, 1, 1, 1), + out_indices=None, + frozen_stages=-1, + conv_cfg=None, + norm_cfg=None, + act_cfg=dict(type='ReLU'), + norm_eval=False, + ceil_mode=False, + with_last_pool=True): + super(VGG, self).__init__() + if depth not in self.arch_settings: + raise KeyError(f'invalid depth {depth} for vgg') + assert num_stages >= 1 and num_stages <= 5 + stage_blocks = self.arch_settings[depth] + self.stage_blocks = stage_blocks[:num_stages] + assert len(dilations) == num_stages + + self.num_classes = num_classes + self.frozen_stages = frozen_stages + self.norm_eval = norm_eval + with_norm = norm_cfg is not None + + if out_indices is None: + out_indices = (5, ) if num_classes > 0 else (4, ) + assert max(out_indices) <= num_stages + self.out_indices = out_indices + + self.in_channels = 3 + start_idx = 0 + vgg_layers = [] + self.range_sub_modules = [] + for i, num_blocks in enumerate(self.stage_blocks): + num_modules = num_blocks + 1 + end_idx = start_idx + num_modules + dilation = dilations[i] + out_channels = 64 * 2**i if i < 4 else 512 + vgg_layer = make_vgg_layer( + self.in_channels, + out_channels, + num_blocks, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + dilation=dilation, + with_norm=with_norm, + ceil_mode=ceil_mode) + vgg_layers.extend(vgg_layer) + self.in_channels = out_channels + self.range_sub_modules.append([start_idx, end_idx]) + start_idx = end_idx + if not with_last_pool: + vgg_layers.pop(-1) + self.range_sub_modules[-1][1] -= 1 + self.module_name = 'features' + self.add_module(self.module_name, nn.Sequential(*vgg_layers)) + + if self.num_classes > 0: + self.classifier = nn.Sequential( + nn.Linear(512 * 7 * 7, 4096), + nn.ReLU(True), + nn.Dropout(), + nn.Linear(4096, 4096), + nn.ReLU(True), + nn.Dropout(), + nn.Linear(4096, num_classes), + ) + + def init_weights(self, pretrained=None): + super(VGG, self).init_weights(pretrained) + if pretrained is None: + for m in self.modules(): + if isinstance(m, nn.Conv2d): + kaiming_init(m) + elif isinstance(m, _BatchNorm): + constant_init(m, 1) + elif isinstance(m, nn.Linear): + normal_init(m, std=0.01) + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x): + outs = [] + vgg_layers = getattr(self, self.module_name) + for i in range(len(self.stage_blocks)): + for j in range(*self.range_sub_modules[i]): + vgg_layer = vgg_layers[j] + x = vgg_layer(x) + if i in self.out_indices: + outs.append(x) + if self.num_classes > 0: + x = x.view(x.size(0), -1) + x = self.classifier(x) + outs.append(x) + if len(outs) == 1: + return outs[0] + else: + return tuple(outs) + + def _freeze_stages(self): + vgg_layers = getattr(self, self.module_name) + for i in range(self.frozen_stages): + for j in range(*self.range_sub_modules[i]): + m = vgg_layers[j] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(VGG, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() diff --git a/tests/test_backbones/test_vgg.py b/tests/test_backbones/test_vgg.py new file mode 100644 index 00000000000..fc184214e1f --- /dev/null +++ b/tests/test_backbones/test_vgg.py @@ -0,0 +1,136 @@ +import pytest +import torch +from mmcv.utils.parrots_wrapper import _BatchNorm + +from mmcls.models.backbones import VGG + + +def check_norm_state(modules, train_state): + """Check if norm layer is in correct train state.""" + for mod in modules: + if isinstance(mod, _BatchNorm): + if mod.training != train_state: + return False + return True + + +def test_vgg(): + """Test VGG backbone""" + with pytest.raises(KeyError): + # VGG depth should be in [11, 13, 16, 19] + VGG(18) + + with pytest.raises(AssertionError): + # In VGG: 1 <= num_stages <= 5 + VGG(11, num_stages=0) + + with pytest.raises(AssertionError): + # In VGG: 1 <= num_stages <= 5 + VGG(11, num_stages=6) + + with pytest.raises(AssertionError): + # len(dilations) == num_stages + VGG(11, dilations=(1, 1), num_stages=3) + + with pytest.raises(TypeError): + # pretrained must be a string path + model = VGG(11) + model.init_weights(pretrained=0) + + # Test VGG11 norm_eval=True + model = VGG(11, norm_eval=True) + model.init_weights() + model.train() + assert check_norm_state(model.modules(), False) + + # Test VGG11 forward without classifiers + model = VGG(11, out_indices=(0, 1, 2, 3, 4)) + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 5 + assert feat[0].shape == (1, 64, 112, 112) + assert feat[1].shape == (1, 128, 56, 56) + assert feat[2].shape == (1, 256, 28, 28) + assert feat[3].shape == (1, 512, 14, 14) + assert feat[4].shape == (1, 512, 7, 7) + + # Test VGG11 forward with classifiers + model = VGG(11, num_classes=10, out_indices=(0, 1, 2, 3, 4, 5)) + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 6 + assert feat[0].shape == (1, 64, 112, 112) + assert feat[1].shape == (1, 128, 56, 56) + assert feat[2].shape == (1, 256, 28, 28) + assert feat[3].shape == (1, 512, 14, 14) + assert feat[4].shape == (1, 512, 7, 7) + assert feat[5].shape == (1, 10) + + # Test VGG11BN forward + model = VGG(11, norm_cfg=dict(type='BN'), out_indices=(0, 1, 2, 3, 4)) + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 5 + assert feat[0].shape == (1, 64, 112, 112) + assert feat[1].shape == (1, 128, 56, 56) + assert feat[2].shape == (1, 256, 28, 28) + assert feat[3].shape == (1, 512, 14, 14) + assert feat[4].shape == (1, 512, 7, 7) + + # Test VGG11BN forward with classifiers + model = VGG( + 11, + num_classes=10, + norm_cfg=dict(type='BN'), + out_indices=(0, 1, 2, 3, 4, 5)) + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 6 + assert feat[0].shape == (1, 64, 112, 112) + assert feat[1].shape == (1, 128, 56, 56) + assert feat[2].shape == (1, 256, 28, 28) + assert feat[3].shape == (1, 512, 14, 14) + assert feat[4].shape == (1, 512, 7, 7) + assert feat[5].shape == (1, 10) + + # Test VGG13 with layers 1, 2, 3 out forward + model = VGG(13, out_indices=(0, 1, 2)) + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 3 + assert feat[0].shape == (1, 64, 112, 112) + assert feat[1].shape == (1, 128, 56, 56) + assert feat[2].shape == (1, 256, 28, 28) + + # Test VGG16 with top feature maps out forward + model = VGG(16) + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert feat.shape == (1, 512, 7, 7) + + # Test VGG19 with classification score out forward + model = VGG(19, num_classes=10) + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert feat.shape == (1, 10) diff --git a/tools/convert_models/vgg_to_mmcls.py b/tools/convert_models/vgg_to_mmcls.py new file mode 100644 index 00000000000..db93059b684 --- /dev/null +++ b/tools/convert_models/vgg_to_mmcls.py @@ -0,0 +1,117 @@ +import argparse +import os +from collections import OrderedDict + +import torch + + +def get_layer_maps(layer_num, with_bn): + layer_maps = {'conv': {}, 'bn': {}} + if with_bn: + if layer_num == 11: + layer_idxs = [0, 4, 8, 11, 15, 18, 22, 25] + elif layer_num == 13: + layer_idxs = [0, 3, 7, 10, 14, 17, 21, 24, 28, 31] + elif layer_num == 16: + layer_idxs = [0, 3, 7, 10, 14, 17, 20, 24, 27, 30, 34, 37, 40] + elif layer_num == 19: + layer_idxs = [ + 0, 3, 7, 10, 14, 17, 20, 23, 27, 30, 33, 36, 40, 43, 46, 49 + ] + else: + raise ValueError(f'Invalid number of layers: {layer_num}') + for i, layer_idx in enumerate(layer_idxs): + if i == 0: + new_layer_idx = layer_idx + else: + new_layer_idx += int((layer_idx - layer_idxs[i - 1]) / 2) + layer_maps['conv'][layer_idx] = new_layer_idx + layer_maps['bn'][layer_idx + 1] = new_layer_idx + else: + if layer_num == 11: + layer_idxs = [0, 3, 6, 8, 11, 13, 16, 18] + new_layer_idxs = [0, 2, 4, 5, 7, 8, 10, 11] + elif layer_num == 13: + layer_idxs = [0, 2, 5, 7, 10, 12, 15, 17, 20, 22] + new_layer_idxs = [0, 1, 3, 4, 6, 7, 9, 10, 12, 13] + elif layer_num == 16: + layer_idxs = [0, 2, 5, 7, 10, 12, 14, 17, 19, 21, 24, 26, 28] + new_layer_idxs = [0, 1, 3, 4, 6, 7, 8, 10, 11, 12, 14, 15, 16] + elif layer_num == 19: + layer_idxs = [ + 0, 2, 5, 7, 10, 12, 14, 16, 19, 21, 23, 25, 28, 30, 32, 34 + ] + new_layer_idxs = [ + 0, 1, 3, 4, 6, 7, 8, 9, 11, 12, 13, 14, 16, 17, 18, 19 + ] + else: + raise ValueError(f'Invalid number of layers: {layer_num}') + + layer_maps['conv'] = { + layer_idx: new_layer_idx + for layer_idx, new_layer_idx in zip(layer_idxs, new_layer_idxs) + } + + return layer_maps + + +def convert(src, dst, layer_num, with_bn=False): + """Convert keys in torchvision pretrained VGG models to mmcls + style.""" + + # load pytorch model + assert os.path.isfile(src), f'no checkpoint found at {src}' + blobs = torch.load(src, map_location='cpu') + + # convert to pytorch style + state_dict = OrderedDict() + + layer_maps = get_layer_maps(layer_num, with_bn) + + prefix = 'backbone' + delimiter = '.' + for key, weight in blobs.items(): + if 'features' in key: + module, layer_idx, weight_type = key.split(delimiter) + new_key = delimiter.join([prefix, key]) + layer_idx = int(layer_idx) + for layer_key, maps in layer_maps.items(): + if layer_idx in maps: + new_layer_idx = maps[layer_idx] + new_key = delimiter.join([ + prefix, 'features', + str(new_layer_idx), layer_key, weight_type + ]) + state_dict[new_key] = weight + print(f'Convert {key} to {new_key}') + elif 'classifier' in key: + new_key = delimiter.join([prefix, key]) + state_dict[new_key] = weight + print(f'Convert {key} to {new_key}') + else: + state_dict[key] = weight + + # save checkpoint + checkpoint = dict() + checkpoint['state_dict'] = state_dict + torch.save(checkpoint, dst) + + +def main(): + parser = argparse.ArgumentParser(description='Convert model keys') + parser.add_argument('src', help='src torchvision model path') + parser.add_argument('dst', help='save path') + parser.add_argument( + '--bn', action='store_true', help='whether original vgg has BN') + parser.add_argument( + '--layer_num', + type=int, + choices=[11, 13, 16, 19], + default=11, + help='number of VGG layers') + args = parser.parse_args() + convert(args.src, args.dst, layer_num=args.layer_num, with_bn=args.bn) + + +if __name__ == '__main__': + main()