Skip to content

Commit

Permalink
Add VGG and pretained models (open-mmlab#27)
Browse files Browse the repository at this point in the history
* add vgg

* add vgg model coversion tool

* fix out_indices and docstr

* add vgg models in configs

* add params, flops and accuracy in docs

* add pretrained models url

* use ConvModule and refine var names

* update vgg conversion tool

* modify bn config

* add docs for arch_setting

* add unit test for vgg

* rm debug code

* update vgg pretrained models
  • Loading branch information
yl-1993 authored Sep 29, 2020
1 parent 99115fd commit bc1b08b
Show file tree
Hide file tree
Showing 21 changed files with 583 additions and 3 deletions.
10 changes: 10 additions & 0 deletions configs/_base_/models/vgg11.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(type='VGG', depth=11, num_classes=1000),
neck=None,
head=dict(
type='ClsHead',
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
))
11 changes: 11 additions & 0 deletions configs/_base_/models/vgg11bn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='VGG', depth=11, norm_cfg=dict(type='BN'), num_classes=1000),
neck=None,
head=dict(
type='ClsHead',
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
))
10 changes: 10 additions & 0 deletions configs/_base_/models/vgg13.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(type='VGG', depth=13, num_classes=1000),
neck=None,
head=dict(
type='ClsHead',
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
))
11 changes: 11 additions & 0 deletions configs/_base_/models/vgg13bn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='VGG', depth=13, norm_cfg=dict(type='BN'), num_classes=1000),
neck=None,
head=dict(
type='ClsHead',
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
))
10 changes: 10 additions & 0 deletions configs/_base_/models/vgg16.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(type='VGG', depth=16, num_classes=1000),
neck=None,
head=dict(
type='ClsHead',
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
))
11 changes: 11 additions & 0 deletions configs/_base_/models/vgg16bn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='VGG', depth=16, norm_cfg=dict(type='BN'), num_classes=1000),
neck=None,
head=dict(
type='ClsHead',
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
))
10 changes: 10 additions & 0 deletions configs/_base_/models/vgg19.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(type='VGG', depth=19, num_classes=1000),
neck=None,
head=dict(
type='ClsHead',
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
))
11 changes: 11 additions & 0 deletions configs/_base_/models/vgg19bn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='VGG', depth=19, norm_cfg=dict(type='BN'), num_classes=1000),
neck=None,
head=dict(
type='ClsHead',
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
))
5 changes: 5 additions & 0 deletions configs/imagenet/vgg11.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
_base_ = [
'../_base_/models/vgg11.py',
'../_base_/datasets/imagenet_bs32_pil_resize.py',
'../_base_/default_runtime.py'
]
5 changes: 5 additions & 0 deletions configs/imagenet/vgg11bn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
_base_ = [
'../_base_/models/vgg11bn.py',
'../_base_/datasets/imagenet_bs32_pil_resize.py',
'../_base_/default_runtime.py'
]
5 changes: 5 additions & 0 deletions configs/imagenet/vgg13.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
_base_ = [
'../_base_/models/vgg13.py',
'../_base_/datasets/imagenet_bs32_pil_resize.py',
'../_base_/default_runtime.py'
]
5 changes: 5 additions & 0 deletions configs/imagenet/vgg13bn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
_base_ = [
'../_base_/models/vgg13bn.py',
'../_base_/datasets/imagenet_bs32_pil_resize.py',
'../_base_/default_runtime.py'
]
5 changes: 5 additions & 0 deletions configs/imagenet/vgg16.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
_base_ = [
'../_base_/models/vgg16.py',
'../_base_/datasets/imagenet_bs32_pil_resize.py',
'../_base_/default_runtime.py'
]
5 changes: 5 additions & 0 deletions configs/imagenet/vgg16bn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
_base_ = [
'../_base_/models/vgg16bn.py',
'../_base_/datasets/imagenet_bs32_pil_resize.py',
'../_base_/default_runtime.py'
]
5 changes: 5 additions & 0 deletions configs/imagenet/vgg19.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
_base_ = [
'../_base_/models/vgg19.py',
'../_base_/datasets/imagenet_bs32_pil_resize.py',
'../_base_/default_runtime.py'
]
5 changes: 5 additions & 0 deletions configs/imagenet/vgg19bn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
_base_ = [
'../_base_/models/vgg19bn.py',
'../_base_/datasets/imagenet_bs32_pil_resize.py',
'../_base_/default_runtime.py'
]
8 changes: 8 additions & 0 deletions docs/model_zoo.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@ The ResNet family models below are trained by standard data augmentations, i.e.,

