-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathclassifier.py
118 lines (71 loc) · 3.2 KB
/
classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import pytorch_lightning as pl
import torch.nn as nn
import torch
class Classifier(pl.LightningModule):
def __init__(self, model, scheduler, optimizer, loss=None, save_every_epoch=False):
super().__init__()
self.model = model
self.scheduler = scheduler
self.optimizer = optimizer
self.best_accuracy = 0.0
self.save_every_epoch = save_every_epoch
if loss is None:
self.loss = nn.CrossEntropyLoss()
else:
self.loss = loss
def forward(self, x):
return self.model(x)
def training_step(self, train_batch, batch_idx):
img, label = train_batch
output = self.forward(img)
_, preds = torch.max(output, 1)
loss = self.loss(output, label)
corrects = torch.sum(preds == label.data) / img.shape[0]
if self.scheduler.__class__ == torch.optim.lr_scheduler.CosineAnnealingLR:
self.scheduler.step()
self.log('learning rate', self.scheduler.get_lr()[0])
return {'loss': loss, 'corrects': corrects}
def training_epoch_end(self, outputs):
losses = [x['loss'] for x in outputs]
corrects = [x['corrects'] for x in outputs]
avg_train_loss = sum(losses) / len(losses)
train_epoch_accuracy = sum(corrects) / len(corrects)
self.log('train accuracy', train_epoch_accuracy)
self.log('train loss', avg_train_loss)
def validation_step(self, val_batch, batch_idx):
img, label = val_batch
output = self.forward(img)
_, preds = torch.max(output, 1)
loss = self.loss(output, label)
corrects = torch.sum(preds == label.data) / img.shape[0]
return {'loss': loss, 'corrects': corrects}
def validation_epoch_end(self, outputs):
losses = [x['loss'] for x in outputs]
corrects = [x['corrects'] for x in outputs]
avg_val_loss = sum(losses) / len(losses)
val_epoch_accuracy = sum(corrects) / len(corrects)
if val_epoch_accuracy >= self.best_accuracy:
self.best_accuracy = val_epoch_accuracy
torch.save({
'epoch': self.current_epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'loss': self.loss,
}, './best.ckpt')
if self.save_every_epoch:
torch.save({
'epoch': self.current_epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'loss': self.loss,
}, './epoch_' + str(self.current_epoch) + '.ckpt')
self.log('val loss', avg_val_loss)
self.log('val accuracy', val_epoch_accuracy)
def test_step(self, batch, batch_idx):
return self.validation_step(batch, batch_idx)
def configure_optimizers(self):
optimizer = self.optimizer
scheduler = self.scheduler
if scheduler:
return [optimizer], [scheduler]
return [optimizer]