-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathBestLossSave.py
49 lines (38 loc) · 1.46 KB
/
BestLossSave.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
class BestLossSave(Callback):
"""
Saves the best weigths with the best training loss
Keeps the best loss accross several calls of fit.
Pass to fit() with:
cb = BestLossSave(learner, 'filename')
...
fit( ..., callbacks=[cb])
"""
def __init__(self, learner, filename=None, print_at_end=True):
self.learner = learner
self.epoch = 0
self.best_loss = 1000000
self.loss_sum = 0
self.filename = filename if filename is not None else "best_loss"
self.save_epoch = -1
self.print_at_end = print_at_end
self.num_batch = 0
def on_batch_end(self, loss):
self.loss_sum += loss
self.num_batch += 1
def on_epoch_end(self, vals):
if self.best_loss > self.loss_sum:
self.best_loss = self.loss_sum
self.learner.save(self.filename)
self.save_epoch = self.epoch
self.epoch += 1
self.loss_sum = 0
self.max_num_batch = self.num_batch
self.num_batch = 0
def get_stats(self):
return (self.epoch, self.save_epoch, self.best_loss/self.max_num_batch, self.filename)
def print_stats(self):
l = self.best_loss/self.max_num_batch
print(f"BestLossSave: Total epochs: {self.epoch}, last saved in epoch: {self.save_epoch}, loss: {l}, filename: {self.filename}")
def on_train_end(self):
if self.print_at_end:
self.print_stats()