diff --git a/src/networks.py b/src/networks.py index 1444f1091..f0c374e2b 100644 --- a/src/networks.py +++ b/src/networks.py @@ -1,6 +1,19 @@ import torch import torch.nn as nn +def output_align(input, output): + """ + author: @youyuge34 (https://github.com/youyuge34) + In testing, sometimes output is several pixels less than irregular-size input, + here is to fill them + """ + if output.size()[-2:] != input.size()[-2:]: + diff_width = input.size(-1) - output.size(-1) + diff_height = input.size(-2) - output.size(-2) + m = nn.ReplicationPad2d((0, diff_width, 0, diff_height)) + output = m(output) + + return output class BaseNetwork(nn.Module): def __init__(self): @@ -78,11 +91,12 @@ def __init__(self, residual_blocks=8, init_weights=True): self.init_weights() def forward(self, x): + inpt = x x = self.encoder(x) x = self.middle(x) x = self.decoder(x) x = (torch.tanh(x) + 1) / 2 - + x = output_align(inpt, x) return x @@ -129,10 +143,12 @@ def __init__(self, residual_blocks=8, use_spectral_norm=True, init_weights=True) self.init_weights() def forward(self, x): + inpt = x x = self.encoder(x) x = self.middle(x) x = self.decoder(x) x = torch.sigmoid(x) + x = output_align(inpt, x) return x