Skip to content

Commit

Permalink
[Feature] Support SOLOv2 (open-mmlab#7441)
Browse files Browse the repository at this point in the history
* solov2 init

* solov2 r18 lightweight

* add model docstrings and reformat the code

* add docstrings to model method

* add solov2 big model config and correct some errors in the docstring

* fix linting issues

* refactor code and configs

* rename variables according to the convention

* add and enhance solov2 logic

* add doc strings

* update solov2 config files

* fix norm_cfg in mask head

* minor fix

* update configs

Co-authored-by: BIGWangYuDong <[email protected]>
  • Loading branch information
zhanggefan and BIGWangYuDong authored May 25, 2022
1 parent 151a803 commit d18cdb1
Show file tree
Hide file tree
Showing 16 changed files with 1,226 additions and 4 deletions.
59 changes: 59 additions & 0 deletions configs/solov2/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# SOLOv2

> [SOLOv2: Dynamic and Fast Instance Segmentation](https://arxiv.org/abs/2003.10152)
<!-- [ALGORITHM] -->

## Abstract

In this work, we aim at building a simple, direct, and fast instance segmentation
framework with strong performance. We follow the principle of the SOLO method of
Wang et al. "SOLO: segmenting objects by locations". Importantly, we take one
step further by dynamically learning the mask head of the object segmenter such
that the mask head is conditioned on the location. Specifically, the mask branch
is decoupled into a mask kernel branch and mask feature branch, which are
responsible for learning the convolution kernel and the convolved features
respectively. Moreover, we propose Matrix NMS (non maximum suppression) to
significantly reduce the inference time overhead due to NMS of masks. Our
Matrix NMS performs NMS with parallel matrix operations in one shot, and
yields better results. We demonstrate a simple direct instance segmentation
system, outperforming a few state-of-the-art methods in both speed and accuracy.
A light-weight version of SOLOv2 executes at 31.3 FPS and yields 37.1% AP.
Moreover, our state-of-the-art results in object detection (from our mask byproduct)
and panoptic segmentation show the potential to serve as a new strong baseline
for many instance-level recognition tasks besides instance segmentation.

<div align=center>
<img src="https://user-images.githubusercontent.com/48282753/167235090-f20dab74-43a5-44ed-9f11-4e5f08866f45.png"/>
</div>

## Results and Models

### SOLOv2

| Backbone | Style | MS train | Lr schd | Mem (GB) | mask AP | Config | Download |
| :--------: | :-----: | :------: | :-----: | :------: | :-----: | :-----------------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| R-50 | pytorch | N | 1x | 5.1 | 34.8 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/solov2/solov2_r50_fpn_1x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/solov2/solov2_r50_fpn_1x_coco/solov2_r50_fpn_1x_coco_20220512_125858-a357fa23.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/solov2/solov2_r50_fpn_1x_coco/solov2_r50_fpn_1x_coco_20220512_125858.log.json) |
| R-50 | pytorch | Y | 3x | 5.1 | 37.5 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/solov2/solov2_r50_fpn_3x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/solov2/solov2_r50_fpn_3x_coco/solov2_r50_fpn_3x_coco_20220512_125856-fed092d4.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/solov2/solov2_r50_fpn_3x_coco/solov2_r50_fpn_3x_coco_20220512_125856.log.json) |
| R-101 | pytorch | Y | 3x | 6.9 | 39.1 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/solov2/solov2_r101_fpn_3x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/solov2/solov2_r101_fpn_3x_coco/solov2_r101_fpn_3x_coco_20220511_095119-c559a076.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/solov2/solov2_r101_fpn_3x_coco/solov2_r101_fpn_3x_coco_20220511_095119.log.json) |
| R-101(DCN) | pytorch | Y | 3x | 7.1 | 41.2 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/solov2/solov2_r101_dcn_fpn_3x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/solov2/solov2_r101_dcn_fpn_3x_coco/solov2_r101_dcn_fpn_3x_coco_20220513_214734-16c966cb.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/solov2/solov2_r101_dcn_fpn_3x_coco/solov2_r101_dcn_fpn_3x_coco_20220513_214734.log.json) |
| X-101(DCN) | pytorch | Y | 3x | 11.3 | 42.4 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/solov2/solov2_x101_dcn_fpn_3x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/solov2/solov2_x101_dcn_fpn_3x_coco/solov2_x101_dcn_fpn_3x_coco_20220513_214337-aef41095.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/solov2/solov2_x101_dcn_fpn_3x_coco/solov2_x101_dcn_fpn_3x_coco_20220513_214337.log.json) |

