-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrains.py
77 lines (63 loc) · 2.56 KB
/
trains.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
"""
date: 2021/3/9 4:04 下午
written by: neonleexiang
"""
from model import SRCNN
from data_preprocessing import SRCNNLoader
import tensorflow as tf
import cv2 as cv
import numpy as np
import model_evaluation
# parameters
num_epochs = 1
batch_size = 5
learning_rate = 0.001 # 其实这里没有实现论文里面说过的一个变化的 learning_rate
# new a SRCNN model
model = SRCNN()
data_loader = SRCNNLoader() # tensorflow 就是通过 data loader 的方式去获取数据
# set optimizer by using Adam optimizer
# according to some readers SGD has better performance
# TODO: 或许可以通过 optimizers 的参数设置可以设置变化的learning_rate
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
# training by batches
num_batches = int(data_loader.num_train_data // batch_size * num_epochs)
for batch_index in range(num_batches):
X, y = data_loader.get_batch(batch_size)
# tensorflow 通过 GradientTape 来进行一个算法梯度下降以及传播
with tf.GradientTape() as tape:
y_pred = model(X)
# only using mean_squared_error
loss = tf.keras.losses.mean_squared_error(y_true=y, y_pred=y_pred)
loss = tf.reduce_mean(loss)
print("batch %d: loss %f" % (batch_index, loss.numpy()))
# gradient descent
grads = tape.gradient(loss, model.variables) # 进行求导并把误差传播
optimizer.apply_gradients(grads_and_vars=zip(grads, model.variables))
print('training ended')
tf.saved_model.save(model, 'saved/1')
# model.save('SRCNN.h5')
# accuracy = model_evaluation.Self_Defined_psnr_accuracy()
num_batches = int(data_loader.num_test_data // batch_size)
# print(num_batches)
psnr_result = []
for batch_index in range(num_batches):
start_index, end_index = batch_index * batch_size, (batch_index + 1) * batch_size
y_pred = model.predict(data_loader.test_data[start_index: end_index])
y_true = data_loader.test_label[start_index: end_index]
for t, p in zip(y_true, y_pred):
psnr_result.append(model_evaluation.psnr(t * 255., p * 255.))
print("test accuracy: %f" % np.mean(psnr_result))
# testing ---------------->
# i = 0
# for img in data_loader.test_data:
# i += 1
# img = np.expand_dims(img, axis=0)
# y_pred = model.predict(img)
# print('> --------------- ')
# print(y_pred.shape)
# print('> saving ---------- ')
# # print(img[0])
# # print(y_pred[0])
# # break
# cv.imwrite('result/data_pred/{}-img.png'.format(str(i)), img[0] * 255)
# cv.imwrite('result/data_pred/{}-pred.png'.format(str(i)), y_pred[0] * 255)