Skip to content

Commit

Permalink
feat(edsr): add EDSR model
Browse files Browse the repository at this point in the history
  • Loading branch information
hahnec committed Jul 26, 2023
1 parent af26a22 commit cda51fc
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 0 deletions.
Empty file added models/__init__.py
Empty file.
126 changes: 126 additions & 0 deletions models/edsr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# https://github.com/sanghyun-son/EDSR-PyTorch

import torch.nn as nn
import math


def default_conv(in_channels, out_channels, kernel_size, bias=True):
return nn.Conv2d(
in_channels, out_channels, kernel_size,
padding=(kernel_size//2), bias=bias)


class EDSR(nn.Module):
def __init__(self, args, conv=default_conv):
super(EDSR, self).__init__()

n_resblocks = args.n_resblocks
n_feats = args.n_feats
kernel_size = 3
scale = args.scale[0]
act = nn.ReLU(True)

# define head module
m_head = [conv(args.n_colors, n_feats, kernel_size)]

# define body module
m_body = [
ResBlock(
conv, n_feats, kernel_size, act=act, res_scale=args.res_scale
) for _ in range(n_resblocks)
]
m_body.append(conv(n_feats, n_feats, kernel_size))

# define tail module
m_tail = [
Upsampler(conv, scale, n_feats, act=False),
conv(n_feats, args.n_colors, kernel_size)
]

self.head = nn.Sequential(*m_head)
self.body = nn.Sequential(*m_body)
self.tail = nn.Sequential(*m_tail)

def forward(self, x):

x = self.head(x)

res = self.body(x)
res += x

x = self.tail(res)

return x

def load_state_dict(self, state_dict, strict=True):
own_state = self.state_dict()
for name, param in state_dict.items():
if name in own_state:
if isinstance(param, nn.Parameter):
param = param.data
try:
own_state[name].copy_(param)
except Exception:
if name.find('tail') == -1:
raise RuntimeError('While copying the parameter named {}, '
'whose dimensions in the model are {} and '
'whose dimensions in the checkpoint are {}.'
.format(name, own_state[name].size(), param.size()))
elif strict:
if name.find('tail') == -1:
raise KeyError('unexpected key "{}" in state_dict'
.format(name))


class ResBlock(nn.Module):
def __init__(
self, conv, n_feats, kernel_size,
bias=True, bn=False, act=nn.ReLU(True), res_scale=1):

super(ResBlock, self).__init__()
m = []
for i in range(2):
m.append(conv(n_feats, n_feats, kernel_size, bias=bias))
if bn:
m.append(nn.BatchNorm2d(n_feats))
if i == 0:
m.append(act)

self.body = nn.Sequential(*m)
self.res_scale = res_scale

def forward(self, x):
res = self.body(x).mul(self.res_scale)
res += x

return res

class Upsampler(nn.Sequential):
def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True):

m = []
if (scale & (scale - 1)) == 0: # Is scale = 2^n?
for _ in range(int(math.log(scale, 2))):
m.append(conv(n_feats, 4 * n_feats, 3, bias))
m.append(nn.PixelShuffle(2))
if bn:
m.append(nn.BatchNorm2d(n_feats))
if act == 'relu':
m.append(nn.ReLU(True))
elif act == 'prelu':
m.append(nn.PReLU(n_feats))

elif scale == 3:
m.append(conv(n_feats, 9 * n_feats, 3, bias))
m.append(nn.PixelShuffle(3))
if bn:
m.append(nn.BatchNorm2d(n_feats))
if act == 'relu':
m.append(nn.ReLU(True))
elif act == 'prelu':
m.append(nn.PReLU(n_feats))
else:
raise NotImplementedError

super(Upsampler, self).__init__(*m)

13 changes: 13 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,19 @@ def train_model(
elif cfg.model == 'mspcn':
# mSPCN model
model = Net(upscale_factor=cfg.upscale_factor, in_channels=in_channels)
elif cfg.model == 'edsr':
# EDSR model
from models.edsr import EDSR
class Args:
pass
args = Args()
args.n_feats = 64
args.n_resblocks = 16
args.n_colors = 2 if cfg.input_type == 'rf' and cfg.rescale_factor == 1 else 1
args.rgb_range = 1
args.scale = (cfg.upscale_factor, cfg.upscale_factor)
args.res_scale = 1
model = EDSR(args)
else:
raise Exception('Model name not recognized')

Expand Down

0 comments on commit cda51fc

Please sign in to comment.