Skip to content
This repository has been archived by the owner on Jul 30, 2024. It is now read-only.

Commit

Permalink
Lint, isort
Browse files Browse the repository at this point in the history
Signed-off-by: Krishna Murthy <[email protected]>
  • Loading branch information
Krishna Murthy committed Apr 15, 2020
1 parent deae4a8 commit 4dea21b
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 40 deletions.
75 changes: 62 additions & 13 deletions cache_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,50 @@
import torch
from tqdm import tqdm

from nerf import load_blender_data
from nerf import load_llff_data
from nerf import get_ray_bundle, meshgrid_xy
from nerf import get_ray_bundle, load_blender_data, load_llff_data, meshgrid_xy


def cache_nerf_dataset(args):
images, poses, render_poses, hwf, i_split = load_blender_data(
args.datapath, half_res=args.halfres, testskip=args.stride
)

i_train, i_val, i_test = i_split
H, W, focal = hwf
H, W = int(H), int(W)
hwf = [H, W, focal]
images, poses, render_poses, hwf = (
None,
None,
None,
None,
)
i_train, i_val, i_test = None, None, None

if args.type == "blender":
images, poses, render_poses, hwf, i_split = load_blender_data(
args.datapath, half_res=args.blender_half_res, testskip=args.blender_stride
)

i_train, i_val, i_test = i_split
H, W, focal = hwf
H, W = int(H), int(W)
hwf = [H, W, focal]
elif args.type == "llff":
images, poses, bds, render_poses, i_test = load_llff_data(
args.datapath, factor=args.llff_downsample_factor
)
hwf = poses[0, :3, -1]
poses = poses[:, :3, :4]
if not isinstance(i_test, list):
i_test = [i_test]
if args.llffhold > 0:
i_test = np.arange(images.shape[0])[:: args.llffhold]
i_val = i_test
i_train = np.array(
[
i
for i in np.arange(images.shape[0])
if (i not in i_test and i not in i_val)
]
)
H, W, focal = hwf
hwf = [int(H), int(W), focal]
images = torch.from_numpy(images)
poses = torch.from_numpy(poses)

# Device on which to run.
if torch.cuda.is_available():
Expand Down Expand Up @@ -114,19 +144,38 @@ def cache_nerf_dataset(args):
help="Path to root dir of dataset that needs caching.",
)
parser.add_argument(
"--halfres",
"--type",
type=str.lower,
required=True,
choices=["blender", "llff"],
help="Type of the dataset to be cached.",
)
parser.add_argument(
"--blender-half-res",
type=bool,
default=True,
help="Whether to load the (Blender/synthetic) datasets"
"at half the resolution.",
)
parser.add_argument(
"--stride",
"--blender-stride",
type=int,
default=1,
help="Stride length. When set to k (k > 1), it samples"
help="Stride length (Blender datasets only). When set to k (k > 1), it samples"
"every kth sample from the dataset.",
)
parser.add_argument(
"--llff-downsample-factor",
type=int,
default=8,
help="Downsample factor for images from the LLFF dataset.",
)
parser.add_argument(
"--llffhold",
type=int,
default=8,
help="Determines the hold-out images for LLFF (TODO: make better).",
)
parser.add_argument(
"--savedir", type=str, required=True, help="Path to save the cached dataset to."
)
Expand Down
8 changes: 2 additions & 6 deletions eval_nerf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,8 @@
import yaml
from tqdm import tqdm

from nerf import models
from nerf import CfgNode
from nerf import load_blender_data
from nerf import load_llff_data
from nerf import get_ray_bundle, positional_encoding
from nerf import run_one_iter_of_nerf
from nerf import (CfgNode, get_ray_bundle, load_blender_data, load_llff_data,
models, positional_encoding, run_one_iter_of_nerf)


def cast_to_image(tensor):
Expand Down
1 change: 0 additions & 1 deletion nerf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,3 @@
from .nerf_helpers import *
from .train_utils import *
from .volume_rendering_utils import *

3 changes: 2 additions & 1 deletion nerf/nerf_helpers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import math
from typing import Optional

import math
import torch

import torchsearchsorted


Expand Down
8 changes: 2 additions & 6 deletions tiny_nerf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,8 @@
import torch
from tqdm import tqdm, trange

from nerf import (
cumprod_exclusive,
get_minibatches,
get_ray_bundle,
positional_encoding,
)
from nerf import (cumprod_exclusive, get_minibatches, get_ray_bundle,
positional_encoding)


def compute_query_points_from_rays(
Expand Down
19 changes: 6 additions & 13 deletions train_nerf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,9 @@
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm, trange

from nerf import models
from nerf import CfgNode
from nerf import load_blender_data
from nerf import load_llff_data
from nerf import (
get_ray_bundle,
img2mse,
meshgrid_xy,
mse2psnr,
positional_encoding,
)
from nerf import run_one_iter_of_nerf
from nerf import (CfgNode, get_ray_bundle, img2mse, load_blender_data,
load_llff_data, meshgrid_xy, models, mse2psnr,
positional_encoding, run_one_iter_of_nerf)


def main():
Expand Down Expand Up @@ -276,7 +267,9 @@ def encode_direction_fn(x):

# Learning rate updates
num_decay_steps = cfg.scheduler.lr_decay * 1000
lr_new = cfg.optimizer.lr * (cfg.scheduler.lr_decay_factor ** (i / num_decay_steps))
lr_new = cfg.optimizer.lr * (
cfg.scheduler.lr_decay_factor ** (i / num_decay_steps)
)
for param_group in optimizer.param_groups:
param_group["lr"] = lr_new

Expand Down

0 comments on commit 4dea21b

Please sign in to comment.