### Light SOLOv2

| Backbone | Style | MS train | Lr schd | Mem (GB) | mask AP | Config | Download |
| :------: | :-----: | :------: | :-----: | :------: | :-----: | :------------------------------------------------------------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| R-18 | pytorch | Y | 3x | 9.1 | 29.7 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/solov2/solov2_light_r18_fpn_3x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/solov2/solov2_light_r18_fpn_3x_coco/solov2_light_r18_fpn_3x_coco_20220511_083717-75fa355b.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/solov2/solov2_light_r18_fpn_3x_coco/solov2_light_r18_fpn_3x_coco_20220511_083717.log.json) |
| R-34 | pytorch | Y | 3x | 9.3 | 31.9 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/solov2/solov2_light_r34_fpn_3x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/solov2/solov2_light_r34_fpn_3x_coco/solov2_light_r34_fpn_3x_coco_20220511_091839-e51659d3.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/solov2/solov2_light_r34_fpn_3x_coco/solov2_light_r34_fpn_3x_coco_20220511_091839.log.json) |
| R-50 | pytorch | Y | 3x | 9.9 | 33.7 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/solov2/solov2_light_r50_fpn_3x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/solov2/solov2_light_r50_fpn_3x_coco/solov2_light_r50_fpn_3x_coco_20220512_165256-c93a6074.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/solov2/solov2_light_r50_fpn_3x_coco/solov2_light_r50_fpn_3x_coco_20220512_165256.log.json) |

## Citation

```latex
@article{wang2020solov2,
title={SOLOv2: Dynamic and Fast Instance Segmentation},
author={Wang, Xinlong and Zhang, Rufeng and Kong, Tao and Li, Lei and Shen, Chunhua},
journal={Proc. Advances in Neural Information Processing Systems (NeurIPS)},
year={2020}
}
```
119 changes: 119 additions & 0 deletions configs/solov2/metafile.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
Collections:
- Name: SOLOv2
Metadata:
Training Data: COCO
Training Techniques:
- SGD with Momentum
- Weight Decay
Training Resources: 8x A100 GPUs
Architecture:
- FPN
- Convolution
- ResNet
Paper: https://arxiv.org/abs/2003.10152
README: configs/solov2/README.md

Models:
- Name: solov2_r50_fpn_1x_coco
In Collection: SOLOv2
Config: configs/solov2/solov2_r50_fpn_1x_coco.py
Metadata:
Training Memory (GB): 5.1
Epochs: 12
Results:
- Task: Instance Segmentation
Dataset: COCO
Metrics:
mask AP: 34.8
Weights: https://download.openmmlab.com/mmdetection/v2.0/solov2/solov2_r50_fpn_1x_coco/solov2_r50_fpn_1x_coco_20220512_125858-a357fa23.pth

- Name: solov2_r50_fpn_3x_coco
In Collection: SOLOv2
Config: configs/solov2/solov2_r50_fpn_3x_coco.py
Metadata:
Training Memory (GB): 5.1
Epochs: 36
Results:
- Task: Instance Segmentation
Dataset: COCO
Metrics:
mask AP: 37.5
Weights: https://download.openmmlab.com/mmdetection/v2.0/solov2/solov2_r50_fpn_3x_coco/solov2_r50_fpn_3x_coco_20220512_125856-fed092d4.pth

