Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor RGBD Pixel Kernel #161

Merged
merged 4 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 16 additions & 15 deletions src/b3d/chisight/gen3d/image_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,18 @@
from genjax import Pytree
from genjax.typing import FloatArray, PRNGKey

from b3d.chisight.gen3d.pixel_kernels import (
FullPixelColorDistribution,
FullPixelDepthDistribution,
PixelDepthDistribution,
from b3d.chisight.gen3d.pixel_kernels import is_unexplained
from b3d.chisight.gen3d.pixel_kernels.pixel_color_kernels import (
TruncatedLaplacePixelColorDistribution,
UniformPixelColorDistribution,
)
from b3d.chisight.gen3d.pixel_kernels.pixel_depth_kernels import (
TruncatedLaplacePixelDepthDistribution,
UniformPixelDepthDistribution,
)
from b3d.chisight.gen3d.pixel_kernels.pixel_rgbd_kernels import (
FullPixelRGBDDistribution,
PixelRGBDDistribution,
is_unexplained,
)
from b3d.chisight.gen3d.projection import PixelsPointsAssociation

Expand Down Expand Up @@ -51,9 +57,6 @@ def logpdf(
) -> FloatArray:
raise NotImplementedError

def get_depth_vertex_kernel(self) -> PixelDepthDistribution:
raise NotImplementedError

def get_rgbd_vertex_kernel(self) -> PixelRGBDDistribution:
raise NotImplementedError

Expand Down Expand Up @@ -130,18 +133,16 @@ def logpdf(
# -inf if it is invalid. We need to filter those out here.
# (invalid rgbd could happen when the vertex is projected out of the image)
scores = jnp.where(is_unexplained(observed_rgbd_per_point), 0.0, scores)

return scores.sum()

def get_rgbd_vertex_kernel(self) -> PixelRGBDDistribution:
# Note: The distributions were originally defined for per-pixel computation,
# but they should work for per-vertex computation as well, except that
# they don't expect observed_rgbd to be invalid, so we need to handle
# that manually.
return PixelRGBDDistribution(
FullPixelColorDistribution(),
FullPixelDepthDistribution(self.near, self.far),
return FullPixelRGBDDistribution(
TruncatedLaplacePixelColorDistribution(),
UniformPixelColorDistribution(),
TruncatedLaplacePixelDepthDistribution(self.near, self.far),
UniformPixelDepthDistribution(self.near, self.far),
)

def get_depth_vertex_kernel(self) -> PixelDepthDistribution:
return self.get_rgbd_vertex_kernel().depth_kernel
12 changes: 6 additions & 6 deletions src/b3d/chisight/gen3d/pixel_kernels/__init__.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,27 @@
from b3d.chisight.gen3d.pixel_kernels.pixel_color_kernels import (
FullPixelColorDistribution,
MixturePixelColorDistribution,
PixelColorDistribution,
is_unexplained,
)
from b3d.chisight.gen3d.pixel_kernels.pixel_depth_kernels import (
DEPTH_NONRETURN_VAL,
FullPixelDepthDistribution,
MixturePixelDepthDistribution,
PixelDepthDistribution,
UnexplainedPixelDepthDistribution,
)
from b3d.chisight.gen3d.pixel_kernels.pixel_rgbd_kernels import PixelRGBDDistribution
from b3d.chisight.gen3d.pixel_kernels.pixel_rgbd_kernels import (
FullPixelRGBDDistribution,
PixelRGBDDistribution,
is_unexplained,
)

__all__ = [
"is_unexplained",
"DEPTH_NONRETURN_VAL",
"FullPixelColorDistribution",
"FullPixelDepthDistribution",
"MixturePixelColorDistribution",
"MixturePixelDepthDistribution",
"PixelColorDistribution",
"PixelDepthDistribution",
"PixelRGBDDistribution",
"FullPixelRGBDDistribution",
"UnexplainedPixelDepthDistribution",
]
90 changes: 0 additions & 90 deletions src/b3d/chisight/gen3d/pixel_kernels/pixel_color_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,6 @@
COLOR_MAX_VAL: float = 1.0


def is_unexplained(latent_value: FloatArray) -> bool:
"""
Check if a given `latent_value` value given to a pixel
indicates that no latent point hits a pixel.
This is done by checking if any of the latent color values
are negative.

Args:
latent_value (FloatArray): The latent color of the pixel.

Returns:
bool: True is none of the latent point hits the pixel, False otherwise.
"""
return jnp.any(latent_value < 0.0)


@Pytree.dataclass
class PixelColorDistribution(genjax.ExactDensity):
"""
Expand Down Expand Up @@ -258,77 +242,3 @@ def logpdf_per_channel(
)

return jnp.logaddexp(*logprobs)


@Pytree.dataclass
class FullPixelColorDistribution(PixelColorDistribution):
"""A distribution that generates the color of the pixel according to the
following rule:

if no latent point hits the pixel:
color ~ uniform(0, 1)
else:
color ~ mixture(
[uniform(0, 1), truncated_laplace(latent_color; color_scale)],
[occluded_prob, 1 - occluded_prob]
)

Constructor args:

Distribution args:
- `latent_color`: 3-array. If no latent point hits the pixel, should contain
3 negative values. If a latent point hits the pixel, should contain the point's
color as an RGB value in [0, 1]^3.
- color_scale: float. The scale of the truncated Laplace distribution
centered around the latent color used for inlier color observations.
- `color_visibility_prob`: float. If a latent point hits the pixel, should contain
the probability associated with that point that the generated color is
visible (non-occluded). If no latent point hits the pixel, this value is ignored.

Distribution support:
- An RGB value in [0, 1]^3.
"""

@property
def _color_from_latent(self) -> PixelColorDistribution:
return MixturePixelColorDistribution()

@property
def _unexplained_color(self) -> PixelColorDistribution:
return UniformPixelColorDistribution()

def sample(
self,
key: PRNGKey,
latent_color: FloatArray,
color_scale: FloatArray,
visibility_prob: FloatArray,
) -> FloatArray:
return jax.lax.cond(
is_unexplained(latent_color),
self._unexplained_color.sample, # if no point hits current pixel
self._color_from_latent.sample, # if pixel is being hit by a latent point
# sample args
key,
latent_color,
color_scale,
visibility_prob,
)

def logpdf_per_channel(
self,
observed_color: FloatArray,
latent_color: FloatArray,
color_scale: FloatArray,
visibility_prob: float,
) -> FloatArray:
return jax.lax.cond(
is_unexplained(latent_color),
self._unexplained_color.logpdf_per_channel, # if no point hits current pixel
self._color_from_latent.logpdf_per_channel, # if pixel is being hit by a latent point
# logpdf args
observed_color,
latent_color,
color_scale,
visibility_prob,
)
75 changes: 0 additions & 75 deletions src/b3d/chisight/gen3d/pixel_kernels/pixel_depth_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@
from typing import TYPE_CHECKING, Any

import genjax
import jax
import jax.numpy as jnp
from genjax import Pytree
from genjax.typing import FloatArray, PRNGKey
from tensorflow_probability.substrates import jax as tfp

from b3d.chisight.gen3d.pixel_kernels.pixel_color_kernels import is_unexplained
from b3d.modeling_utils import (
_FIXED_DEPTH_UNIFORM_WINDOW,
PythonMixtureDistribution,
Expand Down Expand Up @@ -316,76 +314,3 @@ def logpdf(
**kwargs,
) -> float:
return self._mixture_dist.logpdf(observed_depth, self._mix_ratio, [(), ()])


@Pytree.dataclass
class FullPixelDepthDistribution(PixelDepthDistribution):
"""A distribution that generates the depth of the pixel according to the
following rule:

if no latent point hits the pixel:
depth ~ mixture(
[delta(DEPTH_NONRETURN_VAL), uniform(near, far)],
[unexplained_depth_nonreturn_prob, 1 - unexplained_depth_nonreturn_prob]
)
else:
mixture(
[delta(DEPTH_NONRETURN_VAL), uniform(near, far), laplace(latent_depth; depth_scale)],
[depth_nonreturn_prob, (1 - depth_nonreturn_prob) * (1 - visibility_prob), remaining_prob]
)
"""

near: float = Pytree.static()
far: float = Pytree.static()

@property
def _depth_from_latent(self) -> PixelDepthDistribution:
return MixturePixelDepthDistribution(self.near, self.far)

@property
def _unexplained_depth(self) -> PixelDepthDistribution:
return UnexplainedPixelDepthDistribution(self.near, self.far)

def sample(
self,
key: PRNGKey,
latent_depth: FloatArray,
depth_scale: FloatArray,
visibility_prob: FloatArray,
depth_nonreturn_prob: float,
*args,
**kwargs,
) -> FloatArray:
return jax.lax.cond(
is_unexplained(latent_depth),
self._unexplained_depth.sample, # if no point hits current pixel
self._depth_from_latent.sample, # if pixel is being hit by a latent point
# sample args
key,
latent_depth,
depth_scale,
visibility_prob,
depth_nonreturn_prob,
)

def logpdf(
self,
observed_depth: FloatArray,
latent_depth: FloatArray,
depth_scale: FloatArray,
visibility_prob: float,
depth_nonreturn_prob: float,
*args,
**kwargs,
) -> FloatArray:
return jax.lax.cond(
is_unexplained(latent_depth),
self._unexplained_depth.logpdf, # if no point hits current pixel
self._depth_from_latent.logpdf, # if pixel is being hit by a latent point
# logpdf args
observed_depth,
latent_depth,
depth_scale,
visibility_prob,
depth_nonreturn_prob,
)
Loading
Loading