| Model | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Download |
|:---------------------:|:---------:|:--------:|:---------:|:---------:|:--------:|
| VGG-11 | 132.86 | 7.63 | 69.03 | 88.63 | [model](https://openmmlab.oss-cn-hangzhou.aliyuncs.com/mmclassification/v0/imagenet/vgg11-01ecd97e.pth)* |
| VGG-13 | 133.05 | 11.34 | 69.93 | 89.26 | [model](https://openmmlab.oss-cn-hangzhou.aliyuncs.com/mmclassification/v0/imagenet/vgg13-9ad3945d.pth)*|
| VGG-16 | 138.36 | 15.5 | 71.59 | 90.39 | [model](https://openmmlab.oss-cn-hangzhou.aliyuncs.com/mmclassification/v0/imagenet/vgg16-91b6d117.pth)*|
| VGG-19 | 143.67 | 19.67 | 72.38 | 90.88 | [model](https://openmmlab.oss-cn-hangzhou.aliyuncs.com/mmclassification/v0/imagenet/vgg19-fee352a8.pth)*|
| VGG-11-BN | 132.87 | 7.64 | 70.37 | 89.81 | [model](https://openmmlab.oss-cn-hangzhou.aliyuncs.com/mmclassification/v0/imagenet/vgg11_bn-6fbbbf3f.pth)*|
| VGG-13-BN | 133.05 | 11.36 | 71.55 | 90.37 | [model](https://openmmlab.oss-cn-hangzhou.aliyuncs.com/mmclassification/v0/imagenet/vgg13_bn-4b5f9390.pth)*|
| VGG-16-BN | 138.37 | 15.53 | 73.36 | 91.5 | [model](https://openmmlab.oss-cn-hangzhou.aliyuncs.com/mmclassification/v0/imagenet/vgg16_bn-3ac6d8fd.pth)*|
| VGG-19-BN | 143.68 | 19.7 | 74.24 | 91.84 | [model](https://openmmlab.oss-cn-hangzhou.aliyuncs.com/mmclassification/v0/imagenet/vgg19_bn-7c058385.pth)*|
| ResNet-18 | 11.69 | 1.82 | 70.07 | 89.44 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmclassification/v0/imagenet/resnet18_batch256_20200708-34ab8f90.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmclassification/v0/imagenet/resnet18_batch256_20200708-34ab8f90.log.json) |
| ResNet-34 | 21.8 | 3.68 | 73.85 | 91.53 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmclassification/v0/imagenet/resnet34_batch256_20200708-32ffb4f7.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmclassification/v0/imagenet/resnet34_batch256_20200708-32ffb4f7.log.json) |
| ResNet-50 | 25.56 | 4.12 | 76.55 | 93.15 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmclassification/v0/imagenet/resnet50_batch256_20200708-cfb998bf.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmclassification/v0/imagenet/resnet50_batch256_20200708-cfb998bf.log.json) |
Expand Down
7 changes: 4 additions & 3 deletions mmcls/models/backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
from .seresnext import SEResNeXt
from .shufflenet_v1 import ShuffleNetV1
from .shufflenet_v2 import ShuffleNetV2
from .vgg import VGG

__all__ = [
'LeNet5', 'AlexNet', 'RegNet', 'ResNet', 'ResNeXt', 'ResNetV1d', 'ResNeSt',
'ResNet_CIFAR', 'SEResNet', 'SEResNeXt', 'ShuffleNetV1', 'ShuffleNetV2',
'MobileNetV2', 'MobileNetv3'
'LeNet5', 'AlexNet', 'VGG', 'RegNet', 'ResNet', 'ResNeXt', 'ResNetV1d',
'ResNeSt', 'ResNet_CIFAR', 'SEResNet', 'SEResNeXt', 'ShuffleNetV1',
'ShuffleNetV2', 'MobileNetV2', 'MobileNetv3'
]
194 changes: 194 additions & 0 deletions mmcls/models/backbones/vgg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
import torch.nn as nn
from mmcv.cnn import ConvModule, constant_init, kaiming_init, normal_init
from mmcv.utils.parrots_wrapper import _BatchNorm

from ..builder import BACKBONES
from .base_backbone import BaseBackbone


def make_vgg_layer(in_channels,
out_channels,
num_blocks,
conv_cfg=None,
norm_cfg=None,
act_cfg=dict(type='ReLU'),
dilation=1,
with_norm=False,
ceil_mode=False):
layers = []
for _ in range(num_blocks):
layer = ConvModule(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
dilation=dilation,
padding=dilation,
bias=True,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
layers.append(layer)
in_channels = out_channels
layers.append(nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=ceil_mode))

return layers


@BACKBONES.register_module()
class VGG(BaseBackbone):
"""VGG backbone.
Args:
depth (int): Depth of vgg, from {11, 13, 16, 19}.
with_norm (bool): Use BatchNorm or not.
num_classes (int): number of classes for classification.
num_stages (int): VGG stages, normally 5.
dilations (Sequence[int]): Dilation of each stage.
out_indices (Sequence[int]): Output from which stages. If only one
stage is specified, a single tensor (feature map) is returned,
otherwise multiple stages are specified, a tuple of tensors will
be returned. When it is None, the default behavior depends on
whether num_classes is specified. If num_classes <= 0, the default
value is (4, ), outputing the last feature map before classifier.
If num_classes > 0, the default value is (5, ), outputing the
classification score. Default: None.
frozen_stages (int): Stages to be frozen (all param fixed). -1 means
not freezing any parameters.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
ceil_mode (bool): Whether to use ceil_mode of MaxPool. Default: False.
with_last_pool (bool): Whether to keep the last pooling before
classifier. Default: True.
"""

# Parameters to build layers. Each element specifies the number of conv in
# each stage. For example, VGG11 contains 11 layers with learnable
# parameters. 11 is computed as 11 = (1 + 1 + 2 + 2 + 2) + 3,
# where 3 indicates the last three fully-connected layers.
arch_settings = {
11: (1, 1, 2, 2, 2),
13: (2, 2, 2, 2, 2),
16: (2, 2, 3, 3, 3),
19: (2, 2, 4, 4, 4)
}

def __init__(self,
depth,
num_classes=-1,
num_stages=5,
dilations=(1, 1, 1, 1, 1),
out_indices=None,
frozen_stages=-1,
conv_cfg=None,
norm_cfg=None,
act_cfg=dict(type='ReLU'),
norm_eval=False,
ceil_mode=False,
with_last_pool=True):
super(VGG, self).__init__()
if depth not in self.arch_settings:
raise KeyError(f'invalid depth {depth} for vgg')
assert num_stages >= 1 and num_stages <= 5
stage_blocks = self.arch_settings[depth]
self.stage_blocks = stage_blocks[:num_stages]
assert len(dilations) == num_stages

self.num_classes = num_classes
self.frozen_stages = frozen_stages
self.norm_eval = norm_eval
with_norm = norm_cfg is not None

if out_indices is None:
out_indices = (5, ) if num_classes > 0 else (4, )
assert max(out_indices) <= num_stages
self.out_indices = out_indices

self.in_channels = 3
start_idx = 0
vgg_layers = []
self.range_sub_modules = []
for i, num_blocks in enumerate(self.stage_blocks):
num_modules = num_blocks + 1
end_idx = start_idx + num_modules
dilation = dilations[i]
out_channels = 64 * 2**i if i < 4 else 512
vgg_layer = make_vgg_layer(
self.in_channels,
out_channels,
num_blocks,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
dilation=dilation,
with_norm=with_norm,
ceil_mode=ceil_mode)
vgg_layers.extend(vgg_layer)
self.in_channels = out_channels
self.range_sub_modules.append([start_idx, end_idx])
start_idx = end_idx
if not with_last_pool:
vgg_layers.pop(-1)
self.range_sub_modules[-1][1] -= 1
self.module_name = 'features'
self.add_module(self.module_name, nn.Sequential(*vgg_layers))

if self.num_classes > 0:
self.classifier = nn.Sequential(
nn.Linear(512 * 7 * 7, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, num_classes),
)

def init_weights(self, pretrained=None):
super(VGG, self).init_weights(pretrained)
if pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, _BatchNorm):
constant_init(m, 1)
elif isinstance(m, nn.Linear):
normal_init(m, std=0.01)
else:
raise TypeError('pretrained must be a str or None')

def forward(self, x):
outs = []
vgg_layers = getattr(self, self.module_name)
for i in range(len(self.stage_blocks)):
for j in range(*self.range_sub_modules[i]):
vgg_layer = vgg_layers[j]
x = vgg_layer(x)
if i in self.out_indices:
outs.append(x)
if self.num_classes > 0:
x = x.view(x.size(0), -1)
x = self.classifier(x)
outs.append(x)
if len(outs) == 1:
return outs[0]
else:
return tuple(outs)

def _freeze_stages(self):
vgg_layers = getattr(self, self.module_name)
for i in range(self.frozen_stages):
for j in range(*self.range_sub_modules[i]):
m = vgg_layers[j]
m.eval()
for param in m.parameters():
param.requires_grad = False

def train(self, mode=True):
super(VGG, self).train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
# trick: eval have effect on BatchNorm only
if isinstance(m, _BatchNorm):
m.eval()
Loading

0 comments on commit bc1b08b

Please sign in to comment.