- Name: solov2_r101_fpn_3x_coco
In Collection: SOLOv2
Config: configs/solov2/solov2_r101_fpn_3x_coco.py
Metadata:
Training Memory (GB): 6.9
Epochs: 36
Results:
- Task: Instance Segmentation
Dataset: COCO
Metrics:
mask AP: 39.1
Weights: https://download.openmmlab.com/mmdetection/v2.0/solov2/solov2_r101_fpn_3x_coco/solov2_r101_fpn_3x_coco_20220511_095119-c559a076.pth

- Name: solov2_r101_dcn_fpn_3x_coco
In Collection: SOLOv2
Config: configs/solov2/solov2_r101_dcn_fpn_3x_coco.py
Metadata:
Training Memory (GB): 7.1
Epochs: 36
Results:
- Task: Instance Segmentation
Dataset: COCO
Metrics:
mask AP: 41.2
Weights: https://download.openmmlab.com/mmdetection/v2.0/solov2/solov2_r101_dcn_fpn_3x_coco/solov2_r101_dcn_fpn_3x_coco_20220513_214734-16c966cb.pth

- Name: solov2_x101_dcn_fpn_3x_coco
In Collection: SOLOv2
Config: configs/solov2/solov2_x101_dcn_fpn_3x_coco.py
Metadata:
Training Memory (GB): 11.3
Epochs: 36
Results:
- Task: Instance Segmentation
Dataset: COCO
Metrics:
mask AP: 42.4
Weights: https://download.openmmlab.com/mmdetection/v2.0/solov2/solov2_x101_dcn_fpn_3x_coco/solov2_x101_dcn_fpn_3x_coco_20220513_214337-aef41095.pth

- Name: solov2_light_r18_fpn_3x_coco
In Collection: SOLOv2
Config: configs/solov2/solov2_light_r18_fpn_3x_coco.py
Metadata:
Training Memory (GB): 9.1
Epochs: 36
Results:
- Task: Instance Segmentation
Dataset: COCO
Metrics:
mask AP: 29.7
Weights: https://download.openmmlab.com/mmdetection/v2.0/solov2/solov2_light_r18_fpn_3x_coco/solov2_light_r18_fpn_3x_coco_20220511_083717-75fa355b.pth

- Name: solov2_light_r34_fpn_3x_coco
In Collection: SOLOv2
Config: configs/solov2/solov2_light_r34_fpn_3x_coco.py
Metadata:
Training Memory (GB): 9.3
Epochs: 36
Results:
- Task: Instance Segmentation
Dataset: COCO
Metrics:
mask AP: 31.9
Weights: https://download.openmmlab.com/mmdetection/v2.0/solov2/solov2_light_r34_fpn_3x_coco/solov2_light_r34_fpn_3x_coco_20220511_091839-e51659d3.pth

- Name: solov2_light_r50_fpn_3x_coco
In Collection: SOLOv2
Config: configs/solov2/solov2_light_r50_fpn_3x_coco.py
Metadata:
Training Memory (GB): 9.9
Epochs: 36
Results:
- Task: Instance Segmentation
Dataset: COCO
Metrics:
mask AP: 33.7
Weights: https://download.openmmlab.com/mmdetection/v2.0/solov2/solov2_light_r50_fpn_3x_coco/solov2_light_r50_fpn_3x_coco_20220512_165256-c93a6074.pth
7 changes: 7 additions & 0 deletions configs/solov2/solov2_light_r18_fpn_3x_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
_base_ = 'solov2_light_r50_fpn_3x_coco.py'

# model settings
model = dict(
backbone=dict(
depth=18, init_cfg=dict(checkpoint='torchvision://resnet18')),
neck=dict(in_channels=[64, 128, 256, 512]))
7 changes: 7 additions & 0 deletions configs/solov2/solov2_light_r34_fpn_3x_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
_base_ = 'solov2_light_r50_fpn_3x_coco.py'

# model settings
model = dict(
backbone=dict(
depth=34, init_cfg=dict(checkpoint='torchvision://resnet34')),
neck=dict(in_channels=[64, 128, 256, 512]))
62 changes: 62 additions & 0 deletions configs/solov2/solov2_light_r50_dcn_fpn_3x_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
_base_ = 'solov2_r50_fpn_3x_coco.py'

