forked from open-mmlab/mmpretrain
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add VGG and pretained models (open-mmlab#27)
* add vgg * add vgg model coversion tool * fix out_indices and docstr * add vgg models in configs * add params, flops and accuracy in docs * add pretrained models url * use ConvModule and refine var names * update vgg conversion tool * modify bn config * add docs for arch_setting * add unit test for vgg * rm debug code * update vgg pretrained models
- Loading branch information
Showing
21 changed files
with
583 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# model settings | ||
model = dict( | ||
type='ImageClassifier', | ||
backbone=dict(type='VGG', depth=11, num_classes=1000), | ||
neck=None, | ||
head=dict( | ||
type='ClsHead', | ||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0), | ||
topk=(1, 5), | ||
)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# model settings | ||
model = dict( | ||
type='ImageClassifier', | ||
backbone=dict( | ||
type='VGG', depth=11, norm_cfg=dict(type='BN'), num_classes=1000), | ||
neck=None, | ||
head=dict( | ||
type='ClsHead', | ||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0), | ||
topk=(1, 5), | ||
)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# model settings | ||
model = dict( | ||
type='ImageClassifier', | ||
backbone=dict(type='VGG', depth=13, num_classes=1000), | ||
neck=None, | ||
head=dict( | ||
type='ClsHead', | ||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0), | ||
topk=(1, 5), | ||
)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# model settings | ||
model = dict( | ||
type='ImageClassifier', | ||
backbone=dict( | ||
type='VGG', depth=13, norm_cfg=dict(type='BN'), num_classes=1000), | ||
neck=None, | ||
head=dict( | ||
type='ClsHead', | ||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0), | ||
topk=(1, 5), | ||
)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# model settings | ||
model = dict( | ||
type='ImageClassifier', | ||
backbone=dict(type='VGG', depth=16, num_classes=1000), | ||
neck=None, | ||
head=dict( | ||
type='ClsHead', | ||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0), | ||
topk=(1, 5), | ||
)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# model settings | ||
model = dict( | ||
type='ImageClassifier', | ||
backbone=dict( | ||
type='VGG', depth=16, norm_cfg=dict(type='BN'), num_classes=1000), | ||
neck=None, | ||
head=dict( | ||
type='ClsHead', | ||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0), | ||
topk=(1, 5), | ||
)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# model settings | ||
model = dict( | ||
type='ImageClassifier', | ||
backbone=dict(type='VGG', depth=19, num_classes=1000), | ||
neck=None, | ||
head=dict( | ||
type='ClsHead', | ||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0), | ||
topk=(1, 5), | ||
)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# model settings | ||
model = dict( | ||
type='ImageClassifier', | ||
backbone=dict( | ||
type='VGG', depth=19, norm_cfg=dict(type='BN'), num_classes=1000), | ||
neck=None, | ||
head=dict( | ||
type='ClsHead', | ||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0), | ||
topk=(1, 5), | ||
)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
_base_ = [ | ||
'../_base_/models/vgg11.py', | ||
'../_base_/datasets/imagenet_bs32_pil_resize.py', | ||
'../_base_/default_runtime.py' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
_base_ = [ | ||
'../_base_/models/vgg11bn.py', | ||
'../_base_/datasets/imagenet_bs32_pil_resize.py', | ||
'../_base_/default_runtime.py' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
_base_ = [ | ||
'../_base_/models/vgg13.py', | ||
'../_base_/datasets/imagenet_bs32_pil_resize.py', | ||
'../_base_/default_runtime.py' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
_base_ = [ | ||
'../_base_/models/vgg13bn.py', | ||
'../_base_/datasets/imagenet_bs32_pil_resize.py', | ||
'../_base_/default_runtime.py' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
_base_ = [ | ||
'../_base_/models/vgg16.py', | ||
'../_base_/datasets/imagenet_bs32_pil_resize.py', | ||
'../_base_/default_runtime.py' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
_base_ = [ | ||
'../_base_/models/vgg16bn.py', | ||
'../_base_/datasets/imagenet_bs32_pil_resize.py', | ||
'../_base_/default_runtime.py' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
_base_ = [ | ||
'../_base_/models/vgg19.py', | ||
'../_base_/datasets/imagenet_bs32_pil_resize.py', | ||
'../_base_/default_runtime.py' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
_base_ = [ | ||
'../_base_/models/vgg19bn.py', | ||
'../_base_/datasets/imagenet_bs32_pil_resize.py', | ||
'../_base_/default_runtime.py' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,194 @@ | ||
import torch.nn as nn | ||
from mmcv.cnn import ConvModule, constant_init, kaiming_init, normal_init | ||
from mmcv.utils.parrots_wrapper import _BatchNorm | ||
|
||
from ..builder import BACKBONES | ||
from .base_backbone import BaseBackbone | ||
|
||
|
||
def make_vgg_layer(in_channels, | ||
out_channels, | ||
num_blocks, | ||
conv_cfg=None, | ||
norm_cfg=None, | ||
act_cfg=dict(type='ReLU'), | ||
dilation=1, | ||
with_norm=False, | ||
ceil_mode=False): | ||
layers = [] | ||
for _ in range(num_blocks): | ||
layer = ConvModule( | ||
in_channels=in_channels, | ||
out_channels=out_channels, | ||
kernel_size=3, | ||
dilation=dilation, | ||
padding=dilation, | ||
bias=True, | ||
conv_cfg=conv_cfg, | ||
norm_cfg=norm_cfg, | ||
act_cfg=act_cfg) | ||
layers.append(layer) | ||
in_channels = out_channels | ||
layers.append(nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=ceil_mode)) | ||
|
||
return layers | ||
|
||
|
||
@BACKBONES.register_module() | ||
class VGG(BaseBackbone): | ||
"""VGG backbone. | ||
Args: | ||
depth (int): Depth of vgg, from {11, 13, 16, 19}. | ||
with_norm (bool): Use BatchNorm or not. | ||
num_classes (int): number of classes for classification. | ||
num_stages (int): VGG stages, normally 5. | ||
dilations (Sequence[int]): Dilation of each stage. | ||
out_indices (Sequence[int]): Output from which stages. If only one | ||
stage is specified, a single tensor (feature map) is returned, | ||
otherwise multiple stages are specified, a tuple of tensors will | ||
be returned. When it is None, the default behavior depends on | ||
whether num_classes is specified. If num_classes <= 0, the default | ||
value is (4, ), outputing the last feature map before classifier. | ||
If num_classes > 0, the default value is (5, ), outputing the | ||
classification score. Default: None. | ||
frozen_stages (int): Stages to be frozen (all param fixed). -1 means | ||
not freezing any parameters. | ||
norm_eval (bool): Whether to set norm layers to eval mode, namely, | ||
freeze running stats (mean and var). Note: Effect on Batch Norm | ||
and its variants only. Default: False. | ||
ceil_mode (bool): Whether to use ceil_mode of MaxPool. Default: False. | ||
with_last_pool (bool): Whether to keep the last pooling before | ||
classifier. Default: True. | ||
""" | ||
|
||
# Parameters to build layers. Each element specifies the number of conv in | ||
# each stage. For example, VGG11 contains 11 layers with learnable | ||
# parameters. 11 is computed as 11 = (1 + 1 + 2 + 2 + 2) + 3, | ||
# where 3 indicates the last three fully-connected layers. | ||
arch_settings = { | ||
11: (1, 1, 2, 2, 2), | ||
13: (2, 2, 2, 2, 2), | ||
16: (2, 2, 3, 3, 3), | ||
19: (2, 2, 4, 4, 4) | ||
} | ||
|
||
def __init__(self, | ||
depth, | ||
num_classes=-1, | ||
num_stages=5, | ||
dilations=(1, 1, 1, 1, 1), | ||
out_indices=None, | ||
frozen_stages=-1, | ||
conv_cfg=None, | ||
norm_cfg=None, | ||
act_cfg=dict(type='ReLU'), | ||
norm_eval=False, | ||
ceil_mode=False, | ||
with_last_pool=True): | ||
super(VGG, self).__init__() | ||
if depth not in self.arch_settings: | ||
raise KeyError(f'invalid depth {depth} for vgg') | ||
assert num_stages >= 1 and num_stages <= 5 | ||
stage_blocks = self.arch_settings[depth] | ||
self.stage_blocks = stage_blocks[:num_stages] | ||
assert len(dilations) == num_stages | ||
|
||
self.num_classes = num_classes | ||
self.frozen_stages = frozen_stages | ||
self.norm_eval = norm_eval | ||
with_norm = norm_cfg is not None | ||
|
||
if out_indices is None: | ||
out_indices = (5, ) if num_classes > 0 else (4, ) | ||
assert max(out_indices) <= num_stages | ||
self.out_indices = out_indices | ||
|
||
self.in_channels = 3 | ||
start_idx = 0 | ||
vgg_layers = [] | ||
self.range_sub_modules = [] | ||
for i, num_blocks in enumerate(self.stage_blocks): | ||
num_modules = num_blocks + 1 | ||
end_idx = start_idx + num_modules | ||
dilation = dilations[i] | ||
out_channels = 64 * 2**i if i < 4 else 512 | ||
vgg_layer = make_vgg_layer( | ||
self.in_channels, | ||
out_channels, | ||
num_blocks, | ||
conv_cfg=conv_cfg, | ||
norm_cfg=norm_cfg, | ||
act_cfg=act_cfg, | ||
dilation=dilation, | ||
with_norm=with_norm, | ||
ceil_mode=ceil_mode) | ||
vgg_layers.extend(vgg_layer) | ||
self.in_channels = out_channels | ||
self.range_sub_modules.append([start_idx, end_idx]) | ||
start_idx = end_idx | ||
if not with_last_pool: | ||
vgg_layers.pop(-1) | ||
self.range_sub_modules[-1][1] -= 1 | ||
self.module_name = 'features' | ||
self.add_module(self.module_name, nn.Sequential(*vgg_layers)) | ||
|
||
if self.num_classes > 0: | ||
self.classifier = nn.Sequential( | ||
nn.Linear(512 * 7 * 7, 4096), | ||
nn.ReLU(True), | ||
nn.Dropout(), | ||
nn.Linear(4096, 4096), | ||
nn.ReLU(True), | ||
nn.Dropout(), | ||
nn.Linear(4096, num_classes), | ||
) | ||
|
||
def init_weights(self, pretrained=None): | ||
super(VGG, self).init_weights(pretrained) | ||
if pretrained is None: | ||
for m in self.modules(): | ||
if isinstance(m, nn.Conv2d): | ||
kaiming_init(m) | ||
elif isinstance(m, _BatchNorm): | ||
constant_init(m, 1) | ||
elif isinstance(m, nn.Linear): | ||
normal_init(m, std=0.01) | ||
else: | ||
raise TypeError('pretrained must be a str or None') | ||
|
||
def forward(self, x): | ||
outs = [] | ||
vgg_layers = getattr(self, self.module_name) | ||
for i in range(len(self.stage_blocks)): | ||
for j in range(*self.range_sub_modules[i]): | ||
vgg_layer = vgg_layers[j] | ||
x = vgg_layer(x) | ||
if i in self.out_indices: | ||
outs.append(x) | ||
if self.num_classes > 0: | ||
x = x.view(x.size(0), -1) | ||
x = self.classifier(x) | ||
outs.append(x) | ||
if len(outs) == 1: | ||
return outs[0] | ||
else: | ||
return tuple(outs) | ||
|
||
def _freeze_stages(self): | ||
vgg_layers = getattr(self, self.module_name) | ||
for i in range(self.frozen_stages): | ||
for j in range(*self.range_sub_modules[i]): | ||
m = vgg_layers[j] | ||
m.eval() | ||
for param in m.parameters(): | ||
param.requires_grad = False | ||
|
||
def train(self, mode=True): | ||
super(VGG, self).train(mode) | ||
self._freeze_stages() | ||
if mode and self.norm_eval: | ||
for m in self.modules(): | ||
# trick: eval have effect on BatchNorm only | ||
if isinstance(m, _BatchNorm): | ||
m.eval() |
Oops, something went wrong.