Skip to content

Commit

Permalink
add test_selfensemble
Browse files Browse the repository at this point in the history
  • Loading branch information
xinntao committed Sep 15, 2022
1 parent 88647f2 commit ce5c55a
Showing 1 changed file with 48 additions and 0 deletions.
48 changes: 48 additions & 0 deletions basicsr/models/sr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,54 @@ def test(self):
self.output = self.net_g(self.lq)
self.net_g.train()

def test_selfensemble(self):
# TODO: to be tested
# 8 augmentations
# modified from https://github.com/thstkdgus35/EDSR-PyTorch

def _transform(v, op):
# if self.precision != 'single': v = v.float()
v2np = v.data.cpu().numpy()
if op == 'v':
tfnp = v2np[:, :, :, ::-1].copy()
elif op == 'h':
tfnp = v2np[:, :, ::-1, :].copy()
elif op == 't':
tfnp = v2np.transpose((0, 1, 3, 2)).copy()

ret = torch.Tensor(tfnp).to(self.device)
# if self.precision == 'half': ret = ret.half()

return ret

# prepare augmented data
lq_list = [self.lq]
for tf in 'v', 'h', 't':
lq_list.extend([_transform(t, tf) for t in lq_list])

# inference
if hasattr(self, 'net_g_ema'):
self.net_g_ema.eval()
with torch.no_grad():
out_list = [self.net_g_ema(aug) for aug in lq_list]
else:
self.net_g.eval()
with torch.no_grad():
out_list = [self.net_g_ema(aug) for aug in lq_list]
self.net_g.train()

# merge results
for i in range(len(out_list)):
if i > 3:
out_list[i] = _transform(out_list[i], 't')
if i % 4 > 1:
out_list[i] = _transform(out_list[i], 'h')
if (i % 4) % 2 == 1:
out_list[i] = _transform(out_list[i], 'v')
output = torch.cat(out_list, dim=0)

self.output = output.mean(dim=0, keepdim=True)

def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
if self.opt['rank'] == 0:
self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
Expand Down

0 comments on commit ce5c55a

Please sign in to comment.