forked from hiwonjoon/cycle-gan-tf
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
66 lines (59 loc) · 1.99 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from ops import *
def build_enc_dec(source,reuse=False) :
#TODO: transposed conv weights cannot accept dynamic shape?
batch_size, channels, image_size, _ = source.get_shape().as_list()
with tf.variable_scope('encoder') as s:
encoder_spec = [
Conv2d('conv2d_1',channels,32,7,7,1,1),
Lrelu(),
]
for l,(in_,out_) in enumerate([(32,64),(64,128)]):
encoder_spec +=[
Conv2d('conv2d_%d'%(l+2),in_,out_,4,4,2,2),
InstanceNorm('conv2d_in_%d'%(l+2)),
Lrelu(),
]
for l in xrange(9) :
encoder_spec +=[
ResidualBlock('res_%d'%(l+1),128)
]
with tf.variable_scope('decoder') as s:
decoder_spec = []
for l,(in_,out_,size_) in enumerate([(128,64,image_size//2),(64,32,image_size)]):
decoder_spec += [
TransposedConv2d('tconv_%d'%(l+1),in_,[batch_size,out_,size_,size_],4,4,2,2),
InstanceNorm('tconv_in_%d'%(l+1)),
Lrelu()
]
decoder_spec += [
Conv2d('conv2d_1',32,3,7,7,1,1),
lambda t : tf.nn.tanh(t,name='b_gen'),
]
_t = source
for block in encoder_spec+decoder_spec :
if( type(block) == BatchNorm ) :
_t = block(_t,reuse=reuse)
else :
_t = block(_t)
target = _t
return target
def build_critic(_t) :
_, channels, image_size, _ = _t.get_shape().as_list()
c_spec = []
for l,(in_,out_) in enumerate(
[(channels,64),(64,128),(128,256),(256,256),(256,256)]):
c_spec +=[
Conv2d('conv2d_%d'%(l+1),in_,out_,4,4,2,2),
#InstanceNorm('conv2d_in_%d'%(l+1)),
Lrelu(),
]
c_spec += [
Linear('linear_1',image_size//32*image_size//32*256,512),
Lrelu(),
Linear('linear_2',512,512),
Lrelu(),
Linear('linear_3',512,1),
]
for block in c_spec :
_t = block(_t)
return _t