-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathunet.py
420 lines (347 loc) · 14.5 KB
/
unet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
import math
import torch
import torch.nn as nn
import numpy as np
class ConvPass(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel_sizes, padding, activation):
super(ConvPass, self).__init__()
if activation is not None:
activation = getattr(torch.nn, activation)
layers = []
for kernel_size in kernel_sizes:
self.dims = len(kernel_size)
if padding in ("VALID", "valid"):
pad = 0
elif padding in ("SAME", "same"):
pad = tuple(np.array(kernel_size) // 2)
else:
raise RuntimeError("invalid string value for padding")
layers.append(
torch.nn.Conv2d(in_channels, out_channels, kernel_size, padding=pad)
)
in_channels = out_channels
if activation is not None:
layers.append(activation())
layers.append(nn.BatchNorm2d(out_channels))
self.conv_pass = torch.nn.Sequential(*layers)
def forward(self, x):
return self.conv_pass(x)
class Downsample(torch.nn.Module):
def __init__(self, downsample_factor):
super(Downsample, self).__init__()
self.dims = len(downsample_factor)
self.downsample_factor = downsample_factor
self.down = torch.nn.MaxPool2d(downsample_factor, stride=downsample_factor)
def forward(self, x):
for d in range(1, self.dims + 1):
if x.size()[-d] % self.downsample_factor[-d] != 0:
raise RuntimeError(
"Can not downsample shape %s with factor %s, mismatch "
"in spatial dimension %d"
% (x.size(), self.downsample_factor, self.dims - d)
)
return self.down(x)
class Upsample(torch.nn.Module):
def __init__(
self,
scale_factor,
mode="nearest",
in_channels=None,
out_channels=None,
crop_factor=None,
padding="VALID",
next_conv_kernel_sizes=None,
):
super(Upsample, self).__init__()
assert (crop_factor is None) == (
next_conv_kernel_sizes is None
), "crop_factor and next_conv_kernel_sizes have to be given together"
self.crop_factor = crop_factor
self.next_conv_kernel_sizes = next_conv_kernel_sizes
self.padding = padding
self.dims = len(scale_factor)
if mode == "transposed_conv":
self.up = torch.nn.ConvTranspose2d(
in_channels, out_channels, kernel_size=scale_factor, stride=scale_factor
)
else:
self.up = torch.nn.Upsample(scale_factor=tuple(scale_factor), mode=mode)
def crop_to_factor(self, x, factor, kernel_sizes):
"""Crop feature maps to ensure translation equivariance with stride of
upsampling factor. This should be done right after upsampling, before
application of the convolutions with the given kernel sizes.
The crop could be done after the convolutions, but it is more efficient
to do that before (feature maps will be smaller).
"""
shape = x.size()
spatial_shape = shape[-self.dims :]
# the crop that will already be done due to the convolutions
convolution_crop = tuple(
sum(ks[d] - 1 for ks in kernel_sizes) for d in range(self.dims)
)
# we need (spatial_shape - convolution_crop) to be a multiple of
# factor, i.e.:
#
# (s - c) = n*k
#
# we want to find the largest n for which s' = n*k + c <= s
#
# n = floor((s - c)/k)
#
# this gives us the target shape s'
#
# s' = n*k + c
ns = (
int(math.floor(float(s - c) / f))
for s, c, f in zip(spatial_shape, convolution_crop, factor)
)
target_spatial_shape = tuple(
n * f + c for n, c, f in zip(ns, convolution_crop, factor)
)
if target_spatial_shape != spatial_shape:
assert all(
((t > c) for t, c in zip(target_spatial_shape, convolution_crop))
), (
"Feature map with shape %s is too small to ensure "
"translation equivariance with factor %s and following "
"convolutions %s" % (shape, factor, kernel_sizes)
)
return self.crop(x, target_spatial_shape)
return x
def crop(self, x, shape):
"""Center-crop x to match spatial dimensions given by shape."""
x_target_size = x.size()[: -self.dims] + shape
offset = tuple((a - b) // 2 for a, b in zip(x.size(), x_target_size))
slices = tuple(slice(o, o + s) for o, s in zip(offset, x_target_size))
return x[slices]
def forward(self, f_left, g_out):
g_up = self.up(g_out)
# if self.next_conv_kernel_sizes is not None and self.padding in ("VALID", "valid"):
# g_cropped = self.crop_to_factor(
# g_up,
# self.crop_factor,
# self.next_conv_kernel_sizes)
# else:
g_cropped = g_up
f_cropped = self.crop(f_left, g_cropped.size()[-self.dims :])
return torch.cat([f_cropped, g_cropped], dim=1)
class UNet(torch.nn.Module):
def __init__(
self,
in_channels,
num_fmaps,
fmap_inc_factors,
downsample_factors,
kernel_size_down=None,
kernel_size_up=None,
activation="ReLU",
padding="VALID",
num_fmaps_out=None,
constant_upsample=False,
):
"""Create a U-Net::
f_in --> f_left --------------------------->> f_right--> f_out
| ^
v |
g_in --> g_left ------->> g_right --> g_out
| ^
v |
...
where each ``-->`` is a convolution pass, each `-->>` a crop, and down
and up arrows are max-pooling and transposed convolutions,
respectively.
The U-Net expects 2D tensors shaped like::
``(batch=1, channels, height, width)``.
This U-Net performs only "valid" convolutions, i.e., sizes of the
feature maps decrease after each convolution.
Args:
in_channels:
The number of input channels.
num_fmaps:
The number of feature maps in the first layer. This is also the
number of output feature maps. Stored in the ``channels``
dimension.
fmap_inc_factors:
By how much to multiply the number of feature maps between
layers. If layer 0 has ``k`` feature maps, layer ``l`` will
have ``k*fmap_inc_factor**l``.
downsample_factors:
List of tuples ``(y, x)`` to use to down- and up-sample the
feature maps between layers.
kernel_size_down (optional):
List of lists of kernel sizes. The number of sizes in a list
determines the number of convolutional layers in the
corresponding level of the build on the left side. Kernel sizes
can be given as tuples or integer. If not given, each
convolutional pass will consist of two 3x3 convolutions.
kernel_size_up (optional):
List of lists of kernel sizes. The number of sizes in a list
determines the number of convolutional layers in the
corresponding level of the build on the right side. Within one
of the lists going from left to right. Kernel sizes can be
given as tuples or integer. If not given, each convolutional
pass will consist of two 3x3 convolutions.
activation:
Which activation to use after a convolution. Accepts the name
of any tensorflow activation function (e.g., ``ReLU`` for
``torch.nn.ReLU``).
fov (optional):
Initial field of view
constant_upsample (optional):
If set to true, perform a constant upsampling instead of a
transposed convolution in the upsampling layers.
padding (optional):
How to pad convolutions. Either 'same' or 'valid' (default).
"""
super(UNet, self).__init__()
self.num_levels = len(downsample_factors) + 1
self.in_channels = in_channels
self.out_channels = num_fmaps_out if num_fmaps_out else num_fmaps
self.constant_upsample = constant_upsample
# default arguments
if kernel_size_down is None:
kernel_size_down = [[(3, 3), (3, 3)]] * self.num_levels
if kernel_size_up is None:
kernel_size_up = [[(3, 3), (3, 3)]] * (self.num_levels - 1)
self.kernel_size_down = kernel_size_down
self.kernel_size_up = kernel_size_up
self.downsample_factors = downsample_factors
# compute crop factors for translation equivariance
crop_factors = []
factor_product = None
for factor in downsample_factors:
if factor_product is None:
factor_product = list(factor)
else:
factor_product = list(f * ff for f, ff in zip(factor, factor_product))
crop_factors.append(factor_product)
crop_factors = crop_factors[::-1]
# modules
# left convolutional passes
self.l_conv = nn.ModuleList(
[
ConvPass(
in_channels
if level == 0
else num_fmaps * fmap_inc_factors ** (level - 1),
num_fmaps * fmap_inc_factors**level,
kernel_size_down[level],
padding,
activation=activation,
)
for level in range(self.num_levels)
]
)
self.dims = self.l_conv[0].dims
# left downsample layers
self.l_down = nn.ModuleList(
[
Downsample(downsample_factors[level])
for level in range(self.num_levels - 1)
]
)
# right up/crop/concatenate layers
self.r_up = nn.ModuleList(
[
Upsample(
downsample_factors[level],
mode="nearest" if constant_upsample else "transposed_conv",
in_channels=num_fmaps * fmap_inc_factors ** (level + 1),
out_channels=num_fmaps * fmap_inc_factors ** (level + 1),
crop_factor=crop_factors[level],
padding=padding,
next_conv_kernel_sizes=kernel_size_up[level],
)
for level in range(self.num_levels - 1)
]
)
# right convolutional passes
self.r_conv = nn.ModuleList(
[
ConvPass(
num_fmaps * fmap_inc_factors**level
+ num_fmaps * fmap_inc_factors ** (level + 1),
num_fmaps * fmap_inc_factors**level
if num_fmaps_out is None or level != 0
else num_fmaps_out,
kernel_size_up[level],
padding,
activation=activation,
)
for level in range(self.num_levels - 1)
]
)
def rec_fov(self, level, fov, sp):
# index of level in layer arrays
i = self.num_levels - level - 1
# convolve
for j in range(len(self.kernel_size_down[i])):
fov += (np.array(self.kernel_size_down[i][j]) - 1) * sp
# end of recursion
if level != 0:
# down
fov += (np.array(self.downsample_factors[i]) - 1) * sp
sp *= np.array(self.downsample_factors[i])
# nested levels
fov, sp = self.rec_fov(level - 1, fov, sp)
# up
sp //= np.array(self.downsample_factors[i])
# convolve
for j in range(len(self.kernel_size_up[i])):
fov += (np.array(self.kernel_size_up[i][j]) - 1) * sp
return fov, sp
def get_fov(self):
fov, sp = self.rec_fov(self.num_levels - 1, (1, 1), 1)
return fov
def rec_forward(self, level, f_in):
# index of level in layer arrays
i = self.num_levels - level - 1
# convolve
f_left = self.l_conv[i](f_in)
# end of recursion
if level == 0:
fs_out = f_left
else:
# down
g_in = self.l_down[i](f_left)
# nested levels
gs_out = self.rec_forward(level - 1, g_in)
# up, concat, and crop
fs_right = self.r_up[i](f_left, gs_out)
# convolve
fs_out = self.r_conv[i](fs_right)
return fs_out
def forward(self, x):
y = self.rec_forward(self.num_levels - 1, x)
return y
def save_weight_histogram(self, tb_writer, epoch):
for bidx, block in enumerate(self.l_conv):
for lidx, layer in enumerate(block.conv_pass):
if not isinstance(layer, nn.Conv2d):
continue
tb_writer.add_histogram(
"enc_block{}_conv{}.bias".format(bidx, lidx), layer.bias, epoch
)
tb_writer.add_histogram(
"enc_block{}_conv{}.weights".format(bidx, lidx), layer.weight, epoch
)
if not self.constant_upsample:
for bidx, block in enumerate(self.r_up):
tb_writer.add_histogram(
"dec_block{}_transp_conv.bias".format(bidx), block.up.bias, epoch
)
tb_writer.add_histogram(
"dec_block{}_transp_conv.weights".format(bidx),
block.up.weight,
epoch,
)
for bidx, block in enumerate(self.r_conv):
for lidx, layer in enumerate(block.conv_pass):
if not isinstance(layer, nn.Conv2d):
continue
tb_writer.add_histogram(
"dec_block{}_conv{}.bias".format(bidx, lidx), layer.bias, epoch
)
tb_writer.add_histogram(
"dec_block{}_conv{}.weights".format(bidx, lidx), layer.weight, epoch
)