diff --git a/eval_nerf.py b/eval_nerf.py index ce1738f..cffc398 100644 --- a/eval_nerf.py +++ b/eval_nerf.py @@ -23,6 +23,12 @@ def cast_to_image(tensor, dataset_type): # return np.moveaxis(img, [-1], [0]) +def cast_to_disparity_image(tensor): + img = (tensor - tensor.min()) / (tensor.max() - tensor.min()) + img = img.clamp(0, 1) * 255 + return img.detach().cpu().numpy().astype(np.uint8) + + def main(): parser = argparse.ArgumentParser() @@ -38,6 +44,9 @@ def main(): parser.add_argument( "--savedir", type=str, help="Save images to this directory, if specified." ) + parser.add_argument( + "--save-disparity-image", action="store_true", help="Save disparity images too." + ) configargs = parser.parse_args() # Read config file. @@ -134,16 +143,19 @@ def main(): # Create directory to save images to. os.makedirs(configargs.savedir, exist_ok=True) + if configargs.save_disparity_image: + os.makedirs(os.path.join(configargs.savedir, "disparity"), exist_ok=True) # Evaluation loop times_per_image = [] for i, pose in enumerate(tqdm(render_poses)): start = time.time() rgb = None, None + disp = None, None with torch.no_grad(): pose = pose[:3, :4] ray_origins, ray_directions = get_ray_bundle(hwf[0], hwf[1], hwf[2], pose) - rgb_coarse, _, _, rgb_fine, _, _ = run_one_iter_of_nerf( + rgb_coarse, disp_coarse, _, rgb_fine, disp_fine, _ = run_one_iter_of_nerf( hwf[0], hwf[1], hwf[2], @@ -157,10 +169,15 @@ def main(): encode_direction_fn=encode_direction_fn, ) rgb = rgb_fine if rgb_fine is not None else rgb_coarse + if configargs.save_disparity_image: + disp = disp_fine if disp_fine is not None else disp_coarse times_per_image.append(time.time() - start) if configargs.savedir: savefile = os.path.join(configargs.savedir, f"{i:04d}.png") imageio.imwrite(savefile, cast_to_image(rgb[..., :3], cfg.dataset.type.lower())) + if configargs.save_disparity_image: + savefile = os.path.join(configargs.savedir, "disparity", f"{i:04d}.png") + imageio.imwrite(savefile, cast_to_disparity_image(disp)) tqdm.write(f"Avg time per image: {sum(times_per_image) / (i + 1)}") diff --git a/train_nerf.py b/train_nerf.py index 9c899d9..8ebb62d 100644 --- a/train_nerf.py +++ b/train_nerf.py @@ -250,6 +250,11 @@ def main(): rgb_fine[..., :3], target_ray_values[..., :3] ) # loss = torch.nn.functional.mse_loss(rgb_pred[..., :3], target_s[..., :3]) + loss = 0. + # if fine_loss is not None: + # loss = fine_loss + # else: + # loss = coarse_loss loss = coarse_loss + (fine_loss if fine_loss is not None else 0.0) loss.backward() psnr = mse2psnr(loss.item()) @@ -332,9 +337,12 @@ def main(): ) target_ray_values = img_target coarse_loss = img2mse(rgb_coarse[..., :3], target_ray_values[..., :3]) - fine_loss = 0.0 + loss, fine_loss = 0., 0. if rgb_fine is not None: fine_loss = img2mse(rgb_fine[..., :3], target_ray_values[..., :3]) + loss = fine_loss + else: + loss = coarse_loss loss = coarse_loss + fine_loss psnr = mse2psnr(loss.item()) writer.add_scalar("validation/loss", loss.item(), i)