Skip to content

Commit

Permalink
Add CIFAR10 configs and models (open-mmlab#38)
Browse files Browse the repository at this point in the history
* rename cifar configs

* add cifar10 configs

* fix linting

* add params, flops and accuracy in docs

* add cifar10 pretrained models and logs

* use oss-accelerate url

* del unused cifar10 configs
  • Loading branch information
yl-1993 authored Aug 26, 2020
1 parent b2fc837 commit 4b46fd6
Show file tree
Hide file tree
Showing 13 changed files with 86 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
dict(type='Collect', keys=['img', 'gt_label'])
]
data = dict(
samples_per_gpu=128,
samples_per_gpu=16,
workers_per_gpu=2,
train=dict(
type=dataset_type, data_prefix='data/cifar10',
Expand Down
16 changes: 16 additions & 0 deletions configs/_base_/models/resnet101_cifar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='ResNet_CIFAR',
depth=101,
num_stages=4,
out_indices=(3, ),
style='pytorch'),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=10,
in_channels=2048,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
))
16 changes: 16 additions & 0 deletions configs/_base_/models/resnet152_cifar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='ResNet_CIFAR',
depth=152,
num_stages=4,
out_indices=(3, ),
style='pytorch'),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=10,
in_channels=2048,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
))
16 changes: 16 additions & 0 deletions configs/_base_/models/resnet34_cifar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='ResNet_CIFAR',
depth=34,
num_stages=4,
out_indices=(3, ),
style='pytorch'),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=10,
in_channels=512,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
))
File renamed without changes.
5 changes: 5 additions & 0 deletions configs/cifar10/resnet101_b16x8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
_base_ = [
'../_base_/models/resnet101_cifar.py',
'../_base_/datasets/cifar10_bs16.py',
'../_base_/schedules/cifar10_bs128.py', '../_base_/default_runtime.py'
]
5 changes: 5 additions & 0 deletions configs/cifar10/resnet152_b16x8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
_base_ = [
'../_base_/models/resnet152_cifar.py',
'../_base_/datasets/cifar10_bs16.py',
'../_base_/schedules/cifar10_bs128.py', '../_base_/default_runtime.py'
]
4 changes: 0 additions & 4 deletions configs/cifar10/resnet18.py

This file was deleted.

4 changes: 4 additions & 0 deletions configs/cifar10/resnet18_b16x8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
_base_ = [
'../_base_/models/resnet18_cifar.py', '../_base_/datasets/cifar10_bs16.py',
'../_base_/schedules/cifar10_bs128.py', '../_base_/default_runtime.py'
]
4 changes: 4 additions & 0 deletions configs/cifar10/resnet34_b16x8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
_base_ = [
'../_base_/models/resnet34_cifar.py', '../_base_/datasets/cifar10_bs16.py',
'../_base_/schedules/cifar10_bs128.py', '../_base_/default_runtime.py'
]
4 changes: 0 additions & 4 deletions configs/cifar10/resnet50.py

This file was deleted.

4 changes: 4 additions & 0 deletions configs/cifar10/resnet50_b16x8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
_base_ = [
'../_base_/models/resnet50_cifar.py', '../_base_/datasets/cifar10_bs16.py',
'../_base_/schedules/cifar10_bs128.py', '../_base_/default_runtime.py'
]
19 changes: 15 additions & 4 deletions docs/model_zoo.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,21 @@ The ResNet family models below are trained by standard data augmentations, i.e.,
| ResNeXt-32x4d-101 | 44.18 | 8.03 | 78.7 | 94.34 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmclassification/v0/imagenet/resnext101_32x4d_batch256_20200708-87f2d1c9.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmclassification/v0/imagenet/resnext101_32x4d_batch256_20200708-87f2d1c9.log.json) |
| ResNeXt-32x8d-101 | 88.79 | 16.5 | 79.22 | 94.52 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmclassification/v0/imagenet/resnext101_32x8d_batch256_20200708-1ec34aa7.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmclassification/v0/imagenet/resnext101_32x8d_batch256_20200708-1ec34aa7.log.json) |
| ResNeXt-32x4d-152 | 59.95 | 11.8 | 79.06 | 94.47 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmclassification/v0/imagenet/resnext152_32x4d_batch256_20200708-aab5034c.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmclassification/v0/imagenet/resnext152_32x4d_batch256_20200708-aab5034c.log.json) |
| SE-ResNet-50 | 28.09 | 4.13 | 77.74 | 93.84 | [model](https://openmmlab.oss-cn-hangzhou.aliyuncs.com/mmclassification/v0/imagenet/se-resnet50_batch256_20200804-ae206104.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmclassification/v0/imagenet/se-resnet50_batch256_20200708-657b3c36.log.json) |
| SE-ResNet-101 | 49.33 | 7.86 | 78.26 | 94.07 | [model](https://openmmlab.oss-cn-hangzhou.aliyuncs.com/mmclassification/v0/imagenet/se-resnet101_batch256_20200804-ba5b51d4.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmclassification/v0/imagenet/se-resnet101_batch256_20200708-038a4d04.log.json) |
| ShuffleNetV1 1.0x (group=3) | 1.87 | 0.146 | 68.13 | 87.81 | [model](https://openmmlab.oss-cn-hangzhou.aliyuncs.com/mmclassification/v0/imagenet/shufflenet_v1_batch1024_20200804-5d6cec73.pth) | [log](https://openmmlab.oss-cn-hangzhou.aliyuncs.com/mmclassification/v0/imagenet/shufflenet_v1_batch1024_20200804-5d6cec73.log.json) |
| ShuffleNetV2 1.0x | 2.28 | 0.149 | 69.55 | 88.92 | [model](https://openmmlab.oss-cn-hangzhou.aliyuncs.com/mmclassification/v0/imagenet/shufflenet_v2_batch1024_20200812-5bf4721e.pth) | [log](https://openmmlab.oss-cn-hangzhou.aliyuncs.com/mmclassification/v0/imagenet/shufflenet_v2_batch1024_20200804-8860eec9.log.json) |
| SE-ResNet-50 | 28.09 | 4.13 | 77.74 | 93.84 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmclassification/v0/imagenet/se-resnet50_batch256_20200804-ae206104.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmclassification/v0/imagenet/se-resnet50_batch256_20200708-657b3c36.log.json) |
| SE-ResNet-101 | 49.33 | 7.86 | 78.26 | 94.07 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmclassification/v0/imagenet/se-resnet101_batch256_20200804-ba5b51d4.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmclassification/v0/imagenet/se-resnet101_batch256_20200708-038a4d04.log.json) |
| ShuffleNetV1 1.0x (group=3) | 1.87 | 0.146 | 68.13 | 87.81 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmclassification/v0/imagenet/shufflenet_v1_batch1024_20200804-5d6cec73.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmclassification/v0/imagenet/shufflenet_v1_batch1024_20200804-5d6cec73.log.json) |
| ShuffleNetV2 1.0x | 2.28 | 0.149 | 69.55 | 88.92 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmclassification/v0/imagenet/shufflenet_v2_batch1024_20200812-5bf4721e.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmclassification/v0/imagenet/shufflenet_v2_batch1024_20200804-8860eec9.log.json) |
| MobileNet V2 | 3.5 | 0.319 | 71.86 | 90.42 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmclassification/v0/imagenet/mobilenet_v2_batch256_20200708-3b2dc3af.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmclassification/v0/imagenet/mobilenet_v2_batch256_20200708-3b2dc3af.log.json) |

Models with * are converted from other repos, others are trained by ourselves.


## CIFAR10

| Model | Params(M) | Flops(G) | Top-1 (%) | Download |
|:---------------------:|:---------:|:--------:|:---------:|:--------:|
| ResNet-18-b16x8 | 11.17 | 0.56 | 94.72 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmclassification/v0/cifar10/resnet18_b16x8_20200823-f906fa4e.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmclassification/v0/cifar10/resnet18_b16x8_20200823-f906fa4e.log.json) |
| ResNet-34-b16x8 | 21.28 | 1.16 | 95.34 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmclassification/v0/cifar10/resnet34_b16x8_20200823-52d5d832.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmclassification/v0/cifar10/resnet34_b16x8_20200823-52d5d832.log.json) |
| ResNet-50-b16x8 | 23.52 | 1.31 | 95.36 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmclassification/v0/cifar10/resnet50_b16x8_20200823-882aa7b1.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmclassification/v0/cifar10/resnet50_b16x8_20200823-882aa7b1.log.json) |
| ResNet-101-b16x8 | 42.51 | 2.52 | 95.66 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmclassification/v0/cifar10/resnet101_b16x8_20200823-d9501bbc.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmclassification/v0/cifar10/resnet101_b16x8_20200823-d9501bbc.log.json) |
| ResNet-152-b16x8 | 58.16 | 3.74 | 95.96 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmclassification/v0/cifar10/resnet152_b16x8_20200823-ad4d5d0c.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmclassification/v0/cifar10/resnet152_b16x8_20200823-ad4d5d0c.log.json) |

0 comments on commit 4b46fd6

Please sign in to comment.