-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrainingplot.py
39 lines (32 loc) · 1.62 KB
/
trainingplot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# code to plot nn losses throughout training from
# https://github.com/kapil-varshney/utilities/blob/master/training_plot/training_plot_ex_with_cifar10.ipynb
class TrainingPlot(keras.callbacks.Callback):
# This function is called when the training begins
def on_train_begin(self, logs={}):
# Initialize the lists for holding the logs, losses and accuracies
self.losses = []
self.val_losses = []
self.logs = []
# This function is called at the end of each epoch
def on_epoch_end(self, epoch, logs={}):
# Append the logs, losses and accuracies to the lists
self.logs.append(logs)
self.losses.append(logs.get('loss'))
self.val_losses.append(logs.get('val_loss'))
# Before plotting ensure at least 2 epochs have passed
if len(self.losses) > 1:
# Clear the previous plot
clear_output(wait=True)
N = np.arange(0, len(self.losses))
# You can chose the style of your preference
# print(plt.style.available) to see the available options
plt.style.use("seaborn")
# Plot train loss, train acc, val loss and val acc against epochs passed
plt.figure(figsize=(15,10))
plt.plot(N, self.losses, label = "train_loss")
plt.plot(N, self.val_losses, label = "val_loss")
plt.title("Training Loss [Epoch {}]".format(epoch), size = 35)
plt.xlabel("Epoch #", size = 30)
plt.ylabel("Loss", size = 30)
plt.legend(prop={'size':20})
plt.show()