Skip to content

Commit

Permalink
Back to original mode
Browse files Browse the repository at this point in the history
  • Loading branch information
nmhkahn committed Jul 14, 2018
1 parent 05858bc commit 1fb9229
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 10 deletions.
6 changes: 3 additions & 3 deletions carn/model/carn.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(self, **kwargs):

scale = kwargs.get("scale")
multi_scale = kwargs.get("multi_scale")
reduce_upsample = kwargs.get("reduce_upsample", False)
group = kwargs.get("group", 1)

self.sub_mean = ops.MeanShift((0.4488, 0.4371, 0.4040), sub=True)
self.add_mean = ops.MeanShift((0.4488, 0.4371, 0.4040), sub=False)
Expand All @@ -54,8 +54,8 @@ def __init__(self, **kwargs):
self.c3 = ops.BasicBlock(64*4, 64, 1, 1, 0)

self.upsample = ops.UpsampleBlock(64, scale=scale,
multi_scale=multi_scale,
reduce=reduce_upsample)
multi_scale=multi_scale,
group=group)
self.exit = nn.Conv2d(64, 3, 3, 1, 1)

def forward(self, x, scale):
Expand Down
15 changes: 8 additions & 7 deletions carn/model/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,21 +120,22 @@ def forward(self, x, scale):
else:
return self.up(x)


class _UpsampleBlock(nn.Module):
def __init__(self,
n_channels, scale,
group=1, act=nn.ReLU(inplace=True)):
n_channels, scale,
group=1, act=nn.ReLU(inplace=True)):
super(_UpsampleBlock, self).__init__()

modules = []
if scale == 2 or scale == 4 or scale == 8:
for _ in range(int(math.log(scale, 2))):
modules += [nn.Upsample(scale_factor=2)]
modules += [nn.Conv2d(n_channels, n_channels, 3, 1, 1, groups=group), act]
modules += [nn.Conv2d(n_channels, 4*n_channels, 3, 1, 1, groups=group), act]
modules += [nn.PixelShuffle(2)]
elif scale == 3:
modules += [nn.Upsample(scale_factor=3)]
modules += [nn.Conv2d(n_channels, n_channels, 3, 1, 1, groups=group), act]
modules += [nn.Conv2d(n_channels, 9*n_channels, 3, 1, 1, groups=group), act]
modules += [nn.PixelShuffle(3)]

self.body = nn.Sequential(*modules)
init_weights(self.modules)

Expand Down

0 comments on commit 1fb9229

Please sign in to comment.