-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathmodel.py
118 lines (90 loc) · 4.57 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import tensorflow as tf
def model_inputs(real_dim, z_dim):
inputs_real = tf.placeholder(tf.float32, (None, *real_dim), name='input_real')
inputs_z = tf.placeholder(tf.float32, (None, z_dim), name='input_z')
return inputs_real, inputs_z
def generator(z, output_dim, reuse=False, alpha=0.2, training=True):
with tf.variable_scope('generator', reuse=reuse):
# First fully connected layer
x1 = tf.layers.dense(z, 4*4*512)
# Reshape it to start the convolutional stack
x1 = tf.reshape(x1, (-1, 4, 4, 512))
x1 = tf.layers.batch_normalization(x1, training=training)
x1 = tf.maximum(alpha * x1, x1)
# 4x4x512 now
x2 = tf.layers.conv2d_transpose(x1, 256, 5, strides=2, padding='same')
x2 = tf.layers.batch_normalization(x2, training=training)
x2 = tf.maximum(alpha * x2, x2)
# 8x8x256 now
x3 = tf.layers.conv2d_transpose(x2, 128, 5, strides=2, padding='same')
x3 = tf.layers.batch_normalization(x3, training=training)
x3 = tf.maximum(alpha * x3, x3)
# 16x16x128 now
# Output layer
logits = tf.layers.conv2d_transpose(x3, output_dim, 5, strides=2, padding='same')
# 32x32x3 now
out = tf.tanh(logits)
return out
def discriminator(x, reuse=False, alpha=0.2):
with tf.variable_scope('discriminator', reuse=reuse):
# Input layer is 32x32x3
x1 = tf.layers.conv2d(x, 64, 5, strides=2, padding='same')
relu1 = tf.maximum(alpha * x1, x1)
# 16x16x64
x2 = tf.layers.conv2d(relu1, 128, 5, strides=2, padding='same')
bn2 = tf.layers.batch_normalization(x2, training=True)
relu2 = tf.maximum(alpha * bn2, bn2)
# 8x8x128
x3 = tf.layers.conv2d(relu2, 256, 5, strides=2, padding='same')
bn3 = tf.layers.batch_normalization(x3, training=True)
relu3 = tf.maximum(alpha * bn3, bn3)
# 4x4x256
# Flatten it
flat = tf.reshape(relu3, (-1, 4*4*256))
logits = tf.layers.dense(flat, 1)
out = tf.sigmoid(logits)
return out, logits
def model_loss(input_real, input_z, output_dim, alpha=0.2):
"""
Get the loss for the discriminator and generator
:param input_real: Images from the real dataset
:param input_z: Z input
:param out_channel_dim: The number of channels in the output image
:return: A tuple of (discriminator loss, generator loss)
"""
g_model = generator(input_z, output_dim, alpha=alpha)
d_model_real, d_logits_real = discriminator(input_real, alpha=alpha)
d_model_fake, d_logits_fake = discriminator(g_model, reuse=True, alpha=alpha)
d_loss_real = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_real, labels=tf.ones_like(d_model_real)))
d_loss_fake = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake, labels=tf.zeros_like(d_model_fake)))
g_loss = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake, labels=tf.ones_like(d_model_fake)))
d_loss = d_loss_real + d_loss_fake
return d_loss, g_loss
def model_opt(d_loss, g_loss, learning_rate, beta1):
"""
Get optimization operations
:param d_loss: Discriminator loss Tensor
:param g_loss: Generator loss Tensor
:param learning_rate: Learning Rate Placeholder
:param beta1: The exponential decay rate for the 1st moment in the optimizer
:return: A tuple of (discriminator training operation, generator training operation)
"""
# Get weights and bias to update
t_vars = tf.trainable_variables()
d_vars = [var for var in t_vars if var.name.startswith('discriminator')]
g_vars = [var for var in t_vars if var.name.startswith('generator')]
# Optimize
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
d_train_opt = tf.train.AdamOptimizer(learning_rate, beta1=beta1).minimize(d_loss, var_list=d_vars)
g_train_opt = tf.train.AdamOptimizer(learning_rate, beta1=beta1).minimize(g_loss, var_list=g_vars)
return d_train_opt, g_train_opt
class GAN:
def __init__(self, real_size, z_size, learning_rate, alpha=0.2, beta1=0.5):
tf.reset_default_graph()
self.input_real, self.input_z = model_inputs(real_size, z_size)
self.d_loss, self.g_loss = model_loss(self.input_real, self.input_z,
real_size[2], alpha=alpha)
self.d_opt, self.g_opt = model_opt(self.d_loss, self.g_loss, learning_rate, beta1)