# model settings
model = dict(
backbone=dict(
dcn=dict(type='DCNv2', deformable_groups=1, fallback_on_stride=False),
stage_with_dcn=(False, True, True, True)),
mask_head=dict(
feat_channels=256,
stacked_convs=3,
scale_ranges=((1, 64), (32, 128), (64, 256), (128, 512), (256, 2048)),
mask_feature_head=dict(out_channels=128),
dcn_cfg=dict(type='DCNv2'),
dcn_apply_to_all_conv=False)) # light solov2 head

# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=1.0 / 3,
step=[27, 33])
runner = dict(type='EpochBasedRunner', max_epochs=36)

# data
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
dict(
type='Resize',
img_scale=[(768, 512), (768, 480), (768, 448), (768, 416), (768, 384),
(768, 352)],
multiscale_mode='value',
keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(448, 768),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]

data = dict(
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))
57 changes: 57 additions & 0 deletions configs/solov2/solov2_light_r50_fpn_3x_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
_base_ = 'solov2_r50_fpn_1x_coco.py'

# model settings
model = dict(
mask_head=dict(
stacked_convs=2,
feat_channels=256,
scale_ranges=((1, 56), (28, 112), (56, 224), (112, 448), (224, 896)),
mask_feature_head=dict(out_channels=128)))

# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=1.0 / 3,
step=[27, 33])
runner = dict(type='EpochBasedRunner', max_epochs=36)

# data
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
dict(
type='Resize',
img_scale=[(768, 512), (768, 480), (768, 448), (768, 416), (768, 384),
(768, 352)],
multiscale_mode='value',
keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(448, 768),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]

data = dict(
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))
13 changes: 13 additions & 0 deletions configs/solov2/solov2_r101_dcn_fpn_3x_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
_base_ = 'solov2_r50_fpn_3x_coco.py'

# model settings
model = dict(
backbone=dict(
depth=101,
init_cfg=dict(checkpoint='torchvision://resnet101'),
dcn=dict(type='DCNv2', deformable_groups=1, fallback_on_stride=False),
stage_with_dcn=(False, True, True, True)),
mask_head=dict(
mask_feature_head=dict(conv_cfg=dict(type='DCNv2')),
dcn_cfg=dict(type='DCNv2'),
dcn_apply_to_all_conv=True))
6 changes: 6 additions & 0 deletions configs/solov2/solov2_r101_fpn_3x_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
_base_ = 'solov2_r50_fpn_3x_coco.py'

# model settings
model = dict(
backbone=dict(
depth=101, init_cfg=dict(checkpoint='torchvision://resnet101')))
61 changes: 61 additions & 0 deletions configs/solov2/solov2_r50_fpn_1x_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
_base_ = [
'../_base_/datasets/coco_instance.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]

# model settings
model = dict(
type='SOLOv2',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50'),
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=0,
num_outs=5),
mask_head=dict(
type='SOLOV2Head',
num_classes=80,
in_channels=256,
feat_channels=512,
stacked_convs=4,
strides=[8, 8, 16, 32, 32],
scale_ranges=((1, 96), (48, 192), (96, 384), (192, 768), (384, 2048)),
pos_scale=0.2,
num_grids=[40, 36, 24, 16, 12],
cls_down_index=0,
mask_feature_head=dict(
feat_channels=128,
start_level=0,
end_level=3,
out_channels=256,
mask_stride=4,
norm_cfg=dict(type='GN', num_groups=32, requires_grad=True)),
loss_mask=dict(type='DiceLoss', use_sigmoid=True, loss_weight=3.0),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0)),
# model training and testing settings
test_cfg=dict(
nms_pre=500,
score_thr=0.1,
mask_thr=0.5,
filter_thr=0.05,
kernel='gaussian', # gaussian/linear
sigma=2.0,
max_per_img=100))

# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(
_delete_=True, grad_clip=dict(max_norm=35, norm_type=2))
Loading

0 comments on commit d18cdb1

Please sign in to comment.