From 99ff62fcbfa850f53900ec9e2c618d6df4359051 Mon Sep 17 00:00:00 2001 From: Krishna Murthy Date: Wed, 15 Apr 2020 08:52:03 -0400 Subject: [PATCH] Add log sampling Signed-off-by: Krishna Murthy --- config/lego.yml | 24 ++++++++++++++++++------ nerf/nerf_helpers.py | 31 ++++++++++++++++++++++++++----- train_nerf.py | 35 +++++++++++------------------------ 3 files changed, 55 insertions(+), 35 deletions(-) diff --git a/config/lego.yml b/config/lego.yml index e0389e4..33f3384 100644 --- a/config/lego.yml +++ b/config/lego.yml @@ -1,7 +1,7 @@ # Parameters to setup experiment. experiment: # Unique experiment identifier - id: lego-lowres + id: lego-lowres3 # Experiment logs will be stored at "logdir"/"id" logdir: logs # Seed for random number generators (for repeatability). @@ -44,19 +44,22 @@ models: # Name of the torch.nn.Module class that implements the model. type: FlexibleNeRFModel # Number of layers in the model. - num_layers: 4 + num_layers: 8 # Number of hidden units in each layer of the MLP (multi-layer # perceptron). - hidden_size: 64 + hidden_size: 128 # Add a skip connection once in a while. Note: This parameter # won't take affect unless num_layers > skip_connect_every. skip_connect_every: 3 # Whether to include the position (xyz) itself in its positional # encoding. include_input_xyz: True + # Whether or not to perform log sampling in the positional encoding + # of the coordinates. + log_sampling_xyz: True # Number of encoding functions to use in the positional encoding # of the coordinates. - num_encoding_fn_xyz: 6 + num_encoding_fn_xyz: 10 # Additionally use viewing directions as input. use_viewdirs: True # Whether to include the direction itself in its positional encoding. @@ -64,12 +67,15 @@ models: # Number of encoding functions to use in the positional encoding # of the direction. num_encoding_fn_dir: 4 + # Whether or not to perform log sampling in the positional encoding + # of the direction. + log_sampling_dir: True # Fine model. fine: # Name of the torch.nn.Module class that implements the model. type: FlexibleNeRFModel # Number of layers in the model. - num_layers: 4 + num_layers: 8 # Number of hidden units in each layer of the MLP (multi-layer # perceptron). hidden_size: 128 @@ -78,10 +84,13 @@ models: skip_connect_every: 3 # Number of encoding functions to use in the positional encoding # of the coordinates. - num_encoding_fn_xyz: 6 + num_encoding_fn_xyz: 10 # Whether to include the position (xyz) itself in its positional # encoding. include_input_xyz: True + # Whether or not to perform log sampling in the positional encoding + # of the coordinates. + log_sampling_xyz: True # Additionally use viewing directions as input. use_viewdirs: True # Whether to include the direction itself in its positional encoding. @@ -89,6 +98,9 @@ models: # Number of encoding functions to use in the positional encoding of # the direction. num_encoding_fn_dir: 4 + # Whether or not to perform log sampling in the positional encoding + # of the direction. + log_sampling_dir: True # Optimizer params. optimizer: diff --git a/nerf/nerf_helpers.py b/nerf/nerf_helpers.py index 24ca0c2..e102c47 100644 --- a/nerf/nerf_helpers.py +++ b/nerf/nerf_helpers.py @@ -111,7 +111,7 @@ def get_ray_bundle( def positional_encoding( - tensor, num_encoding_functions=6, include_input=True + tensor, num_encoding_functions=6, include_input=True, log_sampling=True ) -> torch.Tensor: r"""Apply positional encoding to the input. @@ -128,14 +128,35 @@ def positional_encoding( # TESTED # Trivially, the input tensor is added to the positional encoding. encoding = [tensor] if include_input else [] - # Now, encode the input using a set of high-frequency functions and append the - # resulting values to the encoding. - for i in range(num_encoding_functions): + frequency_bands = None + if log_sampling: + frequency_bands = 2. ** torch.linspace( + 0., num_encoding_functions - 1, num_encoding_functions, + dtype=tensor.dtype, device=tensor.device, + ) + else: + frequency_bands = torch.linspace( + 2. ** 0., 2. ** (num_encoding_functions - 1), num_encoding_functions, + dtype=tensor.dtype, device=tensor.device + ) + + for freq in frequency_bands: for func in [torch.sin, torch.cos]: - encoding.append(func(2.0 ** i * tensor)) + encoding.append(func(tensor * freq)) + return torch.cat(encoding, dim=-1) +def get_embedding_function( + num_encoding_functions=6, include_input=True, log_sampling=True +): + r"""Returns a lambda function that internally calls positional_encoding. + """ + return lambda x: positional_encoding( + x, num_encoding_functions, include_input, log_sampling + ) + + def ndc_rays(H, W, focal, near, rays_o, rays_d): # UNTESTED, but fairly sure. diff --git a/train_nerf.py b/train_nerf.py index c721522..b97caac 100644 --- a/train_nerf.py +++ b/train_nerf.py @@ -12,7 +12,7 @@ from nerf import (CfgNode, get_ray_bundle, img2mse, load_blender_data, load_llff_data, meshgrid_xy, models, mse2psnr, - positional_encoding, run_one_iter_of_nerf) + get_embedding_function, run_one_iter_of_nerf) def main(): @@ -97,30 +97,17 @@ def main(): else: device = "cpu" - # # Encoding function for position (xyz). - # encode_position_fn = lambda x: positional_encoding( - # x, num_encoding_functions=cfg.models.coarse.num_encoding_fn_xyz, - # include_input=cfg.models.coarse.include_input_xyz - # ) - # # Encoding function for direction. - # encode_direction_fn = lambda x: positional_encoding( - # x, num_encoding_functions=cfg.models.coarse.num_encoding_fn_dir, - # include_input=cfg.models.coarse.include_input_dir - # ) - - def encode_position_fn(x): - return positional_encoding( - x, - num_encoding_functions=cfg.models.coarse.num_encoding_fn_xyz, - include_input=cfg.models.coarse.include_input_xyz, - ) + encode_position_fn = get_embedding_function( + num_encoding_functions=cfg.models.coarse.num_encoding_fn_xyz, + include_input=cfg.models.coarse.include_input_xyz, + log_sampling=cfg.models.coarse.log_sampling_xyz, + ) - def encode_direction_fn(x): - return positional_encoding( - x, - num_encoding_functions=cfg.models.coarse.num_encoding_fn_dir, - include_input=cfg.models.coarse.include_input_dir, - ) + encode_direction_fn = get_embedding_function( + num_encoding_functions=cfg.models.coarse.num_encoding_fn_dir, + include_input=cfg.models.coarse.include_input_dir, + log_sampling=cfg.models.coarse.log_sampling_dir, + ) # Initialize a coarse-resolution model. model_coarse = getattr(models, cfg.models.coarse.type)(