diff --git a/cache_dataset.py b/cache_dataset.py index 539f96c..7e74a0a 100644 --- a/cache_dataset.py +++ b/cache_dataset.py @@ -10,6 +10,8 @@ import torch from tqdm import tqdm +from load_blender import load_blender_data +from load_llff import load_llff_data from nerf_helpers import get_ray_bundle, meshgrid_xy diff --git a/config/lego.yml b/config/lego.yml index ca75041..e0389e4 100644 --- a/config/lego.yml +++ b/config/lego.yml @@ -1,13 +1,13 @@ # Parameters to setup experiment. experiment: # Unique experiment identifier - id: lego + id: lego-lowres # Experiment logs will be stored at "logdir"/"id" logdir: logs # Seed for random number generators (for repeatability). randomseed: 42 # Cause, why not? # Number of training iterations. - train_iters: 250000 + train_iters: 200000 # Number of training iterations after which to validate. validate_every: 100 # Number of training iterations after which to checkpoint. @@ -23,12 +23,12 @@ dataset: basedir: cache/nerf_synthetic/lego # Optionally, provide a path to the pre-cached dataset dir. This # overrides the other dataset options. - cachedir: cache/legocache/legofull + cachedir: cache/legocache200 # For the Blender datasets (synthetic), optionally return images # at half the original resolution of 800 x 800, to save space. half_res: True # Stride (include one per "testskip" images in the dataset). - testskip: 10 + testskip: 1 # Do not use NDC (normalized device coordinates). Usually True for # synthetic (Blender) datasets. no_ndc: True @@ -116,7 +116,7 @@ nerf: train: # Number of random rays to retain from each image. # These sampled rays are used for training, and the others are discarded. - num_random_rays: 8192 # 32 * 32 * 4 + num_random_rays: 1024 # 32 * 32 * 4 # Size of each chunk (rays are batched into "chunks" and passed through # the network) chunksize: 131072 # 131072 # 1024 * 32 diff --git a/train_nerf.py b/train_nerf.py index 292a5ff..3932381 100644 --- a/train_nerf.py +++ b/train_nerf.py @@ -168,17 +168,21 @@ def encode_direction_fn(x): with open(os.path.join(logdir, "config.yml"), "w") as f: f.write(cfg.dump()) # cfg, f, default_flow_style=False) + # By default, start at iteration 0 (unless a checkpoint is specified). + start_iter = 0 + # Load an existing checkpoint, if a path is specified. if os.path.exists(configargs.load_checkpoint): checkpoint = torch.load(configargs.load_checkpoint) model_coarse.load_state_dict(checkpoint["model_coarse_state_dict"]) if checkpoint["model_fine_state_dict"]: model_fine.load_state_dict(checkpoint["model_fine_state_dict"]) - # optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + start_iter = checkpoint["iter"] # # TODO: Prepare raybatch tensor if batching random rays - for i in trange(cfg.experiment.train_iters): + for i in trange(start_iter, cfg.experiment.train_iters): model_coarse.train() if model_fine: