-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathtrain.py
76 lines (63 loc) · 3.13 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from config import Config as conf
from data import *
import scipy.misc
from model import CGAN
import tensorflow as tf
import numpy as np
import time
import sys
def prepocess_train(img, cond,):
# img = scipy.misc.imresize(img, [conf.adjust_size, conf.adjust_size])
# cond = scipy.misc.imresize(cond, [conf.adjust_size, conf.adjust_size])
# h1 = int(np.ceil(np.random.uniform(1e-2, conf.adjust_size - conf.train_size)))
# w1 = int(np.ceil(np.random.uniform(1e-2, conf.adjust_size - conf.train_size)))
# img = img[h1:h1 + conf.train_size, w1:w1 + conf.train_size]
# cond = cond[h1:h1 + conf.train_size, w1:w1 + conf.train_size]
if np.random.random() > 0.5:
img = np.fliplr(img)
cond = np.fliplr(cond)
img = img/127.5 - 1.
cond = cond/127.5 - 1.
img = img.reshape(1, conf.img_size, conf.img_size, conf.img_channel)
cond = cond.reshape(1, conf.img_size, conf.img_size, conf.img_channel)
return img, cond
def train():
data = load_data()
model = CGAN()
d_opt = tf.train.AdamOptimizer(learning_rate=conf.learning_rateD,beta1=conf.beta1).minimize(model.d_loss, var_list=model.d_vars)
g_opt = tf.train.AdamOptimizer(learning_rate=conf.learning_rateG,beta1=conf.beta1).minimize(model.g_loss, var_list=model.g_vars)
saver = tf.train.Saver()
start_time = time.time()
if not os.path.exists(conf.data_path_checkpoint + "/checkpoint"):
os.makedirs(conf.data_path_checkpoint + "/checkpoint")
if not os.path.exists(conf.output_path):
os.makedirs(conf.output_path)
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
if conf.model_path_train == "":
sess.run(tf.global_variables_initializer())
else:
saver.restore(sess, conf.model_path_train)
for epoch in range(conf.max_epoch):
counter = 0
train_data = data["train"]()
for img, cond, name in train_data:
pimg, pcond = prepocess_train(img, cond)
_, m = sess.run([d_opt, model.d_loss], feed_dict={model.image:pimg, model.cond:pcond})
_, M = sess.run([g_opt, model.g_loss], feed_dict={model.image:pimg, model.cond:pcond})
counter += 1
if counter % 50 ==0:
print ("Epoch [%s], Iteration [%s]: time: %s, d_loss: %s, g_loss: %s" \
% (epoch, counter, time.time() - start_time, m, M))
if (epoch + 1) % conf.save_per_epoch == 0:
save_path = saver.save(sess, conf.data_path_checkpoint + "/checkpoint/" + "model_%d.ckpt" % (epoch+1))
print ("Model saved in file: %s" % (save_path))
if __name__ == "__main__":
if len(sys.argv) > 1 and sys.argv[1] == 'gpu=':
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]=str(sys.argv[1][4:])
else:
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]=str(0)
train()