forked from usef-kh/fer
-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtrain.py
94 lines (72 loc) · 3.13 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import sys
import warnings
import torch
import torch.nn as nn
from torch.cuda.amp import GradScaler
from torch.optim.lr_scheduler import ReduceLROnPlateau
from data.fer2013 import get_dataloaders
from utils.checkpoint import save
from utils.hparams import setup_hparams
from utils.loops import train, evaluate
from utils.setup_network import setup_network
warnings.filterwarnings("ignore")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def run(net, logger, hps):
# Create dataloaders
trainloader, valloader, testloader = get_dataloaders(bs=hps["bs"])
net = net.to(device)
learning_rate = float(hps["lr"])
scaler = GradScaler()
# optimizer = torch.optim.Adadelta(net.parameters(), lr=learning_rate, weight_decay=0.0001)
# optimizer = torch.optim.Adagrad(net.parameters(), lr=learning_rate, weight_decay=0.0001)
# optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate, weight_decay=0.0001, amsgrad=True)
# optimizer = torch.optim.ASGD(net.parameters(), lr=learning_rate, weight_decay=0.0001)
optimizer = torch.optim.SGD(
net.parameters(),
lr=learning_rate,
momentum=0.9,
nesterov=True,
weight_decay=0.0001,
)
scheduler = ReduceLROnPlateau(
optimizer, mode="max", factor=0.75, patience=5, verbose=True
)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 20, gamma=0.5, last_epoch=-1, verbose=True)
# scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=len(trainloader), epochs=hps['n_epochs'])
# scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=1, eta_min=1e-6, last_epoch=-1, verbose=True)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-6, last_epoch=-1, verbose=False)
criterion = nn.CrossEntropyLoss()
best_acc = 0.0
print("Training", hps["name"], "on", device)
for epoch in range(hps["start_epoch"], hps["n_epochs"]):
acc_tr, loss_tr = train(net, trainloader, criterion, optimizer, scaler)
logger.loss_train.append(loss_tr)
logger.acc_train.append(acc_tr)
acc_v, loss_v = evaluate(net, valloader, criterion)
logger.loss_val.append(loss_v)
logger.acc_val.append(acc_v)
# Update learning rate
scheduler.step(acc_v)
if acc_v > best_acc:
best_acc = acc_v
save(net, logger, hps, epoch + 1)
logger.save_plt(hps)
if (epoch + 1) % hps["save_freq"] == 0:
save(net, logger, hps, epoch + 1)
logger.save_plt(hps)
print(
"Epoch %2d" % (epoch + 1),
"Train Accuracy: %2.4f %%" % acc_tr,
"Val Accuracy: %2.4f %%" % acc_v,
sep="\t\t",
)
# Calculate performance on test set
acc_test, loss_test = evaluate(net, testloader, criterion)
print(
"Test Accuracy: %2.4f %%" % acc_test, "Test Loss: %2.6f" % loss_test, sep="\t\t"
)
if __name__ == "__main__":
# Important parameters
hps = setup_hparams(sys.argv[1:])
logger, net = setup_network(hps)
run(net, logger, hps)