forked from open-mmlab/mmdetection
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add Faster rcnn c5 (open-mmlab#4043)
* add faster rcnn C5 * add trained faster rcnn C5 model url * fix typos * update based on 1-st comments * update unittest for ChannelMapper * small changes of ChannelMapper's docstring * small changes of ChannelMapper's docstring
- Loading branch information
Showing
8 changed files
with
308 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
# model settings | ||
norm_cfg = dict(type='BN', requires_grad=False) | ||
model = dict( | ||
type='FasterRCNN', | ||
pretrained='open-mmlab://detectron2/resnet50_caffe', | ||
backbone=dict( | ||
type='ResNet', | ||
depth=50, | ||
num_stages=4, | ||
strides=(1, 2, 2, 1), | ||
dilations=(1, 1, 1, 2), | ||
out_indices=(3, ), | ||
frozen_stages=1, | ||
norm_cfg=norm_cfg, | ||
norm_eval=True, | ||
style='caffe'), | ||
rpn_head=dict( | ||
type='RPNHead', | ||
in_channels=2048, | ||
feat_channels=2048, | ||
anchor_generator=dict( | ||
type='AnchorGenerator', | ||
scales=[2, 4, 8, 16, 32], | ||
ratios=[0.5, 1.0, 2.0], | ||
strides=[16]), | ||
bbox_coder=dict( | ||
type='DeltaXYWHBBoxCoder', | ||
target_means=[.0, .0, .0, .0], | ||
target_stds=[1.0, 1.0, 1.0, 1.0]), | ||
loss_cls=dict( | ||
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), | ||
loss_bbox=dict(type='L1Loss', loss_weight=1.0)), | ||
roi_head=dict( | ||
type='StandardRoIHead', | ||
bbox_roi_extractor=dict( | ||
type='SingleRoIExtractor', | ||
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), | ||
out_channels=2048, | ||
featmap_strides=[16]), | ||
bbox_head=dict( | ||
type='Shared2FCBBoxHead', | ||
in_channels=2048, | ||
fc_out_channels=1024, | ||
roi_feat_size=7, | ||
num_classes=80, | ||
bbox_coder=dict( | ||
type='DeltaXYWHBBoxCoder', | ||
target_means=[0., 0., 0., 0.], | ||
target_stds=[0.1, 0.1, 0.2, 0.2]), | ||
reg_class_agnostic=False, | ||
loss_cls=dict( | ||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), | ||
loss_bbox=dict(type='L1Loss', loss_weight=1.0)))) | ||
# model training and testing settings | ||
train_cfg = dict( | ||
rpn=dict( | ||
assigner=dict( | ||
type='MaxIoUAssigner', | ||
pos_iou_thr=0.7, | ||
neg_iou_thr=0.3, | ||
min_pos_iou=0.3, | ||
match_low_quality=True, | ||
ignore_iof_thr=-1), | ||
sampler=dict( | ||
type='RandomSampler', | ||
num=256, | ||
pos_fraction=0.5, | ||
neg_pos_ub=-1, | ||
add_gt_as_proposals=False), | ||
allowed_border=0, | ||
pos_weight=-1, | ||
debug=False), | ||
rpn_proposal=dict( | ||
nms_across_levels=False, | ||
nms_pre=12000, | ||
nms_post=2000, | ||
max_num=2000, | ||
nms_thr=0.7, | ||
min_bbox_size=0), | ||
rcnn=dict( | ||
assigner=dict( | ||
type='MaxIoUAssigner', | ||
pos_iou_thr=0.5, | ||
neg_iou_thr=0.5, | ||
min_pos_iou=0.5, | ||
match_low_quality=False, | ||
ignore_iof_thr=-1), | ||
sampler=dict( | ||
type='RandomSampler', | ||
num=512, | ||
pos_fraction=0.25, | ||
neg_pos_ub=-1, | ||
add_gt_as_proposals=True), | ||
pos_weight=-1, | ||
debug=False)) | ||
test_cfg = dict( | ||
rpn=dict( | ||
nms_across_levels=False, | ||
nms_pre=6000, | ||
nms_post=1000, | ||
max_num=1000, | ||
nms_thr=0.7, | ||
min_bbox_size=0), | ||
rcnn=dict( | ||
score_thr=0.05, | ||
nms=dict(type='nms', iou_threshold=0.5), | ||
max_per_img=100)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
_base_ = [ | ||
'../_base_/models/faster_rcnn_r50_caffe_dc5.py', | ||
'../_base_/datasets/coco_detection.py', | ||
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' | ||
] | ||
# use caffe img_norm | ||
img_norm_cfg = dict( | ||
mean=[103.530, 116.280, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False) | ||
train_pipeline = [ | ||
dict(type='LoadImageFromFile'), | ||
dict(type='LoadAnnotations', with_bbox=True), | ||
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True), | ||
dict(type='RandomFlip', flip_ratio=0.5), | ||
dict(type='Normalize', **img_norm_cfg), | ||
dict(type='Pad', size_divisor=32), | ||
dict(type='DefaultFormatBundle'), | ||
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']), | ||
] | ||
test_pipeline = [ | ||
dict(type='LoadImageFromFile'), | ||
dict( | ||
type='MultiScaleFlipAug', | ||
img_scale=(1333, 800), | ||
flip=False, | ||
transforms=[ | ||
dict(type='Resize', keep_ratio=True), | ||
dict(type='RandomFlip'), | ||
dict(type='Normalize', **img_norm_cfg), | ||
dict(type='Pad', size_divisor=32), | ||
dict(type='ImageToTensor', keys=['img']), | ||
dict(type='Collect', keys=['img']), | ||
]) | ||
] | ||
data = dict( | ||
train=dict(pipeline=train_pipeline), | ||
val=dict(pipeline=test_pipeline), | ||
test=dict(pipeline=test_pipeline)) |
42 changes: 42 additions & 0 deletions
42
configs/faster_rcnn/faster_rcnn_r50_caffe_dc5_mstrain_1x_coco.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
_base_ = [ | ||
'../_base_/models/faster_rcnn_r50_caffe_dc5.py', | ||
'../_base_/datasets/coco_detection.py', | ||
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' | ||
] | ||
# use caffe img_norm | ||
img_norm_cfg = dict( | ||
mean=[103.530, 116.280, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False) | ||
train_pipeline = [ | ||
dict(type='LoadImageFromFile'), | ||
dict(type='LoadAnnotations', with_bbox=True), | ||
dict( | ||
type='Resize', | ||
img_scale=[(1333, 640), (1333, 672), (1333, 704), (1333, 736), | ||
(1333, 768), (1333, 800)], | ||
multiscale_mode='value', | ||
keep_ratio=True), | ||
dict(type='RandomFlip', flip_ratio=0.5), | ||
dict(type='Normalize', **img_norm_cfg), | ||
dict(type='Pad', size_divisor=32), | ||
dict(type='DefaultFormatBundle'), | ||
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']), | ||
] | ||
test_pipeline = [ | ||
dict(type='LoadImageFromFile'), | ||
dict( | ||
type='MultiScaleFlipAug', | ||
img_scale=(1333, 800), | ||
flip=False, | ||
transforms=[ | ||
dict(type='Resize', keep_ratio=True), | ||
dict(type='RandomFlip'), | ||
dict(type='Normalize', **img_norm_cfg), | ||
dict(type='Pad', size_divisor=32), | ||
dict(type='ImageToTensor', keys=['img']), | ||
dict(type='Collect', keys=['img']), | ||
]) | ||
] | ||
data = dict( | ||
train=dict(pipeline=train_pipeline), | ||
val=dict(pipeline=test_pipeline), | ||
test=dict(pipeline=test_pipeline)) |
4 changes: 4 additions & 0 deletions
4
configs/faster_rcnn/faster_rcnn_r50_caffe_dc5_mstrain_3x_coco.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
_base_ = './faster_rcnn_r50_caffe_dc5_mstrain_1x_coco.py' | ||
# learning policy | ||
lr_config = dict(step=[28, 34]) | ||
total_epochs = 36 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import torch.nn as nn | ||
from mmcv.cnn import ConvModule, xavier_init | ||
|
||
from ..builder import NECKS | ||
|
||
|
||
@NECKS.register_module() | ||
class ChannelMapper(nn.Module): | ||
r"""Channel Mapper to reduce/increase channels of backbone features. | ||
This is used to reduce/increase channels of backbone features. | ||
Args: | ||
in_channels (List[int]): Number of input channels per scale. | ||
out_channels (int): Number of output channels (used at each scale). | ||
kernel_size (int, optional): kernel_size for reducing channels (used | ||
at each scale). Default: 3. | ||
conv_cfg (dict, optional): Config dict for convolution layer. | ||
Default: None. | ||
norm_cfg (dict, optional): Config dict for normalization layer. | ||
Default: None. | ||
act_cfg (dict, optional): Config dict for activation layer in | ||
ConvModule. Default: dict(type='ReLU'). | ||
Example: | ||
>>> import torch | ||
>>> in_channels = [2, 3, 5, 7] | ||
>>> scales = [340, 170, 84, 43] | ||
>>> inputs = [torch.rand(1, c, s, s) | ||
... for c, s in zip(in_channels, scales)] | ||
>>> self = ChannelMapper(in_channels, 11, 3).eval() | ||
>>> outputs = self.forward(inputs) | ||
>>> for i in range(len(outputs)): | ||
... print(f'outputs[{i}].shape = {outputs[i].shape}') | ||
outputs[0].shape = torch.Size([1, 11, 340, 340]) | ||
outputs[1].shape = torch.Size([1, 11, 170, 170]) | ||
outputs[2].shape = torch.Size([1, 11, 84, 84]) | ||
outputs[3].shape = torch.Size([1, 11, 43, 43]) | ||
""" | ||
|
||
def __init__(self, | ||
in_channels, | ||
out_channels, | ||
kernel_size=3, | ||
conv_cfg=None, | ||
norm_cfg=None, | ||
act_cfg=dict(type='ReLU')): | ||
super(ChannelMapper, self).__init__() | ||
assert isinstance(in_channels, list) | ||
|
||
self.convs = nn.ModuleList() | ||
for in_channel in in_channels: | ||
self.convs.append( | ||
ConvModule( | ||
in_channel, | ||
out_channels, | ||
kernel_size, | ||
padding=(kernel_size - 1) // 2, | ||
conv_cfg=conv_cfg, | ||
norm_cfg=norm_cfg, | ||
act_cfg=act_cfg)) | ||
|
||
# default init_weights for conv(msra) and norm in ConvModule | ||
def init_weights(self): | ||
"""Initialize the weights of ChannelMapper module.""" | ||
for m in self.modules(): | ||
if isinstance(m, nn.Conv2d): | ||
xavier_init(m, distribution='uniform') | ||
|
||
def forward(self, inputs): | ||
"""Forward function.""" | ||
assert len(inputs) == len(self.convs) | ||
outs = [self.convs[i](inputs[i]) for i in range(len(inputs))] | ||
return tuple(outs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters