Skip to content

Commit

Permalink
Fix bug and optimize MNIST config (open-mmlab#98)
Browse files Browse the repository at this point in the history
* add simple_test to ClsHead

* optimize lenet training config

* recover path setting
  • Loading branch information
wangruohui authored Nov 26, 2020
1 parent c0e7512 commit 44bbc71
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 5 deletions.
11 changes: 6 additions & 5 deletions configs/mnist/lenet5.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,23 @@
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label']),
]
test_pipeline = [
dict(type='Resize', size=32),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img']),
]
data = dict(
samples_per_gpu=128,
workers_per_gpu=2,
train=dict(
type=dataset_type, data_prefix='data/mnist', pipeline=train_pipeline),
val=dict(
type=dataset_type, data_prefix='data/mnist', pipeline=test_pipeline),
test=dict(
type=dataset_type, data_prefix='data/mnist', pipeline=test_pipeline))
evaluation = dict(
interval=5, metric='accuracy', metric_options={'topk': (1, )})
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=None)
Expand All @@ -40,14 +41,14 @@
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
interval=100,
interval=150,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
# runtime settings
runner = dict(type='EpochBasedRunner', max_epochs=20)
runner = dict(type='EpochBasedRunner', max_epochs=5)
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/mnist/'
Expand Down
13 changes: 13 additions & 0 deletions mmcls/models/heads/cls_head.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import torch
import torch.nn.functional as F

from mmcls.models.losses import Accuracy
from ..builder import HEADS, build_loss
from .base_head import BaseHead
Expand Down Expand Up @@ -43,3 +46,13 @@ def loss(self, cls_score, gt_label):
def forward_train(self, cls_score, gt_label):
losses = self.loss(cls_score, gt_label)
return losses

def simple_test(self, cls_score):
"""Test without augmentation."""
if isinstance(cls_score, list):
cls_score = sum(cls_score) / float(len(cls_score))
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
if torch.onnx.is_in_onnx_export():
return pred
pred = list(pred.detach().cpu().numpy())
return pred

0 comments on commit 44bbc71

Please sign in to comment.