Skip to content

Commit

Permalink
Use build_runner (open-mmlab#54)
Browse files Browse the repository at this point in the history
* Use build_runner in train api

* Support iter in eval_hook

* Add runner section

* Add test_eval_hook

* Pin mmcv version in install docs

* Replace max_iters with max_epochs

* Set by_epoch=True as default

* Remove trailing space

* Replace DeprecationWarning with UserWarning

* pre-commit

* Fix tests
  • Loading branch information
daavoo authored Oct 15, 2020
1 parent f7a916f commit 4f4f295
Show file tree
Hide file tree
Showing 16 changed files with 262 additions and 24 deletions.
2 changes: 1 addition & 1 deletion configs/_base_/schedules/cifar10_bs128.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
optimizer_config = dict(grad_clip=None)
# learning policy
lr_config = dict(policy='step', step=[100, 150])
total_epochs = 200
runner = dict(type='EpochBasedRunner', max_epochs=200)
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@
warmup='constant',
warmup_iters=5000,
)
total_epochs = 300
runner = dict(type='EpochBasedRunner', max_epochs=300)
2 changes: 1 addition & 1 deletion configs/_base_/schedules/imagenet_bs2048.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
warmup_iters=2500,
warmup_ratio=0.25,
step=[30, 60, 90])
total_epochs = 100
runner = dict(type='EpochBasedRunner', max_epochs=100)
2 changes: 1 addition & 1 deletion configs/_base_/schedules/imagenet_bs2048_coslr.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
warmup='linear',
warmup_iters=2500,
warmup_ratio=0.25)
total_epochs = 100
runner = dict(type='EpochBasedRunner', max_epochs=100)
2 changes: 1 addition & 1 deletion configs/_base_/schedules/imagenet_bs256.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
optimizer_config = dict(grad_clip=None)
# learning policy
lr_config = dict(policy='step', step=[30, 60, 90])
total_epochs = 100
runner = dict(type='EpochBasedRunner', max_epochs=100)
2 changes: 1 addition & 1 deletion configs/_base_/schedules/imagenet_bs256_140e.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
optimizer_config = dict(grad_clip=None)
# learning policy
lr_config = dict(policy='step', step=[40, 80, 120])
total_epochs = 140
runner = dict(type='EpochBasedRunner', max_epochs=140)
2 changes: 1 addition & 1 deletion configs/_base_/schedules/imagenet_bs256_coslr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
optimizer_config = dict(grad_clip=None)
# learning policy
lr_config = dict(policy='CosineAnnealing', min_lr=0)
total_epochs = 100
runner = dict(type='EpochBasedRunner', max_epochs=100)
2 changes: 1 addition & 1 deletion configs/_base_/schedules/imagenet_bs256_epochstep.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
optimizer_config = dict(grad_clip=None)
# learning policy
lr_config = dict(policy='step', gamma=0.98, step=1)
total_epochs = 300
runner = dict(type='EpochBasedRunner', max_epochs=300)
2 changes: 1 addition & 1 deletion configs/mnist/lenet5.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
])
# yapf:enable
# runtime settings
total_epochs = 20
runner = dict(type='EpochBasedRunner', max_epochs=20)
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/mnist/'
Expand Down
2 changes: 1 addition & 1 deletion docs/install.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

