Skip to content

Commit

Permalink
Update models
Browse files Browse the repository at this point in the history
  • Loading branch information
nmhkahn committed Jul 25, 2018
1 parent 1fb9229 commit 77badff
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 20 deletions.
36 changes: 16 additions & 20 deletions carn/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ def parse_args():
parser.add_argument("--model", type=str)
parser.add_argument("--ckpt_path", type=str)
parser.add_argument("--group", type=int, default=1)
parser.add_argument("--reduce_upsample", action="store_true", default=False)
parser.add_argument("--sample_dir", type=str)
parser.add_argument("--test_data_dir", type=str, default="dataset/Urban100")
parser.add_argument("--cuda", action="store_true")
Expand All @@ -33,7 +32,7 @@ def save_image(tensor, filename):
im.save(filename)


def sample(net, dataset, cfg):
def sample(net, device, dataset, cfg):
scale = cfg.scale
for step, (hr, lr, name) in enumerate(dataset):
if "DIV2K" in dataset.name:
Expand All @@ -42,19 +41,19 @@ def sample(net, dataset, cfg):
h_half, w_half = int(h/2), int(w/2)
h_chop, w_chop = h_half + cfg.shave, w_half + cfg.shave

lr_patch = torch.FloatTensor(4, 3, h_chop, w_chop)
lr_patch = torch.tensor((4, 3, h_chop, w_chop), dtype=torch.float)
lr_patch[0].copy_(lr[:, 0:h_chop, 0:w_chop])
lr_patch[1].copy_(lr[:, 0:h_chop, w-w_chop:w])
lr_patch[2].copy_(lr[:, h-h_chop:h, 0:w_chop])
lr_patch[3].copy_(lr[:, h-h_chop:h, w-w_chop:w])
lr_patch = Variable(lr_patch, volatile=True).cuda()
lr_patch = lr_patch.to(device)

sr = net(lr_patch, cfg.scale).data
sr = net(lr_patch, cfg.scale).detach()

h, h_half, h_chop = h*scale, h_half*scale, h_chop*scale
w, w_half, w_chop = w*scale, w_half*scale, w_chop*scale

result = torch.FloatTensor(3, h, w).cuda()
result = torch.tensor((3, h, w), dtype=torch.float).to(device)
result[:, 0:h_half, 0:w_half].copy_(sr[0, :, 0:h_half, 0:w_half])
result[:, 0:h_half, w_half:w].copy_(sr[1, :, 0:h_half, w_chop-w+w_half:w_chop])
result[:, h_half:h, 0:w_half].copy_(sr[2, :, h_chop-h+h_half:h_chop, 0:w_half])
Expand All @@ -63,9 +62,9 @@ def sample(net, dataset, cfg):
t2 = time.time()
else:
t1 = time.time()
lr = Variable(lr.unsqueeze(0), volatile=True).cuda()
sr = net(lr, cfg.scale).data[0]
lr = lr.data[0]
lr = lr.unsqueeze(0).to(device)
sr = net(lr, cfg.scale).detach().squeeze(0)
lr = lr.squeeze(0)
t2 = time.time()

model_name = cfg.ckpt_path.split(".")[0].split("/")[-1]
Expand All @@ -80,41 +79,38 @@ def sample(net, dataset, cfg):
"x{}".format(cfg.scale),
"HR")

if not os.path.exists(sr_dir):
os.makedirs(sr_dir)

if not os.path.exists(hr_dir):
os.makedirs(hr_dir)
os.makedirs(sr_dir, exist_ok=True)
os.makedirs(hr_dir, exist_ok=True)

sr_im_path = os.path.join(sr_dir, "{}".format(name.replace("HR", "SR")))
hr_im_path = os.path.join(hr_dir, "{}".format(name))

save_image(sr, sr_im_path)
# save_image(hr, hr_im_path)
save_image(hr, hr_im_path)
print("Saved {} ({}x{} -> {}x{}, {:.3f}s)"
.format(sr_im_path, lr.shape[1], lr.shape[2], sr.shape[1], sr.shape[2], t2-t1))


def main(cfg):
module = importlib.import_module("model.{}".format(cfg.model))
net = module.Net(multi_scale=True,
group=cfg.group,
reduce_upsample=cfg.reduce_upsample)
group=cfg.group)
print(json.dumps(vars(cfg), indent=4, sort_keys=True))

state_dict = torch.load(cfg.ckpt_path)
new_state_dict = OrderedDict()
for k, v in state_dict.items():
print(k)
name = k
# name = k[7:] # remove "module."
new_state_dict[name] = v

net.load_state_dict(new_state_dict)
net.cuda()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = net.to(device)

dataset = TestDataset(cfg.test_data_dir, cfg.scale)
sample(net, dataset, cfg)
sample(net, device, dataset, cfg)


if __name__ == "__main__":
Expand Down
Binary file modified checkpoint/carn.pth
Binary file not shown.
Binary file modified checkpoint/carn_m.pth
Binary file not shown.

0 comments on commit 77badff

Please sign in to comment.