-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathinpaint.py
434 lines (378 loc) · 17.4 KB
/
inpaint.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import pdb
import _thread
import tensorflow as tf
import numpy as np
from utils import save
from utils import load
from utils import read_by_batch
from utils import preprocess_image
from utils import save_images
from utils import compute_psnr_ssim
FLAG = tf.app.flags
FLAGS = FLAG.FLAGS
FLAG.DEFINE_string('dataname', '180256', 'The dataset name')
FLAG.DEFINE_string('mode', 'train', 'Chose mode from (train, validate, test)')
FLAG.DEFINE_string('block', 'center', 'Block location, chose from (center, random), \
only valid in validate or test mode')
FLAG.DEFINE_float('learning_rate', 0.0002, 'Initial learning rate.')
FLAG.DEFINE_float('lamb_rec', 0.999, 'Weight for reconstruct loss.')
FLAG.DEFINE_float('lamb_adv', 0.001, 'Weight for adversarial loss.')
FLAG.DEFINE_float('lamb_tv', 1e-6, 'Weight for TV loss.')
FLAG.DEFINE_boolean('use_l1', False,
'If True, use L1 distance for rec_loss, else use L2 distance')
FLAG.DEFINE_boolean('weight_clip', False,
'When updating G & D, clip weights or not.')
FLAG.DEFINE_integer('image_size', 256, 'Image size.')
##
FLAG.DEFINE_integer('image_size_h', 180, 'Image size.')
FLAG.DEFINE_integer('image_size_w', 256, 'Image size.')
##
FLAG.DEFINE_integer('image_channel', 1,
'Image channel, grayscale should be 1 and RGB should be 3')
FLAG.DEFINE_integer('start_epoch', 0, 'Number of epochs the trainer will run.')
FLAG.DEFINE_integer('epoch',230, 'Number of epochs the trainer will run.')
FLAG.DEFINE_integer('batch_size',16, 'Batch size.')
FLAG.DEFINE_integer('ckpt', 1, 'Save checkpoint every ? epochs.')
FLAG.DEFINE_integer('sample', 1, 'Get sample every ? epochs.')
FLAG.DEFINE_integer('summary', 1, 'Get summary every ? steps.')
FLAG.DEFINE_integer('gene_iter', 5,
'Train generator how many times every batch.')
FLAG.DEFINE_integer('disc_iter', 1,
'Train discriminator how many times every batch.')
FLAG.DEFINE_integer('gpu', 0, 'GPU No.')
TRAIN = False
VALIDATE = False
TEST = False
if FLAGS.mode == 'train':
TRAIN = True
elif FLAGS.mode == 'validate':
VALIDATE = True
elif FLAGS.mode == 'test':
TEST = True
else:
print('mode not recognized, process terminated.')
exit()
BATCH_SIZE = FLAGS.batch_size
IMAGE_SIZE = FLAGS.image_size
##
IMAGE_SIZE_H = FLAGS.image_size_h
IMAGE_SIZE_W = FLAGS.image_size_w
##
IMAGE_CHANNEL = FLAGS.image_channel
HIDDEN_SIZE_H=IMAGE_SIZE_H
HIDDEN_SIZE_W=64
if IMAGE_SIZE == 256:
from model import generator
from model import local_discriminator, global_discriminator
else:
print('image_size not supported, process terminated.')
exit()
DATANAME = FLAGS.dataname
DATA_PATH = os.path.join('data', DATANAME + '_' + FLAGS.mode + '.bin')
CHECKPOINT_DIR = os.path.join('checkpoints')
LOG_DIR = os.path.join('logs')
SAMPLE_DIR = os.path.join('samples', FLAGS.mode)
BETA1 = 0.5
BETA2 = 0.9
LAMB_REC = FLAGS.lamb_rec
LAMB_ADV = FLAGS.lamb_adv
LAMB_TV = FLAGS.lamb_tv
WEIGHT_DECAY_RATE = 0.00001
def run_model(sess):
masked_image_holder = tf.placeholder(tf.float32, [
BATCH_SIZE, IMAGE_SIZE_H, IMAGE_SIZE_W, IMAGE_CHANNEL], name='masked_image')
hidden_image_holder = tf.placeholder(tf.float32, [
BATCH_SIZE, HIDDEN_SIZE_H, HIDDEN_SIZE_W, IMAGE_CHANNEL], name='hidden_image')
loc_fake_image = generator(masked_image_holder, BATCH_SIZE,
IMAGE_CHANNEL, is_train=TRAIN, no_reuse=True)
tmp_image_holder = tf.strided_slice(masked_image_holder,[0,0,HIDDEN_SIZE_W],[BATCH_SIZE, IMAGE_SIZE_H,IMAGE_SIZE_W])
global_real_image = tf.concat([hidden_image_holder,tmp_image_holder],axis=2)
global_fake_image = tf.concat([loc_fake_image, tmp_image_holder],axis=2)
adv_real_score = local_discriminator(
hidden_image_holder, BATCH_SIZE, is_train=TRAIN)
adv_fake_score = local_discriminator(
loc_fake_image, BATCH_SIZE, reuse=True, is_train=TRAIN)
global_real_score = global_discriminator(
global_real_image, BATCH_SIZE, is_train=TRAIN)
global_fake_score = global_discriminator(
global_fake_image, BATCH_SIZE, is_train=TRAIN)
final_real_score = tf.add(adv_real_score, global_real_score)
final_fake_score = tf.add(adv_fake_score, global_fake_score)
adv_all_score = tf.concat([final_real_score, final_fake_score], axis=0)
labels_disc = tf.concat(
[tf.ones([BATCH_SIZE]), tf.zeros([BATCH_SIZE])], axis=0)
###11111.....1111,00000......0000
labels_gene = tf.ones([BATCH_SIZE])
###11111111111111
correct = tf.equal(labels_disc, tf.round(tf.nn.sigmoid(adv_all_score)))
disc_acc = tf.reduce_mean(tf.cast(correct, tf.float32))
sampler = generator(masked_image_holder, BATCH_SIZE,
IMAGE_CHANNEL, is_train=False)
# 损失计算
if FLAGS.use_l1:#重建损失函数L_rec
rec_loss = tf.reduce_mean(
tf.concat([tf.abs(tf.subtract(global_real_image, global_fake_image)), tf.abs(tf.subtract(hidden_image_holder, loc_fake_image))],axis=2))
else:
# print(np.shape(hidden_image_holder),'-----')
# print(np.shape(fake_image))
rec_loss = tf.reduce_mean(
tf.concat([tf.squared_difference(hidden_image_holder, loc_fake_image),tf.squared_difference(global_real_image, global_fake_image)],axis=2))
## tf.squared_difference均值
adv_disc_loss = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(
labels=labels_disc, logits=adv_all_score))
adv_gene_loss = tf.reduce_mean(#上下文编码器的对抗性损失L_adv
tf.nn.sigmoid_cross_entropy_with_logits(
labels=labels_gene, logits=adv_fake_score))
tv_loss = tf.reduce_mean(tf.add(tf.image.total_variation(global_fake_image),tf.image.total_variation(loc_fake_image)))
gene_loss_ori = LAMB_ADV * adv_gene_loss + \
LAMB_REC * rec_loss + LAMB_TV * tv_loss
disc_loss_ori = LAMB_ADV * adv_disc_loss
# 生成器和判别器要更新的变量,用于 tf.train.Optimizer 的 var_list
all_vars = tf.trainable_variables()
gene_vars = [var for var in all_vars if 'generator' in var.name]
disc_vars = [var for var in all_vars if 'discriminator' in var.name]
gene_weights = [var for var in gene_vars if 'weights' in var.name]
disc_weights = [var for var in disc_vars if 'weights' in var.name]
gene_loss = gene_loss_ori + WEIGHT_DECAY_RATE * \
tf.reduce_mean(tf.stack([tf.nn.l2_loss(x) for x in gene_weights]))
disc_loss = disc_loss_ori + WEIGHT_DECAY_RATE * \
tf.reduce_mean(tf.stack([tf.nn.l2_loss(x) for x in disc_weights]))
if TRAIN:
if FLAGS.weight_clip:
gene_optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate, BETA1)#Adam优化参数
gene_vars_grads = gene_optimizer.compute_gradients(gene_loss, gene_vars)
gene_vars_grads = [gv if gv[0] is None else [
tf.clip_by_value(gv[0], -10., 10.), gv[1]] for gv in gene_vars_grads]
gene_train_op = gene_optimizer.apply_gradients(gene_vars_grads)
disc_optimizer = tf.train.AdamOptimizer(
FLAGS.learning_rate / 10.0, BETA1)
disc_vars_grads = disc_optimizer.compute_gradients(disc_loss, disc_vars)
disc_vars_grads = [gv if gv[0] is None else [
tf.clip_by_value(gv[0], -10., 10.), gv[1]] for gv in disc_vars_grads]
disc_train_op = disc_optimizer.apply_gradients(disc_vars_grads)
else:
gene_train_op = tf.train.AdamOptimizer(
FLAGS.learning_rate, BETA1, BETA2).minimize(gene_loss, var_list=gene_vars)
disc_train_op = tf.train.AdamOptimizer(
FLAGS.learning_rate, BETA1, BETA2).minimize(disc_loss, var_list=disc_vars)
# 合并各自的总结
tf.summary.scalar('disc_acc', disc_acc)
tf.summary.scalar('rec_loss', rec_loss)
tf.summary.scalar('adv_gene_loss', gene_loss_ori)
tf.summary.scalar('gene_loss', adv_gene_loss)
tf.summary.scalar('disc_loss', adv_disc_loss)
tf.summary.histogram('gene_score', adv_fake_score)
tf.summary.histogram('disc_score', adv_all_score)
merged = tf.summary.merge_all()
writer = tf.summary.FileWriter(LOG_DIR, sess.graph)
tf.global_variables_initializer().run()
counter = 1
saver = tf.train.Saver(max_to_keep=1)
could_load, checkpoint_counter = load(sess, saver, CHECKPOINT_DIR)
if could_load:
counter = checkpoint_counter
print(' [*] Load SUCCESS')
else:
print(' [!] Load FAILED...')
if not TRAIN:
exit()
sess.graph.finalize()
if TRAIN:
save_configs()
for epoch in range(FLAGS.start_epoch, FLAGS.epoch):
index = 0
losses = np.zeros(4)
loss_file = open(os.path.join(LOG_DIR, 'loss_log.txt'), 'a+')
# file_object = open(DATA_PATH, 'rb')
print('Current Epoch is: ' + str(epoch))
for image_batch in read_by_batch(
file_object, BATCH_SIZE, [IMAGE_SIZE_H, IMAGE_SIZE_W, IMAGE_CHANNEL]):
if image_batch.shape[0] != BATCH_SIZE:
break
image_batch = (image_batch.astype(np.float32) - 127.5) / 127.5
# while(True):
# image_batch = (BATCH_SIZE, IMAGE_SIZE_H,IMAGE_SIZE_W,IMAGE_CHANNEL)
if FLAGS.block == 'center':
masked_image_batch, hidden_image_batch, masks_idx = preprocess_image(
image_batch, BATCH_SIZE, IMAGE_SIZE_H,IMAGE_SIZE_W, HIDDEN_SIZE_H, HIDDEN_SIZE_W,IMAGE_CHANNEL,False)
else:
masked_image_batch, hidden_image_batch, masks_idx = preprocess_image(
image_batch, BATCH_SIZE, IMAGE_SIZE_H,IMAGE_SIZE_W, HIDDEN_SIZE_H, HIDDEN_SIZE_W, IMAGE_CHANNEL)
if epoch % FLAGS.sample == 0 and index == 0:
samples, rec_loss_value, adv_gene_loss_value, adv_disc_loss_value = sess.run(
[sampler, rec_loss, adv_gene_loss, adv_disc_loss],
feed_dict={
masked_image_holder: masked_image_batch,
hidden_image_holder: hidden_image_batch
})
inpaint_image = np.copy(masked_image_batch)
for idx in range(FLAGS.batch_size):
idx_start1 = int(masks_idx[idx, 0])
idx_end1 = int(masks_idx[idx, 0] + (HIDDEN_SIZE_H))
idx_start2 = int(masks_idx[idx, 1])
idx_end2 = int(masks_idx[idx, 1] + (HIDDEN_SIZE_W))
inpaint_image[idx, idx_start1: idx_end1,
idx_start2: idx_end2, :] = samples[idx, :, :, :]
save_images(image_batch, epoch, index, SAMPLE_DIR)
save_images(inpaint_image, epoch, index + 1, SAMPLE_DIR)
save_images(masked_image_batch, epoch, index + 2, SAMPLE_DIR)
psnr, ssim = compute_psnr_ssim(hidden_image_batch, samples)
# psnr, ssim = compute_psnr_ssim(image_batch, inpaint_image)
print('[Getting Sample...] rec_loss:%.8f, gene_loss: %.8f, disc_loss: %.8f, \
psnr: %.8f, ssim: %.8f' % (rec_loss_value, adv_gene_loss_value,
adv_disc_loss_value, psnr, ssim))
if counter % FLAGS.summary == 0:
summary = sess.run(
merged,
feed_dict={
masked_image_holder: masked_image_batch,
hidden_image_holder: hidden_image_batch
})
writer.add_summary(summary, counter)
if epoch % FLAGS.ckpt == 0 and index == 0:
_thread.start_new_thread(
save, (sess, saver, CHECKPOINT_DIR, counter,))
# 更新 D 的参数, # 计算训练过程中的损失,打印出来
for _ in range(FLAGS.disc_iter):
_ = sess.run(
disc_train_op,
feed_dict={
masked_image_holder: masked_image_batch,
hidden_image_holder: hidden_image_batch
})
for _ in range(FLAGS.gene_iter):
_, rec_loss_value, adv_gene_loss_value, adv_disc_loss_value, disc_acc_value = sess.run(
[gene_train_op, rec_loss, adv_gene_loss, adv_disc_loss, disc_acc],
feed_dict={
masked_image_holder: masked_image_batch,
hidden_image_holder: hidden_image_batch
})
print('Epoch:%4d Batch:%4d, rec_loss:%2.8f, gene_loss: %2.8f, \
disc_loss: %2.8f, disc_acc: %2.8f' % (epoch, index, rec_loss_value, adv_gene_loss_value,
adv_disc_loss_value, disc_acc_value))
index += 1
counter += 1
losses[0] += rec_loss_value
losses[1] += adv_gene_loss_value
losses[2] += adv_disc_loss_value
losses[3] += disc_acc_value
losses /= index
loss_file.write(
'Epoch:%4d, rec_loss:%2.8f, gene_loss: %2.8f, disc_loss: %2.8f, disc_acc: %2.8f\n'
% (epoch, losses[0], losses[1], losses[2], losses[3]))
else: # VALIDATE or TEST
index = 0
values = np.zeros(6)
file_object = open(DATA_PATH, 'rb')
for image_batch in read_by_batch(
file_object, BATCH_SIZE, [IMAGE_SIZE_H, IMAGE_SIZE_W, IMAGE_CHANNEL]):
if image_batch.shape[0] != BATCH_SIZE:
print('image_batch.shape[0] != BATCH_SIZE')
break
image_batch = (image_batch.astype(np.float32) - 127.5) / 127.5
if FLAGS.block == 'center':
masked_image_batch, hidden_image_batch, masks_idx = preprocess_image(
image_batch, BATCH_SIZE,IMAGE_SIZE_H,IMAGE_SIZE_W, HIDDEN_SIZE_H,HIDDEN_SIZE_W, IMAGE_CHANNEL, False)
else:
masked_image_batch, hidden_image_batch, masks_idx = preprocess_image(
image_batch, BATCH_SIZE,IMAGE_SIZE_H,IMAGE_SIZE_W, HIDDEN_SIZE_H,HIDDEN_SIZE_W, IMAGE_CHANNEL)
samples, rec_loss_value, adv_gene_loss_value, adv_disc_loss_value, disc_acc_value = sess.run(
[sampler, rec_loss, adv_gene_loss, adv_disc_loss, disc_acc],
feed_dict={
masked_image_holder: masked_image_batch,
hidden_image_holder: hidden_image_batch
})
inpaint_image = np.copy(masked_image_batch)
for idx in range(FLAGS.batch_size):
idx_start1 = int(masks_idx[idx, 0])
idx_end1 = int(masks_idx[idx, 0] + HIDDEN_SIZE_H)
idx_start2 = int(masks_idx[idx, 1])
idx_end2 = int(masks_idx[idx, 1] + HIDDEN_SIZE_W)
inpaint_image[idx, idx_start1: idx_end1,
idx_start2: idx_end2, :] = samples[idx, :, :, :]
save_images(image_batch, index, 0, SAMPLE_DIR)
save_images(inpaint_image, index, 1, SAMPLE_DIR)
save_images(masked_image_batch, index, 2, SAMPLE_DIR)
psnr, ssim = compute_psnr_ssim(hidden_image_batch, samples)
# psnr, ssim = compute_psnr_ssim(image_batch, inpaint_image)
print('[Getting Sample...] rec_loss:%2.8f, gene_loss: %2.8f, disc_loss: %2.8f, \
disc_acc: %2.8f, psnr: %2.8f, ssim: %2.8f' % (rec_loss_value, adv_gene_loss_value,
adv_disc_loss_value, disc_acc_value, psnr, ssim))
index += 1
values[0] += rec_loss_value
values[1] += adv_gene_loss_value
values[2] += adv_disc_loss_value
values[3] += disc_acc_value
values[4] += psnr
values[5] += ssim
values /= index
print('Mean rec_loss:%2.8f, gene_loss: %2.8f, disc_loss: %2.8f, disc_acc: %2.8f, \
psnr: %2.8f, ssim: %2.8f' % (values[0], values[1], values[2], values[3], values[4], values[5]))
#对SR的质量进行定量评价常用的两个指标是PSNR(Peak Signal-to-Noise Ratio)
#SSIM(Structure Similarity Index)
def save_configs():
config_file = open(os.path.join(CHECKPOINT_DIR, 'configs.txt'), 'a+')
config_file.write('\
dataname:%s\n\
block:%s\n\
learning_rate:%f\n\
lamb_rec:%f\n\
lamb_adv:%f\n\
lamb_tv:%f\n\
use_l1:%s\n\
weight_clip:%s\n\
image_size_h:%d\n\
image_size_w:%d\n\
image_channel:%d\n\
start_epoch: %d\n\
epoch: %d\n\
batch_size: %d\n\
gene_iter: %d\n\
disc_iter: %d\n\n' % (FLAGS.dataname, FLAGS.block, FLAGS.learning_rate, FLAGS.lamb_rec,
FLAGS.lamb_adv, FLAGS.lamb_tv, str(FLAGS.use_l1),
str(FLAGS.weight_clip), FLAGS.image_size_h, FLAGS.image_size_w,FLAGS.image_channel,
FLAGS.start_epoch, FLAGS.epoch, FLAGS.batch_size, FLAGS.gene_iter,
FLAGS.disc_iter))
def print_args():
print('dataname is: ' + str(FLAGS.dataname))
print('learning_rate is: ' + str(FLAGS.learning_rate))
print('lamb_rec is: ' + str(FLAGS.lamb_rec))
print('lamb_adv is: ' + str(FLAGS.lamb_adv))
print('lamb_tv is: ' + str(FLAGS.lamb_tv))
print('start_epoch is: ' + str(FLAGS.start_epoch))
print('epoch is: ' + str(FLAGS.epoch))
print('image_size_h is: ' + str(FLAGS.image_size_h))
print('image_size_w is: ' + str(FLAGS.image_size_w))
print('image_channel is: ' + str(FLAGS.image_channel))
print('use_l1 is: ' + str(FLAGS.use_l1))
print('weight_clip is: ' + str(FLAGS.weight_clip))
print('mode is: ' + str(FLAGS.mode))
print('batch_size is: ' + str(FLAGS.batch_size))
print('ckpt is: ' + str(FLAGS.ckpt))
print('sample is: ' + str(FLAGS.sample))
print('summary is: ' + str(FLAGS.summary))
print('gene_iter is: ' + str(FLAGS.gene_iter))
print('disc_iter is: ' + str(FLAGS.disc_iter))
print('gpu is: ' + str(FLAGS.gpu))
print('')
def main(_):
print_args()
if not os.path.exists(CHECKPOINT_DIR):
os.makedirs(CHECKPOINT_DIR)
if not os.path.exists(LOG_DIR):
os.makedirs(LOG_DIR)
if not os.path.exists(SAMPLE_DIR):
os.makedirs(SAMPLE_DIR)
run_config = tf.ConfigProto(allow_soft_placement=True)
run_config.gpu_options.allow_growth = True
with tf.device('/gpu:' + str(FLAGS.gpu)):
with tf.Session(config=run_config) as sess:
run_model(sess)
if __name__ == '__main__':
tf.app.run()