-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrainer.py
143 lines (117 loc) · 5.43 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import os
import time
import datetime
import itertools
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.autograd as autograd
from torch.utils.data import DataLoader
import torch.backends.cudnn as cudnn
import dataset
import utils
def Trainer(opt):
# ----------------------------------------
# Network training parameters
# ----------------------------------------
# Handle multiple GPUs
gpu_num = torch.cuda.device_count()
print("There are %d GPUs:" % (gpu_num))
opt.batch_size *= gpu_num
opt.num_workers *= gpu_num
# Create folders
save_model_folder = opt.save_path
utils.check_path(save_model_folder)
# cudnn benchmark
cudnn.benchmark = opt.cudnn_benchmark
# Loss functions
criterion_L1 = torch.nn.L1Loss().cuda()
# Initialize SGN
generator = utils.create_generator(opt)
# To device
if opt.multi_gpu:
generator = nn.DataParallel(generator)
generator = generator.cuda()
else:
generator = generator.cuda()
# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr = opt.lr, betas = (opt.b1, opt.b2), weight_decay = opt.weight_decay)
# Learning rate decrease
def adjust_learning_rate(opt, epoch, iteration, optimizer):
# Set the learning rate to the initial LR decayed by "lr_decrease_factor" every "lr_decrease_epoch" epochs
if opt.lr_decrease_mode == 'epoch':
lr = opt.lr * (opt.lr_decrease_factor ** (epoch // opt.lr_decrease_epoch))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
if opt.lr_decrease_mode == 'iter':
lr = opt.lr * (opt.lr_decrease_factor ** (iteration // opt.lr_decrease_iter))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# Save the model
def save_model(opt, epoch, iteration, len_dataset, generator):
# Define the name of trained model
if opt.save_mode == 'epoch':
model_name = 'G_epoch%d_bs%d.pth' % (epoch, opt.batch_size)
if opt.save_mode == 'iter':
model_name = 'G_iter%d_bs%d.pth' % (iteration, opt.batch_size)
save_model_path = os.path.join(opt.save_path, model_name)
# Save model
if opt.multi_gpu == True:
if opt.save_mode == 'epoch':
if (epoch % opt.save_by_epoch == 0) and (iteration % len_dataset == 0):
torch.save(generator.module.state_dict(), save_model_path)
print('The trained model is successfully saved at epoch %d' % (epoch))
if opt.save_mode == 'iter':
if iteration % opt.save_by_iter == 0:
torch.save(generator.module.state_dict(), save_model_path)
print('The trained model is successfully saved at iteration %d' % (iteration))
else:
if opt.save_mode == 'epoch':
if (epoch % opt.save_by_epoch == 0) and (iteration % len_dataset == 0):
torch.save(generator.state_dict(), save_model_path)
print('The trained model is successfully saved at epoch %d' % (epoch))
if opt.save_mode == 'iter':
if iteration % opt.save_by_iter == 0:
torch.save(generator.state_dict(), save_model_path)
print('The trained model is successfully saved at iteration %d' % (iteration))
# ----------------------------------------
# Network dataset
# ----------------------------------------
# Define the dataset
trainset = dataset.HS_multiscale_DSet(opt)
print('The overall number of images:', len(trainset))
# Define the dataloader
dataloader = DataLoader(trainset, batch_size = opt.batch_size, shuffle = True, num_workers = opt.num_workers, pin_memory = True)
# ----------------------------------------
# Training
# ----------------------------------------
# Count start time
prev_time = time.time()
# For loop training
for epoch in range(opt.epochs):
for i, (img_A, img_B) in enumerate(dataloader):
# To device
img_A = img_A.cuda()
img_B = img_B.cuda()
# Train Generator
optimizer_G.zero_grad()
# Forword propagation
recon_B = generator(img_A)
# Losses
loss = criterion_L1(recon_B, img_B)
# Overall Loss and optimize
loss.backward()
optimizer_G.step()
# Determine approximate time left
iters_done = epoch * len(dataloader) + i
iters_left = opt.epochs * len(dataloader) - iters_done
time_left = datetime.timedelta(seconds = iters_left * (time.time() - prev_time))
prev_time = time.time()
# Print log
print("\r[Epoch %d/%d] [Batch %d/%d] [Total Loss: %.4f] Time_left: %s" %
((epoch + 1), opt.epochs, i, len(dataloader), loss.item(), time_left))
# Save model at certain epochs or iterations
save_model(opt, (epoch + 1), (iters_done + 1), len(dataloader), generator)
# Learning rate decrease at certain epochs
adjust_learning_rate(opt, (epoch + 1), (iters_done + 1), optimizer_G)