-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathimagenet_runner.py
120 lines (89 loc) · 3.97 KB
/
imagenet_runner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from typing import Dict, Union, Tuple, Optional
import torch
from torch import nn
from torch.utils.data import Dataset
from torchvision import models, datasets, transforms
from easytorch import Runner
from easytorch.device import to_device
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
class ImagenetRunner(Runner):
"""ImagenetRunner
"""
def __init__(self, cfg: Dict):
super().__init__(cfg)
self.criterion = nn.CrossEntropyLoss()
self.criterion = to_device(self.criterion)
def init_training(self, cfg: Dict):
super().init_training(cfg)
self.register_epoch_meter('train/loss', 'train', '{:.4e}')
self.register_epoch_meter('train/acc@1', 'train', '{:6.2f}')
self.register_epoch_meter('train/acc@5', 'train', '{:6.2f}')
def init_validation(self, cfg: Dict):
super().init_validation(cfg)
self.register_epoch_meter('val/loss', 'val', '{:.4e}')
self.register_epoch_meter('val/acc@1', 'val', '{:6.2f}')
self.register_epoch_meter('val/acc@5', 'val', '{:6.2f}')
@staticmethod
def define_model(cfg: Dict) -> nn.Module:
return models.__dict__[cfg['MODEL']['NAME']]()
@staticmethod
def build_train_dataset(cfg: Dict) -> Dataset:
normalize = transforms.Normalize(**cfg['TRAIN']['DATA']['NORMALIZE'])
return datasets.ImageFolder(
cfg['TRAIN']['DATA']['DIR'],
transforms.Compose([
transforms.RandomResizedCrop(cfg['TRAIN']['DATA']['CROP_SIZE']),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]))
@staticmethod
def build_val_dataset(cfg: Dict):
normalize = transforms.Normalize(**cfg['VAL']['DATA']['NORMALIZE'])
return datasets.ImageFolder(
cfg['VAL']['DATA']['DIR'],
transforms.Compose([
transforms.Resize(cfg['VAL']['DATA']['RESIZE']),
transforms.CenterCrop(cfg['VAL']['DATA']['CROP_SIZE']),
transforms.ToTensor(),
normalize,
]))
def train_iters(self, epoch: int, iter_index: int, data: Union[torch.Tensor, Tuple]) -> torch.Tensor:
images, target = data
images = to_device(images)
target = to_device(target)
output = self.model(images)
loss = self.criterion(output, target)
# measure accuracy and record loss
acc1, acc5 = accuracy(output, target, topk=(1, 5))
self.update_epoch_meter('train/loss', loss.item(), images.size(0))
self.update_epoch_meter('train/acc@1', acc1[0], images.size(0))
self.update_epoch_meter('train/acc@5', acc5[0], images.size(0))
return loss
def val_iters(self, iter_index: int, data: Union[torch.Tensor, Tuple]):
images, target = data
images = to_device(images)
target = to_device(target)
output = self.model(images)
loss = self.criterion(output, target)
# measure accuracy and record loss
acc1, acc5 = accuracy(output, target, topk=(1, 5))
self.update_epoch_meter('val/loss', loss.item(), images.size(0))
self.update_epoch_meter('val/acc@1', acc1[0], images.size(0))
self.update_epoch_meter('val/acc@5', acc5[0], images.size(0))
def on_validating_end(self, train_epoch: Optional[int]):
# `None` means validation mode
if train_epoch is not None:
self.save_best_model(train_epoch, 'val/acc@1', greater_best=True)