- Python 3.6+
- PyTorch 1.3+
- [mmcv](https://github.com/open-mmlab/mmcv)
- [mmcv](https://github.com/open-mmlab/mmcv) 1.1.4+


### Install mmclassification
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/finetune.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ optimizer_config = dict(grad_clip=None)
lr_config = dict(
policy='step',
step=[15])
total_epochs = 20
runner = dict(type='EpochBasedRunner', max_epochs=200)
log_config = dict(interval=100)
```

Expand Down
32 changes: 24 additions & 8 deletions mmcls/apis/train.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import random
import warnings

import numpy as np
import torch
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import DistSamplerSeedHook, EpochBasedRunner, build_optimizer
from mmcv.runner import DistSamplerSeedHook, build_optimizer, build_runner

from mmcls.core import (DistEvalHook, DistOptimizerHook, EvalHook,
Fp16OptimizerHook)
Expand Down Expand Up @@ -70,12 +71,26 @@ def train_model(model,

# build runner
optimizer = build_optimizer(model, cfg.optimizer)
runner = EpochBasedRunner(
model,
optimizer=optimizer,
work_dir=cfg.work_dir,
logger=logger,
meta=meta)

if cfg.get('runner') is None:
cfg.runner = {
'type': 'EpochBasedRunner',
'max_epochs': cfg.total_epochs
}
warnings.warn(
'config is now expected to have a `runner` section, '
'please set `runner` in your config.', UserWarning)

runner = build_runner(
cfg.runner,
default_args=dict(
model=model,
batch_processor=None,
optimizer=optimizer,
work_dir=cfg.work_dir,
logger=logger,
meta=meta))

# an ugly walkaround to make the .log and .log.json filenames the same
runner.timestamp = timestamp

Expand Down Expand Up @@ -107,11 +122,12 @@ def train_model(model,
shuffle=False,
round_up=False)
eval_cfg = cfg.get('evaluation', {})
eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
eval_hook = DistEvalHook if distributed else EvalHook
runner.register_hook(eval_hook(val_dataloader, **eval_cfg))

if cfg.resume_from:
runner.resume(cfg.resume_from)
elif cfg.load_from:
runner.load_checkpoint(cfg.load_from)
runner.run(data_loaders, cfg.workflow, cfg.total_epochs)
runner.run(data_loaders, cfg.workflow)
31 changes: 28 additions & 3 deletions mmcls/core/evaluation/eval_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,30 @@ class EvalHook(Hook):
interval (int): Evaluation interval (by epochs). Default: 1.
"""

def __init__(self, dataloader, interval=1, **eval_kwargs):
def __init__(self, dataloader, interval=1, by_epoch=True, **eval_kwargs):
if not isinstance(dataloader, DataLoader):
raise TypeError('dataloader must be a pytorch DataLoader, but got'
f' {type(dataloader)}')
self.dataloader = dataloader
self.interval = interval
self.eval_kwargs = eval_kwargs
self.by_epoch = by_epoch

def after_train_epoch(self, runner):
if not self.every_n_epochs(runner, self.interval):
if not self.by_epoch or not self.every_n_epochs(runner, self.interval):
return
from mmcls.apis import single_gpu_test
results = single_gpu_test(runner.model, self.dataloader, show=False)
self.evaluate(runner, results)

def after_train_iter(self, runner):
if self.by_epoch or not self.every_n_iters(runner, self.interval):
return
from mmcls.apis import single_gpu_test
runner.log_buffer.clear()
results = single_gpu_test(runner.model, self.dataloader, show=False)
self.evaluate(runner, results)

def evaluate(self, runner, results):
eval_res = self.dataloader.dataset.evaluate(
results, logger=runner.logger, **self.eval_kwargs)
Expand All @@ -51,19 +60,35 @@ def __init__(self,
dataloader,
interval=1,
gpu_collect=False,
by_epoch=True,
**eval_kwargs):
if not isinstance(dataloader, DataLoader):
raise TypeError('dataloader must be a pytorch DataLoader, but got '
f'{type(dataloader)}')
self.dataloader = dataloader
self.interval = interval
self.gpu_collect = gpu_collect
self.by_epoch = by_epoch
self.eval_kwargs = eval_kwargs

def after_train_epoch(self, runner):
if not self.every_n_epochs(runner, self.interval):
if not self.by_epoch or not self.every_n_epochs(runner, self.interval):
return
from mmcls.apis import multi_gpu_test
results = multi_gpu_test(
runner.model,
self.dataloader,
tmpdir=osp.join(runner.work_dir, '.eval_hook'),
gpu_collect=self.gpu_collect)
if runner.rank == 0:
print('\n')
self.evaluate(runner, results)

def after_train_iter(self, runner):
if self.by_epoch or not self.every_n_iters(runner, self.interval):
return
from mmcls.apis import multi_gpu_test
runner.log_buffer.clear()
results = multi_gpu_test(
runner.model,
self.dataloader,
Expand Down
2 changes: 1 addition & 1 deletion requirements/readthedocs.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
mmcv
mmcv>=1.1.4
torch
torchvision
2 changes: 1 addition & 1 deletion requirements/runtime.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
matplotlib
mmcv
mmcv>=1.1.4
numpy
Loading

0 comments on commit 4f4f295

Please sign in to comment.