From 44bbc71e144dad21940883a645bc48881d4523e7 Mon Sep 17 00:00:00 2001 From: WRH <12756472+wangruohui@users.noreply.github.com> Date: Thu, 26 Nov 2020 15:27:04 +0800 Subject: [PATCH] Fix bug and optimize MNIST config (#98) * add simple_test to ClsHead * optimize lenet training config * recover path setting --- configs/mnist/lenet5.py | 11 ++++++----- mmcls/models/heads/cls_head.py | 13 +++++++++++++ 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/configs/mnist/lenet5.py b/configs/mnist/lenet5.py index af4d4085e27..28ecb9bdc17 100644 --- a/configs/mnist/lenet5.py +++ b/configs/mnist/lenet5.py @@ -15,12 +15,13 @@ dict(type='Normalize', **img_norm_cfg), dict(type='ImageToTensor', keys=['img']), dict(type='ToTensor', keys=['gt_label']), + dict(type='Collect', keys=['img', 'gt_label']), ] test_pipeline = [ dict(type='Resize', size=32), dict(type='Normalize', **img_norm_cfg), dict(type='ImageToTensor', keys=['img']), - dict(type='ToTensor', keys=['gt_label']), + dict(type='Collect', keys=['img']), ] data = dict( samples_per_gpu=128, @@ -28,9 +29,9 @@ train=dict( type=dataset_type, data_prefix='data/mnist', pipeline=train_pipeline), val=dict( - type=dataset_type, data_prefix='data/mnist', pipeline=test_pipeline), - test=dict( type=dataset_type, data_prefix='data/mnist', pipeline=test_pipeline)) +evaluation = dict( + interval=5, metric='accuracy', metric_options={'topk': (1, )}) # optimizer optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001) optimizer_config = dict(grad_clip=None) @@ -40,14 +41,14 @@ checkpoint_config = dict(interval=1) # yapf:disable log_config = dict( - interval=100, + interval=150, hooks=[ dict(type='TextLoggerHook'), # dict(type='TensorboardLoggerHook') ]) # yapf:enable # runtime settings -runner = dict(type='EpochBasedRunner', max_epochs=20) +runner = dict(type='EpochBasedRunner', max_epochs=5) dist_params = dict(backend='nccl') log_level = 'INFO' work_dir = './work_dirs/mnist/' diff --git a/mmcls/models/heads/cls_head.py b/mmcls/models/heads/cls_head.py index f4fd71b6e13..77d0ba26c22 100644 --- a/mmcls/models/heads/cls_head.py +++ b/mmcls/models/heads/cls_head.py @@ -1,3 +1,6 @@ +import torch +import torch.nn.functional as F + from mmcls.models.losses import Accuracy from ..builder import HEADS, build_loss from .base_head import BaseHead @@ -43,3 +46,13 @@ def loss(self, cls_score, gt_label): def forward_train(self, cls_score, gt_label): losses = self.loss(cls_score, gt_label) return losses + + def simple_test(self, cls_score): + """Test without augmentation.""" + if isinstance(cls_score, list): + cls_score = sum(cls_score) / float(len(cls_score)) + pred = F.softmax(cls_score, dim=1) if cls_score is not None else None + if torch.onnx.is_in_onnx_export(): + return pred + pred = list(pred.detach().cpu().numpy()) + return pred