diff --git a/carn/sample.py b/carn/sample.py index a6a94ac..748177c 100644 --- a/carn/sample.py +++ b/carn/sample.py @@ -16,7 +16,6 @@ def parse_args(): parser.add_argument("--model", type=str) parser.add_argument("--ckpt_path", type=str) parser.add_argument("--group", type=int, default=1) - parser.add_argument("--reduce_upsample", action="store_true", default=False) parser.add_argument("--sample_dir", type=str) parser.add_argument("--test_data_dir", type=str, default="dataset/Urban100") parser.add_argument("--cuda", action="store_true") @@ -33,7 +32,7 @@ def save_image(tensor, filename): im.save(filename) -def sample(net, dataset, cfg): +def sample(net, device, dataset, cfg): scale = cfg.scale for step, (hr, lr, name) in enumerate(dataset): if "DIV2K" in dataset.name: @@ -42,19 +41,19 @@ def sample(net, dataset, cfg): h_half, w_half = int(h/2), int(w/2) h_chop, w_chop = h_half + cfg.shave, w_half + cfg.shave - lr_patch = torch.FloatTensor(4, 3, h_chop, w_chop) + lr_patch = torch.tensor((4, 3, h_chop, w_chop), dtype=torch.float) lr_patch[0].copy_(lr[:, 0:h_chop, 0:w_chop]) lr_patch[1].copy_(lr[:, 0:h_chop, w-w_chop:w]) lr_patch[2].copy_(lr[:, h-h_chop:h, 0:w_chop]) lr_patch[3].copy_(lr[:, h-h_chop:h, w-w_chop:w]) - lr_patch = Variable(lr_patch, volatile=True).cuda() + lr_patch = lr_patch.to(device) - sr = net(lr_patch, cfg.scale).data + sr = net(lr_patch, cfg.scale).detach() h, h_half, h_chop = h*scale, h_half*scale, h_chop*scale w, w_half, w_chop = w*scale, w_half*scale, w_chop*scale - result = torch.FloatTensor(3, h, w).cuda() + result = torch.tensor((3, h, w), dtype=torch.float).to(device) result[:, 0:h_half, 0:w_half].copy_(sr[0, :, 0:h_half, 0:w_half]) result[:, 0:h_half, w_half:w].copy_(sr[1, :, 0:h_half, w_chop-w+w_half:w_chop]) result[:, h_half:h, 0:w_half].copy_(sr[2, :, h_chop-h+h_half:h_chop, 0:w_half]) @@ -63,9 +62,9 @@ def sample(net, dataset, cfg): t2 = time.time() else: t1 = time.time() - lr = Variable(lr.unsqueeze(0), volatile=True).cuda() - sr = net(lr, cfg.scale).data[0] - lr = lr.data[0] + lr = lr.unsqueeze(0).to(device) + sr = net(lr, cfg.scale).detach().squeeze(0) + lr = lr.squeeze(0) t2 = time.time() model_name = cfg.ckpt_path.split(".")[0].split("/")[-1] @@ -80,17 +79,14 @@ def sample(net, dataset, cfg): "x{}".format(cfg.scale), "HR") - if not os.path.exists(sr_dir): - os.makedirs(sr_dir) - - if not os.path.exists(hr_dir): - os.makedirs(hr_dir) + os.makedirs(sr_dir, exist_ok=True) + os.makedirs(hr_dir, exist_ok=True) sr_im_path = os.path.join(sr_dir, "{}".format(name.replace("HR", "SR"))) hr_im_path = os.path.join(hr_dir, "{}".format(name)) save_image(sr, sr_im_path) - # save_image(hr, hr_im_path) + save_image(hr, hr_im_path) print("Saved {} ({}x{} -> {}x{}, {:.3f}s)" .format(sr_im_path, lr.shape[1], lr.shape[2], sr.shape[1], sr.shape[2], t2-t1)) @@ -98,23 +94,23 @@ def sample(net, dataset, cfg): def main(cfg): module = importlib.import_module("model.{}".format(cfg.model)) net = module.Net(multi_scale=True, - group=cfg.group, - reduce_upsample=cfg.reduce_upsample) + group=cfg.group) print(json.dumps(vars(cfg), indent=4, sort_keys=True)) state_dict = torch.load(cfg.ckpt_path) new_state_dict = OrderedDict() for k, v in state_dict.items(): - print(k) name = k # name = k[7:] # remove "module." new_state_dict[name] = v net.load_state_dict(new_state_dict) - net.cuda() + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + net = net.to(device) dataset = TestDataset(cfg.test_data_dir, cfg.scale) - sample(net, dataset, cfg) + sample(net, device, dataset, cfg) if __name__ == "__main__": diff --git a/checkpoint/carn.pth b/checkpoint/carn.pth index 2ca3dc9..2b9fb1e 100644 Binary files a/checkpoint/carn.pth and b/checkpoint/carn.pth differ diff --git a/checkpoint/carn_m.pth b/checkpoint/carn_m.pth index 2c699d2..8d53d10 100644 Binary files a/checkpoint/carn_m.pth and b/checkpoint/carn_m.pth differ