Skip to content
This repository has been archived by the owner on Jul 30, 2024. It is now read-only.

Commit

Permalink
Add disparity image generation
Browse files Browse the repository at this point in the history
Signed-off-by: Krishna Murthy <[email protected]>
  • Loading branch information
Krishna Murthy committed Apr 17, 2020
1 parent 417b92c commit 317b14d
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 2 deletions.
19 changes: 18 additions & 1 deletion eval_nerf.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ def cast_to_image(tensor, dataset_type):
# return np.moveaxis(img, [-1], [0])


def cast_to_disparity_image(tensor):
img = (tensor - tensor.min()) / (tensor.max() - tensor.min())
img = img.clamp(0, 1) * 255
return img.detach().cpu().numpy().astype(np.uint8)


def main():

parser = argparse.ArgumentParser()
Expand All @@ -38,6 +44,9 @@ def main():
parser.add_argument(
"--savedir", type=str, help="Save images to this directory, if specified."
)
parser.add_argument(
"--save-disparity-image", action="store_true", help="Save disparity images too."
)
configargs = parser.parse_args()

# Read config file.
Expand Down Expand Up @@ -134,16 +143,19 @@ def main():

# Create directory to save images to.
os.makedirs(configargs.savedir, exist_ok=True)
if configargs.save_disparity_image:
os.makedirs(os.path.join(configargs.savedir, "disparity"), exist_ok=True)

# Evaluation loop
times_per_image = []
for i, pose in enumerate(tqdm(render_poses)):
start = time.time()
rgb = None, None
disp = None, None
with torch.no_grad():
pose = pose[:3, :4]
ray_origins, ray_directions = get_ray_bundle(hwf[0], hwf[1], hwf[2], pose)
rgb_coarse, _, _, rgb_fine, _, _ = run_one_iter_of_nerf(
rgb_coarse, disp_coarse, _, rgb_fine, disp_fine, _ = run_one_iter_of_nerf(
hwf[0],
hwf[1],
hwf[2],
Expand All @@ -157,10 +169,15 @@ def main():
encode_direction_fn=encode_direction_fn,
)
rgb = rgb_fine if rgb_fine is not None else rgb_coarse
if configargs.save_disparity_image:
disp = disp_fine if disp_fine is not None else disp_coarse
times_per_image.append(time.time() - start)
if configargs.savedir:
savefile = os.path.join(configargs.savedir, f"{i:04d}.png")
imageio.imwrite(savefile, cast_to_image(rgb[..., :3], cfg.dataset.type.lower()))
if configargs.save_disparity_image:
savefile = os.path.join(configargs.savedir, "disparity", f"{i:04d}.png")
imageio.imwrite(savefile, cast_to_disparity_image(disp))
tqdm.write(f"Avg time per image: {sum(times_per_image) / (i + 1)}")


Expand Down
10 changes: 9 additions & 1 deletion train_nerf.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,11 @@ def main():
rgb_fine[..., :3], target_ray_values[..., :3]
)
# loss = torch.nn.functional.mse_loss(rgb_pred[..., :3], target_s[..., :3])
loss = 0.
# if fine_loss is not None:
# loss = fine_loss
# else:
# loss = coarse_loss
loss = coarse_loss + (fine_loss if fine_loss is not None else 0.0)
loss.backward()
psnr = mse2psnr(loss.item())
Expand Down Expand Up @@ -332,9 +337,12 @@ def main():
)
target_ray_values = img_target
coarse_loss = img2mse(rgb_coarse[..., :3], target_ray_values[..., :3])
fine_loss = 0.0
loss, fine_loss = 0., 0.
if rgb_fine is not None:
fine_loss = img2mse(rgb_fine[..., :3], target_ray_values[..., :3])
loss = fine_loss
else:
loss = coarse_loss
loss = coarse_loss + fine_loss
psnr = mse2psnr(loss.item())
writer.add_scalar("validation/loss", loss.item(), i)
Expand Down

0 comments on commit 317b14d

Please sign in to comment.