Skip to content
This repository has been archived by the owner on Jul 30, 2024. It is now read-only.

Commit

Permalink
Add log sampling
Browse files Browse the repository at this point in the history
Signed-off-by: Krishna Murthy <[email protected]>
  • Loading branch information
Krishna Murthy committed Apr 15, 2020
1 parent 4dea21b commit 99ff62f
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 35 deletions.
24 changes: 18 additions & 6 deletions config/lego.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Parameters to setup experiment.
experiment:
# Unique experiment identifier
id: lego-lowres
id: lego-lowres3
# Experiment logs will be stored at "logdir"/"id"
logdir: logs
# Seed for random number generators (for repeatability).
Expand Down Expand Up @@ -44,32 +44,38 @@ models:
# Name of the torch.nn.Module class that implements the model.
type: FlexibleNeRFModel
# Number of layers in the model.
num_layers: 4
num_layers: 8
# Number of hidden units in each layer of the MLP (multi-layer
# perceptron).
hidden_size: 64
hidden_size: 128
# Add a skip connection once in a while. Note: This parameter
# won't take affect unless num_layers > skip_connect_every.
skip_connect_every: 3
# Whether to include the position (xyz) itself in its positional
# encoding.
include_input_xyz: True
# Whether or not to perform log sampling in the positional encoding
# of the coordinates.
log_sampling_xyz: True
# Number of encoding functions to use in the positional encoding
# of the coordinates.
num_encoding_fn_xyz: 6
num_encoding_fn_xyz: 10
# Additionally use viewing directions as input.
use_viewdirs: True
# Whether to include the direction itself in its positional encoding.
include_input_dir: True
# Number of encoding functions to use in the positional encoding
# of the direction.
num_encoding_fn_dir: 4
# Whether or not to perform log sampling in the positional encoding
# of the direction.
log_sampling_dir: True
# Fine model.
fine:
# Name of the torch.nn.Module class that implements the model.
type: FlexibleNeRFModel
# Number of layers in the model.
num_layers: 4
num_layers: 8
# Number of hidden units in each layer of the MLP (multi-layer
# perceptron).
hidden_size: 128
Expand All @@ -78,17 +84,23 @@ models:
skip_connect_every: 3
# Number of encoding functions to use in the positional encoding
# of the coordinates.
num_encoding_fn_xyz: 6
num_encoding_fn_xyz: 10
# Whether to include the position (xyz) itself in its positional
# encoding.
include_input_xyz: True
# Whether or not to perform log sampling in the positional encoding
# of the coordinates.
log_sampling_xyz: True
# Additionally use viewing directions as input.
use_viewdirs: True
# Whether to include the direction itself in its positional encoding.
include_input_dir: True
# Number of encoding functions to use in the positional encoding of
# the direction.
num_encoding_fn_dir: 4
# Whether or not to perform log sampling in the positional encoding
# of the direction.
log_sampling_dir: True

# Optimizer params.
optimizer:
Expand Down
31 changes: 26 additions & 5 deletions nerf/nerf_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def get_ray_bundle(


def positional_encoding(
tensor, num_encoding_functions=6, include_input=True
tensor, num_encoding_functions=6, include_input=True, log_sampling=True
) -> torch.Tensor:
r"""Apply positional encoding to the input.
Expand All @@ -128,14 +128,35 @@ def positional_encoding(
# TESTED
# Trivially, the input tensor is added to the positional encoding.
encoding = [tensor] if include_input else []
# Now, encode the input using a set of high-frequency functions and append the
# resulting values to the encoding.
for i in range(num_encoding_functions):
frequency_bands = None
if log_sampling:
frequency_bands = 2. ** torch.linspace(
0., num_encoding_functions - 1, num_encoding_functions,
dtype=tensor.dtype, device=tensor.device,
)
else:
frequency_bands = torch.linspace(
2. ** 0., 2. ** (num_encoding_functions - 1), num_encoding_functions,
dtype=tensor.dtype, device=tensor.device
)

for freq in frequency_bands:
for func in [torch.sin, torch.cos]:
encoding.append(func(2.0 ** i * tensor))
encoding.append(func(tensor * freq))

return torch.cat(encoding, dim=-1)


def get_embedding_function(
num_encoding_functions=6, include_input=True, log_sampling=True
):
r"""Returns a lambda function that internally calls positional_encoding.
"""
return lambda x: positional_encoding(
x, num_encoding_functions, include_input, log_sampling
)


def ndc_rays(H, W, focal, near, rays_o, rays_d):
# UNTESTED, but fairly sure.

Expand Down
35 changes: 11 additions & 24 deletions train_nerf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from nerf import (CfgNode, get_ray_bundle, img2mse, load_blender_data,
load_llff_data, meshgrid_xy, models, mse2psnr,
positional_encoding, run_one_iter_of_nerf)
get_embedding_function, run_one_iter_of_nerf)


def main():
Expand Down Expand Up @@ -97,30 +97,17 @@ def main():
else:
device = "cpu"

# # Encoding function for position (xyz).
# encode_position_fn = lambda x: positional_encoding(
# x, num_encoding_functions=cfg.models.coarse.num_encoding_fn_xyz,
# include_input=cfg.models.coarse.include_input_xyz
# )
# # Encoding function for direction.
# encode_direction_fn = lambda x: positional_encoding(
# x, num_encoding_functions=cfg.models.coarse.num_encoding_fn_dir,
# include_input=cfg.models.coarse.include_input_dir
# )

def encode_position_fn(x):
return positional_encoding(
x,
num_encoding_functions=cfg.models.coarse.num_encoding_fn_xyz,
include_input=cfg.models.coarse.include_input_xyz,
)
encode_position_fn = get_embedding_function(
num_encoding_functions=cfg.models.coarse.num_encoding_fn_xyz,
include_input=cfg.models.coarse.include_input_xyz,
log_sampling=cfg.models.coarse.log_sampling_xyz,
)

def encode_direction_fn(x):
return positional_encoding(
x,
num_encoding_functions=cfg.models.coarse.num_encoding_fn_dir,
include_input=cfg.models.coarse.include_input_dir,
)
encode_direction_fn = get_embedding_function(
num_encoding_functions=cfg.models.coarse.num_encoding_fn_dir,
include_input=cfg.models.coarse.include_input_dir,
log_sampling=cfg.models.coarse.log_sampling_dir,
)

# Initialize a coarse-resolution model.
model_coarse = getattr(models, cfg.models.coarse.type)(
Expand Down

0 comments on commit 99ff62f

Please sign in to comment.