From 4dea21b8e3182b72a814b95199ff79048d6f1174 Mon Sep 17 00:00:00 2001 From: Krishna Murthy Date: Wed, 15 Apr 2020 07:42:24 -0400 Subject: [PATCH] Lint, isort Signed-off-by: Krishna Murthy --- cache_dataset.py | 75 ++++++++++++++++++++++++++++++++++++-------- eval_nerf.py | 8 ++--- nerf/__init__.py | 1 - nerf/nerf_helpers.py | 3 +- tiny_nerf.py | 8 ++--- train_nerf.py | 19 ++++------- 6 files changed, 74 insertions(+), 40 deletions(-) diff --git a/cache_dataset.py b/cache_dataset.py index 09bf393..5217755 100644 --- a/cache_dataset.py +++ b/cache_dataset.py @@ -10,20 +10,50 @@ import torch from tqdm import tqdm -from nerf import load_blender_data -from nerf import load_llff_data -from nerf import get_ray_bundle, meshgrid_xy +from nerf import get_ray_bundle, load_blender_data, load_llff_data, meshgrid_xy def cache_nerf_dataset(args): - images, poses, render_poses, hwf, i_split = load_blender_data( - args.datapath, half_res=args.halfres, testskip=args.stride - ) - i_train, i_val, i_test = i_split - H, W, focal = hwf - H, W = int(H), int(W) - hwf = [H, W, focal] + images, poses, render_poses, hwf = ( + None, + None, + None, + None, + ) + i_train, i_val, i_test = None, None, None + + if args.type == "blender": + images, poses, render_poses, hwf, i_split = load_blender_data( + args.datapath, half_res=args.blender_half_res, testskip=args.blender_stride + ) + + i_train, i_val, i_test = i_split + H, W, focal = hwf + H, W = int(H), int(W) + hwf = [H, W, focal] + elif args.type == "llff": + images, poses, bds, render_poses, i_test = load_llff_data( + args.datapath, factor=args.llff_downsample_factor + ) + hwf = poses[0, :3, -1] + poses = poses[:, :3, :4] + if not isinstance(i_test, list): + i_test = [i_test] + if args.llffhold > 0: + i_test = np.arange(images.shape[0])[:: args.llffhold] + i_val = i_test + i_train = np.array( + [ + i + for i in np.arange(images.shape[0]) + if (i not in i_test and i not in i_val) + ] + ) + H, W, focal = hwf + hwf = [int(H), int(W), focal] + images = torch.from_numpy(images) + poses = torch.from_numpy(poses) # Device on which to run. if torch.cuda.is_available(): @@ -114,19 +144,38 @@ def cache_nerf_dataset(args): help="Path to root dir of dataset that needs caching.", ) parser.add_argument( - "--halfres", + "--type", + type=str.lower, + required=True, + choices=["blender", "llff"], + help="Type of the dataset to be cached.", + ) + parser.add_argument( + "--blender-half-res", type=bool, default=True, help="Whether to load the (Blender/synthetic) datasets" "at half the resolution.", ) parser.add_argument( - "--stride", + "--blender-stride", type=int, default=1, - help="Stride length. When set to k (k > 1), it samples" + help="Stride length (Blender datasets only). When set to k (k > 1), it samples" "every kth sample from the dataset.", ) + parser.add_argument( + "--llff-downsample-factor", + type=int, + default=8, + help="Downsample factor for images from the LLFF dataset.", + ) + parser.add_argument( + "--llffhold", + type=int, + default=8, + help="Determines the hold-out images for LLFF (TODO: make better).", + ) parser.add_argument( "--savedir", type=str, required=True, help="Path to save the cached dataset to." ) diff --git a/eval_nerf.py b/eval_nerf.py index 4ad3158..f53f2f9 100644 --- a/eval_nerf.py +++ b/eval_nerf.py @@ -9,12 +9,8 @@ import yaml from tqdm import tqdm -from nerf import models -from nerf import CfgNode -from nerf import load_blender_data -from nerf import load_llff_data -from nerf import get_ray_bundle, positional_encoding -from nerf import run_one_iter_of_nerf +from nerf import (CfgNode, get_ray_bundle, load_blender_data, load_llff_data, + models, positional_encoding, run_one_iter_of_nerf) def cast_to_image(tensor): diff --git a/nerf/__init__.py b/nerf/__init__.py index b5e2e39..ef9f0f6 100644 --- a/nerf/__init__.py +++ b/nerf/__init__.py @@ -5,4 +5,3 @@ from .nerf_helpers import * from .train_utils import * from .volume_rendering_utils import * - diff --git a/nerf/nerf_helpers.py b/nerf/nerf_helpers.py index 181004e..24ca0c2 100644 --- a/nerf/nerf_helpers.py +++ b/nerf/nerf_helpers.py @@ -1,7 +1,8 @@ +import math from typing import Optional -import math import torch + import torchsearchsorted diff --git a/tiny_nerf.py b/tiny_nerf.py index 41142c1..c088ec6 100644 --- a/tiny_nerf.py +++ b/tiny_nerf.py @@ -6,12 +6,8 @@ import torch from tqdm import tqdm, trange -from nerf import ( - cumprod_exclusive, - get_minibatches, - get_ray_bundle, - positional_encoding, -) +from nerf import (cumprod_exclusive, get_minibatches, get_ray_bundle, + positional_encoding) def compute_query_points_from_rays( diff --git a/train_nerf.py b/train_nerf.py index c870f51..c721522 100644 --- a/train_nerf.py +++ b/train_nerf.py @@ -10,18 +10,9 @@ from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm, trange -from nerf import models -from nerf import CfgNode -from nerf import load_blender_data -from nerf import load_llff_data -from nerf import ( - get_ray_bundle, - img2mse, - meshgrid_xy, - mse2psnr, - positional_encoding, -) -from nerf import run_one_iter_of_nerf +from nerf import (CfgNode, get_ray_bundle, img2mse, load_blender_data, + load_llff_data, meshgrid_xy, models, mse2psnr, + positional_encoding, run_one_iter_of_nerf) def main(): @@ -276,7 +267,9 @@ def encode_direction_fn(x): # Learning rate updates num_decay_steps = cfg.scheduler.lr_decay * 1000 - lr_new = cfg.optimizer.lr * (cfg.scheduler.lr_decay_factor ** (i / num_decay_steps)) + lr_new = cfg.optimizer.lr * ( + cfg.scheduler.lr_decay_factor ** (i / num_decay_steps) + ) for param_group in optimizer.param_groups: param_group["lr"] = lr_new