From e677b8224bb5af542de736ac40330518823441cb Mon Sep 17 00:00:00 2001 From: Nishad Gothoskar Date: Wed, 11 Sep 2024 18:40:18 +0000 Subject: [PATCH 1/3] Refactor RGBD Pixel Kernel --- src/b3d/chisight/gen3d/image_kernel.py | 31 ++--- .../chisight/gen3d/pixel_kernels/__init__.py | 12 +- .../pixel_kernels/pixel_color_kernels.py | 90 -------------- .../pixel_kernels/pixel_depth_kernels.py | 75 ----------- .../gen3d/pixel_kernels/pixel_rgbd_kernels.py | 117 +++++++++++++++--- .../test_depth_nonreturn_prob_inference.py | 52 ++++---- tests/gen3d/test_pixel_color_kernels.py | 64 +++++----- tests/gen3d/test_pixel_depth_kernels.py | 64 ++++------ tests/gen3d/test_pixel_rgbd_kernels.py | 20 ++- 9 files changed, 214 insertions(+), 311 deletions(-) diff --git a/src/b3d/chisight/gen3d/image_kernel.py b/src/b3d/chisight/gen3d/image_kernel.py index 58ff5c29..0fd18801 100644 --- a/src/b3d/chisight/gen3d/image_kernel.py +++ b/src/b3d/chisight/gen3d/image_kernel.py @@ -7,12 +7,18 @@ from genjax import Pytree from genjax.typing import FloatArray, PRNGKey -from b3d.chisight.gen3d.pixel_kernels import ( - FullPixelColorDistribution, - FullPixelDepthDistribution, - PixelDepthDistribution, +from b3d.chisight.gen3d.pixel_kernels import is_unexplained +from b3d.chisight.gen3d.pixel_kernels.pixel_color_kernels import ( + TruncatedLaplacePixelColorDistribution, + UniformPixelColorDistribution, +) +from b3d.chisight.gen3d.pixel_kernels.pixel_depth_kernels import ( + TruncatedLaplacePixelDepthDistribution, + UniformPixelDepthDistribution, +) +from b3d.chisight.gen3d.pixel_kernels.pixel_rgbd_kernels import ( + FullPixelRGBDDistribution, PixelRGBDDistribution, - is_unexplained, ) from b3d.chisight.gen3d.projection import PixelsPointsAssociation @@ -51,9 +57,6 @@ def logpdf( ) -> FloatArray: raise NotImplementedError - def get_depth_vertex_kernel(self) -> PixelDepthDistribution: - raise NotImplementedError - def get_rgbd_vertex_kernel(self) -> PixelRGBDDistribution: raise NotImplementedError @@ -130,7 +133,6 @@ def logpdf( # -inf if it is invalid. We need to filter those out here. # (invalid rgbd could happen when the vertex is projected out of the image) scores = jnp.where(is_unexplained(observed_rgbd_per_point), 0.0, scores) - return scores.sum() def get_rgbd_vertex_kernel(self) -> PixelRGBDDistribution: @@ -138,10 +140,9 @@ def get_rgbd_vertex_kernel(self) -> PixelRGBDDistribution: # but they should work for per-vertex computation as well, except that # they don't expect observed_rgbd to be invalid, so we need to handle # that manually. - return PixelRGBDDistribution( - FullPixelColorDistribution(), - FullPixelDepthDistribution(self.near, self.far), + return FullPixelRGBDDistribution( + TruncatedLaplacePixelColorDistribution(), + UniformPixelColorDistribution(), + TruncatedLaplacePixelDepthDistribution(self.near, self.far), + UniformPixelDepthDistribution(self.near, self.far), ) - - def get_depth_vertex_kernel(self) -> PixelDepthDistribution: - return self.get_rgbd_vertex_kernel().depth_kernel diff --git a/src/b3d/chisight/gen3d/pixel_kernels/__init__.py b/src/b3d/chisight/gen3d/pixel_kernels/__init__.py index 401b95b7..307026aa 100644 --- a/src/b3d/chisight/gen3d/pixel_kernels/__init__.py +++ b/src/b3d/chisight/gen3d/pixel_kernels/__init__.py @@ -1,27 +1,27 @@ from b3d.chisight.gen3d.pixel_kernels.pixel_color_kernels import ( - FullPixelColorDistribution, MixturePixelColorDistribution, PixelColorDistribution, - is_unexplained, ) from b3d.chisight.gen3d.pixel_kernels.pixel_depth_kernels import ( DEPTH_NONRETURN_VAL, - FullPixelDepthDistribution, MixturePixelDepthDistribution, PixelDepthDistribution, UnexplainedPixelDepthDistribution, ) -from b3d.chisight.gen3d.pixel_kernels.pixel_rgbd_kernels import PixelRGBDDistribution +from b3d.chisight.gen3d.pixel_kernels.pixel_rgbd_kernels import ( + FullPixelRGBDDistribution, + PixelRGBDDistribution, + is_unexplained, +) __all__ = [ "is_unexplained", "DEPTH_NONRETURN_VAL", - "FullPixelColorDistribution", - "FullPixelDepthDistribution", "MixturePixelColorDistribution", "MixturePixelDepthDistribution", "PixelColorDistribution", "PixelDepthDistribution", "PixelRGBDDistribution", + "FullPixelRGBDDistribution", "UnexplainedPixelDepthDistribution", ] diff --git a/src/b3d/chisight/gen3d/pixel_kernels/pixel_color_kernels.py b/src/b3d/chisight/gen3d/pixel_kernels/pixel_color_kernels.py index 867f84d5..ea5f4e1a 100644 --- a/src/b3d/chisight/gen3d/pixel_kernels/pixel_color_kernels.py +++ b/src/b3d/chisight/gen3d/pixel_kernels/pixel_color_kernels.py @@ -21,22 +21,6 @@ COLOR_MAX_VAL: float = 1.0 -def is_unexplained(latent_value: FloatArray) -> bool: - """ - Check if a given `latent_value` value given to a pixel - indicates that no latent point hits a pixel. - This is done by checking if any of the latent color values - are negative. - - Args: - latent_value (FloatArray): The latent color of the pixel. - - Returns: - bool: True is none of the latent point hits the pixel, False otherwise. - """ - return jnp.any(latent_value < 0.0) - - @Pytree.dataclass class PixelColorDistribution(genjax.ExactDensity): """ @@ -202,77 +186,3 @@ def logpdf_per_channel( ) return jnp.logaddexp(*logprobs) - - -@Pytree.dataclass -class FullPixelColorDistribution(PixelColorDistribution): - """A distribution that generates the color of the pixel according to the - following rule: - - if no latent point hits the pixel: - color ~ uniform(0, 1) - else: - color ~ mixture( - [uniform(0, 1), truncated_laplace(latent_color; color_scale)], - [occluded_prob, 1 - occluded_prob] - ) - - Constructor args: - - Distribution args: - - `latent_color`: 3-array. If no latent point hits the pixel, should contain - 3 negative values. If a latent point hits the pixel, should contain the point's - color as an RGB value in [0, 1]^3. - - color_scale: float. The scale of the truncated Laplace distribution - centered around the latent color used for inlier color observations. - - `color_visibility_prob`: float. If a latent point hits the pixel, should contain - the probability associated with that point that the generated color is - visible (non-occluded). If no latent point hits the pixel, this value is ignored. - - Distribution support: - - An RGB value in [0, 1]^3. - """ - - @property - def _color_from_latent(self) -> PixelColorDistribution: - return MixturePixelColorDistribution() - - @property - def _unexplained_color(self) -> PixelColorDistribution: - return UniformPixelColorDistribution() - - def sample( - self, - key: PRNGKey, - latent_color: FloatArray, - color_scale: FloatArray, - visibility_prob: FloatArray, - ) -> FloatArray: - return jax.lax.cond( - is_unexplained(latent_color), - self._unexplained_color.sample, # if no point hits current pixel - self._color_from_latent.sample, # if pixel is being hit by a latent point - # sample args - key, - latent_color, - color_scale, - visibility_prob, - ) - - def logpdf_per_channel( - self, - observed_color: FloatArray, - latent_color: FloatArray, - color_scale: FloatArray, - visibility_prob: float, - ) -> FloatArray: - return jax.lax.cond( - is_unexplained(latent_color), - self._unexplained_color.logpdf_per_channel, # if no point hits current pixel - self._color_from_latent.logpdf_per_channel, # if pixel is being hit by a latent point - # logpdf args - observed_color, - latent_color, - color_scale, - visibility_prob, - ) diff --git a/src/b3d/chisight/gen3d/pixel_kernels/pixel_depth_kernels.py b/src/b3d/chisight/gen3d/pixel_kernels/pixel_depth_kernels.py index 186e9931..33c6e8e2 100644 --- a/src/b3d/chisight/gen3d/pixel_kernels/pixel_depth_kernels.py +++ b/src/b3d/chisight/gen3d/pixel_kernels/pixel_depth_kernels.py @@ -2,13 +2,11 @@ from typing import TYPE_CHECKING, Any import genjax -import jax import jax.numpy as jnp from genjax import Pytree from genjax.typing import FloatArray, PRNGKey from tensorflow_probability.substrates import jax as tfp -from b3d.chisight.gen3d.pixel_kernels.pixel_color_kernels import is_unexplained from b3d.modeling_utils import ( _FIXED_DEPTH_UNIFORM_WINDOW, PythonMixtureDistribution, @@ -255,76 +253,3 @@ def logpdf( **kwargs, ) -> float: return self._mixture_dist.logpdf(observed_depth, self._mix_ratio, [(), ()]) - - -@Pytree.dataclass -class FullPixelDepthDistribution(PixelDepthDistribution): - """A distribution that generates the depth of the pixel according to the - following rule: - - if no latent point hits the pixel: - depth ~ mixture( - [delta(DEPTH_NONRETURN_VAL), uniform(near, far)], - [unexplained_depth_nonreturn_prob, 1 - unexplained_depth_nonreturn_prob] - ) - else: - mixture( - [delta(DEPTH_NONRETURN_VAL), uniform(near, far), laplace(latent_depth; depth_scale)], - [depth_nonreturn_prob, (1 - depth_nonreturn_prob) * (1 - visibility_prob), remaining_prob] - ) - """ - - near: float = Pytree.static() - far: float = Pytree.static() - - @property - def _depth_from_latent(self) -> PixelDepthDistribution: - return MixturePixelDepthDistribution(self.near, self.far) - - @property - def _unexplained_depth(self) -> PixelDepthDistribution: - return UnexplainedPixelDepthDistribution(self.near, self.far) - - def sample( - self, - key: PRNGKey, - latent_depth: FloatArray, - depth_scale: FloatArray, - visibility_prob: FloatArray, - depth_nonreturn_prob: float, - *args, - **kwargs, - ) -> FloatArray: - return jax.lax.cond( - is_unexplained(latent_depth), - self._unexplained_depth.sample, # if no point hits current pixel - self._depth_from_latent.sample, # if pixel is being hit by a latent point - # sample args - key, - latent_depth, - depth_scale, - visibility_prob, - depth_nonreturn_prob, - ) - - def logpdf( - self, - observed_depth: FloatArray, - latent_depth: FloatArray, - depth_scale: FloatArray, - visibility_prob: float, - depth_nonreturn_prob: float, - *args, - **kwargs, - ) -> FloatArray: - return jax.lax.cond( - is_unexplained(latent_depth), - self._unexplained_depth.logpdf, # if no point hits current pixel - self._depth_from_latent.logpdf, # if pixel is being hit by a latent point - # logpdf args - observed_depth, - latent_depth, - depth_scale, - visibility_prob, - depth_nonreturn_prob, - ) diff --git a/src/b3d/chisight/gen3d/pixel_kernels/pixel_rgbd_kernels.py b/src/b3d/chisight/gen3d/pixel_kernels/pixel_rgbd_kernels.py index af0f0506..510210c6 100644 --- a/src/b3d/chisight/gen3d/pixel_kernels/pixel_rgbd_kernels.py +++ b/src/b3d/chisight/gen3d/pixel_kernels/pixel_rgbd_kernels.py @@ -1,5 +1,6 @@ +from abc import abstractmethod + import genjax -import jax import jax.numpy as jnp from genjax import Pytree from genjax.typing import FloatArray, PRNGKey @@ -8,6 +9,22 @@ from b3d.chisight.gen3d.pixel_kernels.pixel_depth_kernels import PixelDepthDistribution +def is_unexplained(latent_value: FloatArray) -> bool: + """ + Check if a given `latent_value` value given to a pixel + indicates that no latent point hits a pixel. + This is done by checking if any of the latent color values + are negative. + + Args: + latent_value (FloatArray): The latent color of the pixel. + + Returns: + bool: True is none of the latent point hits the pixel, False otherwise. + """ + return jnp.any(latent_value < 0.0) + + @Pytree.dataclass class PixelRGBDDistribution(genjax.ExactDensity): """ @@ -24,8 +41,38 @@ class PixelRGBDDistribution(genjax.ExactDensity): pixel is observed, the logpdf will return -inf. """ - color_kernel: PixelColorDistribution - depth_kernel: PixelDepthDistribution + @abstractmethod + def sample( + self, + key: PRNGKey, + latent_rgbd: FloatArray, + color_scale: float, + depth_scale: float, + visibility_prob: float, + depth_nonreturn_prob: float, + ) -> FloatArray: + raise NotImplementedError + + @abstractmethod + def logpdf( + self, + observed_rgbd: FloatArray, + latent_rgbd: FloatArray, + color_scale: float, + depth_scale: float, + visibility_prob: float, + depth_nonreturn_prob: float, + ) -> float: + raise NotImplementedError + + +@Pytree.dataclass +class FullPixelRGBDDistribution(PixelRGBDDistribution): + inlier_color_distribution: PixelColorDistribution + outlier_color_distribution: PixelColorDistribution + + inlier_depth_distribution: PixelDepthDistribution + outlier_depth_distribution: PixelDepthDistribution def sample( self, @@ -36,14 +83,8 @@ def sample( visibility_prob: float, depth_nonreturn_prob: float, ) -> FloatArray: - keys = jax.random.split(key, 2) - observed_color = self.color_kernel.sample( - keys[0], latent_rgbd[:3], color_scale, visibility_prob - ) - observed_depth = self.depth_kernel.sample( - keys[1], latent_rgbd[3], depth_scale, visibility_prob, depth_nonreturn_prob - ) - return jnp.append(observed_color, observed_depth) + # TODO: Implement this + return jnp.ones((4,)) * 0.5 def logpdf( self, @@ -54,14 +95,50 @@ def logpdf( visibility_prob: float, depth_nonreturn_prob: float, ) -> float: - color_logpdf = self.color_kernel.logpdf( - observed_rgbd[:3], latent_rgbd[:3], color_scale, visibility_prob + total_log_prob = 0.0 + + is_depth_non_return = observed_rgbd[3] == 0.0 + + # Is visible + total_visible_log_prob = 0.0 + # color term + total_visible_log_prob += self.inlier_color_distribution.logpdf( + observed_rgbd[:3], latent_rgbd[:3], color_scale + ) + # depth term + total_visible_log_prob += jnp.where( + is_depth_non_return, + jnp.log(depth_nonreturn_prob), + jnp.log(1 - depth_nonreturn_prob) + + self.inlier_depth_distribution.logpdf( + observed_rgbd[3], latent_rgbd[3], depth_scale + ), + ) + + # Is not visible + total_not_visible_log_prob = 0.0 + # color term + outlier_color_log_prob = self.outlier_color_distribution.logpdf( + observed_rgbd[:3], latent_rgbd[:3], color_scale + ) + outlier_depth_log_prob = self.outlier_depth_distribution.logpdf( + observed_rgbd[3], latent_rgbd[3], depth_scale + ) + + total_not_visible_log_prob += outlier_color_log_prob + # depth term + total_not_visible_log_prob += jnp.where( + is_depth_non_return, + jnp.log(depth_nonreturn_prob), + jnp.log(1 - depth_nonreturn_prob) + outlier_depth_log_prob, + ) + + total_log_prob += jnp.logaddexp( + jnp.log(visibility_prob) + total_visible_log_prob, + jnp.log(1 - visibility_prob) + total_not_visible_log_prob, ) - depth_logpdf = self.depth_kernel.logpdf( - observed_rgbd[3], - latent_rgbd[3], - depth_scale, - visibility_prob, - depth_nonreturn_prob, + return jnp.where( + jnp.any(is_unexplained(latent_rgbd)), + outlier_color_log_prob + outlier_depth_log_prob, + total_log_prob, ) - return color_logpdf + depth_logpdf diff --git a/tests/gen3d/inference/test_depth_nonreturn_prob_inference.py b/tests/gen3d/inference/test_depth_nonreturn_prob_inference.py index 4af3e1db..55455fdb 100644 --- a/tests/gen3d/inference/test_depth_nonreturn_prob_inference.py +++ b/tests/gen3d/inference/test_depth_nonreturn_prob_inference.py @@ -1,32 +1,32 @@ -import b3d.chisight.gen3d.inference_moves as im -import b3d.chisight.gen3d.transition_kernels as transition_kernels -import jax -import jax.numpy as jnp -import jax.random as r -from b3d.chisight.gen3d.pixel_kernels.pixel_depth_kernels import ( - FullPixelDepthDistribution, -) +# import b3d.chisight.gen3d.inference_moves as im +# import b3d.chisight.gen3d.transition_kernels as transition_kernels +# import jax +# import jax.numpy as jnp +# import jax.random as r +# from b3d.chisight.gen3d.pixel_kernels.pixel_depth_kernels import ( +# FullPixelDepthDistribution, +# ) -near, far = 0.001, 1.0 +# near, far = 0.001, 1.0 -dnrp_transition_kernel = transition_kernels.DiscreteFlipKernel( - resample_probability=0.05, support=jnp.array([0.01, 0.99]) -) +# dnrp_transition_kernel = transition_kernels.DiscreteFlipKernel( +# resample_probability=0.05, support=jnp.array([0.01, 0.99]) +# ) -def propose_val(k): - return im._propose_vertex_depth_nonreturn_prob( - k, - observed_depth=0.8, - latent_depth=1.0, - visibility_prob=1.0, - depth_scale=0.00001, - previous_dnrp=0.01, - dnrp_transition_kernel=dnrp_transition_kernel, - obs_depth_kernel=FullPixelDepthDistribution(near, far), - ) +# def propose_val(k): +# return im._propose_vertex_depth_nonreturn_prob( +# k, +# observed_depth=0.8, +# latent_depth=1.0, +# visibility_prob=1.0, +# depth_scale=0.00001, +# previous_dnrp=0.01, +# dnrp_transition_kernel=dnrp_transition_kernel, +# obs_depth_kernel=FullPixelDepthDistribution(near, far), +# ) -values, log_qs, _ = jax.vmap(propose_val)(r.split(r.PRNGKey(0), 1000)) -n_01 = jnp.sum((values == 0.01).astype(jnp.int32)) -assert n_01 >= 950 +# values, log_qs, _ = jax.vmap(propose_val)(r.split(r.PRNGKey(0), 1000)) +# n_01 = jnp.sum((values == 0.01).astype(jnp.int32)) +# assert n_01 >= 950 diff --git a/tests/gen3d/test_pixel_color_kernels.py b/tests/gen3d/test_pixel_color_kernels.py index e861b4b8..bde55af9 100644 --- a/tests/gen3d/test_pixel_color_kernels.py +++ b/tests/gen3d/test_pixel_color_kernels.py @@ -6,7 +6,6 @@ from b3d.chisight.gen3d.pixel_kernels.pixel_color_kernels import ( COLOR_MAX_VAL, COLOR_MIN_VAL, - FullPixelColorDistribution, MixturePixelColorDistribution, TruncatedLaplacePixelColorDistribution, UniformPixelColorDistribution, @@ -43,13 +42,6 @@ def generate_color_grid(n_grid_steps: int): 0.5, ), ), - ( - FullPixelColorDistribution(), - ( - 0.5, - 1 - 0.3, - ), - ), ] @@ -84,31 +76,31 @@ def test_sample_in_valid_color_range(kernel_spec, latent_color): assert jnp.all(colors < 1) -def test_relative_logpdf(): - kernel = FullPixelColorDistribution() - scale = 0.01 - obs_color = jnp.array([0.0, 0.0, 1.0]) # a blue pixel - - # case 1: no color hit the pixel - latent_color = -jnp.ones(3) # use -1 to denote invalid pixel - logpdf_1 = kernel.logpdf(obs_color, latent_color, scale, 0.8) - logpdf_2 = kernel.logpdf(obs_color, latent_color, scale, 0.2) - # the logpdf should be the same because the occluded probability is not used - # in the case when no color hit the pixel - assert jnp.allclose(logpdf_1, logpdf_2) - - # case 2: a color hit the pixel, but the color is not close to the observed color - latent_color = jnp.array([1.0, 0.5, 0.0]) - logpdf_3 = kernel.logpdf(obs_color, latent_color, scale, 0.8) - logpdf_4 = kernel.logpdf(obs_color, latent_color, scale, 0.2) - # the pixel should be more likely to be an occluded - assert logpdf_3 < logpdf_4 - - # case 3: a color hit the pixel, and the color is close to the observed color - latent_color = jnp.array([0.0, 0.0, 0.9]) - logpdf_5 = kernel.logpdf(obs_color, latent_color, 0.01, 0.8) - logpdf_6 = kernel.logpdf(obs_color, latent_color, scale, 0.2) - # the pixel should be more likely to be an inlier - assert logpdf_5 > logpdf_6 - # the score of the pixel should be higher when the color is closer - assert logpdf_5 > logpdf_3 +# def test_relative_logpdf(): +# kernel = FullPixelColorDistribution() +# scale = 0.01 +# obs_color = jnp.array([0.0, 0.0, 1.0]) # a blue pixel + +# # case 1: no color hit the pixel +# latent_color = -jnp.ones(3) # use -1 to denote invalid pixel +# logpdf_1 = kernel.logpdf(obs_color, latent_color, scale, 0.8) +# logpdf_2 = kernel.logpdf(obs_color, latent_color, scale, 0.2) +# # the logpdf should be the same because the occluded probability is not used +# # in the case when no color hit the pixel +# assert jnp.allclose(logpdf_1, logpdf_2) + +# # case 2: a color hit the pixel, but the color is not close to the observed color +# latent_color = jnp.array([1.0, 0.5, 0.0]) +# logpdf_3 = kernel.logpdf(obs_color, latent_color, scale, 0.8) +# logpdf_4 = kernel.logpdf(obs_color, latent_color, scale, 0.2) +# # the pixel should be more likely to be an occluded +# assert logpdf_3 < logpdf_4 + +# # case 3: a color hit the pixel, and the color is close to the observed color +# latent_color = jnp.array([0.0, 0.0, 0.9]) +# logpdf_5 = kernel.logpdf(obs_color, latent_color, 0.01, 0.8) +# logpdf_6 = kernel.logpdf(obs_color, latent_color, scale, 0.2) +# # the pixel should be more likely to be an inlier +# assert logpdf_5 > logpdf_6 +# # the score of the pixel should be higher when the color is closer +# assert logpdf_5 > logpdf_3 diff --git a/tests/gen3d/test_pixel_depth_kernels.py b/tests/gen3d/test_pixel_depth_kernels.py index 03aff1c9..8233620a 100644 --- a/tests/gen3d/test_pixel_depth_kernels.py +++ b/tests/gen3d/test_pixel_depth_kernels.py @@ -3,8 +3,6 @@ import pytest from b3d.chisight.gen3d.pixel_kernels.pixel_depth_kernels import ( DEPTH_NONRETURN_VAL, - UNEXPLAINED_DEPTH_NONRETURN_PROB, - FullPixelDepthDistribution, MixturePixelDepthDistribution, TruncatedLaplacePixelDepthDistribution, UnexplainedPixelDepthDistribution, @@ -27,14 +25,6 @@ 0.23, # depth_nonreturn_prob ), ), - ( - FullPixelDepthDistribution(near, far), - ( - 0.5, # scale - 1 - 0.3, # visibility_prob - 0.1, # depth_nonreturn_prob - ), - ), ] @@ -76,34 +66,34 @@ def test_sample_in_valid_depth_range(kernel_spec, latent_depth): assert jnp.all((depths < far) | (depths == DEPTH_NONRETURN_VAL)) -def test_relative_logpdf(): - kernel = FullPixelDepthDistribution(near, far) - scale = 0.1 +# def test_relative_logpdf(): +# kernel = FullPixelDepthDistribution(near, far) +# scale = 0.1 - # case 1: depth is missing in observation (nonreturn) - obs_depth = DEPTH_NONRETURN_VAL - latent_depth = DEPTH_NONRETURN_VAL - depth_nonreturn_prob = 0.2 - logpdf_1 = kernel.logpdf(obs_depth, latent_depth, scale, 0.8, depth_nonreturn_prob) - assert logpdf_1 == jnp.log(depth_nonreturn_prob) +# # case 1: depth is missing in observation (nonreturn) +# obs_depth = DEPTH_NONRETURN_VAL +# latent_depth = DEPTH_NONRETURN_VAL +# depth_nonreturn_prob = 0.2 +# logpdf_1 = kernel.logpdf(obs_depth, latent_depth, scale, 0.8, depth_nonreturn_prob) +# assert logpdf_1 == jnp.log(depth_nonreturn_prob) - latent_depth = -1.0 # no depth information from latent - logpdf_2 = kernel.logpdf(obs_depth, latent_depth, scale, 0.8, depth_nonreturn_prob) - # nonreturn obs cannot be generates from latent that is not nonreturn - assert logpdf_2 == jnp.log(UNEXPLAINED_DEPTH_NONRETURN_PROB) +# latent_depth = -1.0 # no depth information from latent +# logpdf_2 = kernel.logpdf(obs_depth, latent_depth, scale, 0.8, depth_nonreturn_prob) +# # nonreturn obs cannot be generates from latent that is not nonreturn +# assert logpdf_2 == jnp.log(UNEXPLAINED_DEPTH_NONRETURN_PROB) - # case 2: valid depth is observed, but latent depth is far from the observed depth - obs_depth = 10.0 - latent_depth = 0.01 - logpdf_3 = kernel.logpdf(obs_depth, latent_depth, scale, 0.1, depth_nonreturn_prob) - logpdf_4 = kernel.logpdf(obs_depth, latent_depth, scale, 0.9, depth_nonreturn_prob) - # the pixel should be more likely to be an occluded - assert logpdf_3 > logpdf_4 +# # case 2: valid depth is observed, but latent depth is far from the observed depth +# obs_depth = 10.0 +# latent_depth = 0.01 +# logpdf_3 = kernel.logpdf(obs_depth, latent_depth, scale, 0.1, depth_nonreturn_prob) +# logpdf_4 = kernel.logpdf(obs_depth, latent_depth, scale, 0.9, depth_nonreturn_prob) +# # the pixel should be more likely to be an occluded +# assert logpdf_3 > logpdf_4 - # case 3: valid depth is observed, but latent depth is close from the observed depth - obs_depth = 6.0 - latent_depth = 6.01 - logpdf_5 = kernel.logpdf(obs_depth, latent_depth, scale, 0.1, depth_nonreturn_prob) - logpdf_6 = kernel.logpdf(obs_depth, latent_depth, scale, 0.9, depth_nonreturn_prob) - # the pixel should be more likely to be an inliner - assert logpdf_5 < logpdf_6 +# # case 3: valid depth is observed, but latent depth is close from the observed depth +# obs_depth = 6.0 +# latent_depth = 6.01 +# logpdf_5 = kernel.logpdf(obs_depth, latent_depth, scale, 0.1, depth_nonreturn_prob) +# logpdf_6 = kernel.logpdf(obs_depth, latent_depth, scale, 0.9, depth_nonreturn_prob) +# # the pixel should be more likely to be an inliner +# assert logpdf_5 < logpdf_6 diff --git a/tests/gen3d/test_pixel_rgbd_kernels.py b/tests/gen3d/test_pixel_rgbd_kernels.py index 704e623a..606edf64 100644 --- a/tests/gen3d/test_pixel_rgbd_kernels.py +++ b/tests/gen3d/test_pixel_rgbd_kernels.py @@ -3,9 +3,15 @@ import pytest from b3d.chisight.gen3d.pixel_kernels import ( DEPTH_NONRETURN_VAL, - FullPixelColorDistribution, - FullPixelDepthDistribution, - PixelRGBDDistribution, + FullPixelRGBDDistribution, +) +from b3d.chisight.gen3d.pixel_kernels.pixel_color_kernels import ( + TruncatedLaplacePixelColorDistribution, + UniformPixelColorDistribution, +) +from b3d.chisight.gen3d.pixel_kernels.pixel_depth_kernels import ( + TruncatedLaplacePixelDepthDistribution, + UniformPixelDepthDistribution, ) near = 0.01 @@ -13,9 +19,11 @@ sample_kernels_to_test = [ ( - PixelRGBDDistribution( - FullPixelColorDistribution(), - FullPixelDepthDistribution(near, far), + FullPixelRGBDDistribution( + TruncatedLaplacePixelColorDistribution(), + UniformPixelColorDistribution(), + TruncatedLaplacePixelDepthDistribution(near, far), + UniformPixelDepthDistribution(near, far), ), ( 0.01, # color_scale From 113ffa93b610acc0e025675a9aaee77b584952d0 Mon Sep 17 00:00:00 2001 From: George Matheos Date: Wed, 11 Sep 2024 21:29:00 +0000 Subject: [PATCH 2/3] minor docstring and comment changes --- src/b3d/chisight/gen3d/image_kernel.py | 7 ++++--- .../chisight/gen3d/pixel_kernels/pixel_rgbd_kernels.py | 10 ++++++++++ 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/src/b3d/chisight/gen3d/image_kernel.py b/src/b3d/chisight/gen3d/image_kernel.py index 0fd18801..b3362a27 100644 --- a/src/b3d/chisight/gen3d/image_kernel.py +++ b/src/b3d/chisight/gen3d/image_kernel.py @@ -129,10 +129,11 @@ def logpdf( state["visibility_prob"], state["depth_nonreturn_prob"], ) - # the pixel kernel does not expect invalid observed_rgbd and will return - # -inf if it is invalid. We need to filter those out here. - # (invalid rgbd could happen when the vertex is projected out of the image) + # Points that don't hit the camera plane should not contribute to the score. scores = jnp.where(is_unexplained(observed_rgbd_per_point), 0.0, scores) + + # TODO: add scoring for pixels that are not explained by the latent points + return scores.sum() def get_rgbd_vertex_kernel(self) -> PixelRGBDDistribution: diff --git a/src/b3d/chisight/gen3d/pixel_kernels/pixel_rgbd_kernels.py b/src/b3d/chisight/gen3d/pixel_kernels/pixel_rgbd_kernels.py index 510210c6..c81322ba 100644 --- a/src/b3d/chisight/gen3d/pixel_kernels/pixel_rgbd_kernels.py +++ b/src/b3d/chisight/gen3d/pixel_kernels/pixel_rgbd_kernels.py @@ -68,6 +68,16 @@ def logpdf( @Pytree.dataclass class FullPixelRGBDDistribution(PixelRGBDDistribution): + """ + Args: + - latent_rgbd: 4-array: RGBD value. (a value of [-1, -1, -1, -1] indicates no point hits here.) + - color_scale: float + - depth_scale: float + - visibility_prob: float + + The support of the distribution is [0, 1]^3 x ([near, far] + {DEPTH_NONRETURN_VALUE}). + """ + inlier_color_distribution: PixelColorDistribution outlier_color_distribution: PixelColorDistribution From 5609d715af1cf9b1b3b4d5e6e6fab9c6db7387a8 Mon Sep 17 00:00:00 2001 From: George Matheos Date: Wed, 11 Sep 2024 22:04:45 +0000 Subject: [PATCH 3/3] block all point attrs in inference --- notebooks/bayes3d_paper/tester.ipynb | 124 +++++++-- src/b3d/chisight/gen3d/inference_moves.py | 323 +++++++++------------- 2 files changed, 230 insertions(+), 217 deletions(-) diff --git a/notebooks/bayes3d_paper/tester.ipynb b/notebooks/bayes3d_paper/tester.ipynb index 73396def..5e05220e 100644 --- a/notebooks/bayes3d_paper/tester.ipynb +++ b/notebooks/bayes3d_paper/tester.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -12,7 +12,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -28,7 +28,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -37,7 +37,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -51,7 +51,10 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 49/49 [00:06<00:00, 7.29it/s]\n" + "100%|██████████| 49/49 [00:03<00:00, 13.52it/s]\n", + "/home/georgematheos/b3d/.pixi/envs/gpu/lib/python3.12/site-packages/torch/utils/cpp_extension.py:1967: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. \n", + "If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].\n", + " warnings.warn(\n" ] }, { @@ -62,7 +65,7 @@ "" ] }, - "execution_count": 14, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -97,7 +100,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -125,7 +128,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -134,7 +137,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -156,7 +159,7 @@ }, { "cell_type": "code", - "execution_count": 85, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -167,7 +170,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -216,7 +219,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -225,7 +228,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -244,7 +247,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -265,16 +268,16 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "Array(158743.16, dtype=float32)" + "Array(158884.22, dtype=float32)" ] }, - "execution_count": 23, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -287,16 +290,16 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "Array(158743.16, dtype=float32)" + "Array(158884.22, dtype=float32)" ] }, - "execution_count": 24, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } @@ -307,7 +310,7 @@ }, { "cell_type": "code", - "execution_count": 62, + "execution_count": 39, "metadata": {}, "outputs": [], "source": [ @@ -319,16 +322,16 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 40, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "Array(43753.14, dtype=float32)" + "Array(43455.8, dtype=float32)" ] }, - "execution_count": 26, + "execution_count": 40, "metadata": {}, "output_type": "execute_result" } @@ -352,7 +355,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 41, "metadata": {}, "outputs": [], "source": [ @@ -721,6 +724,75 @@ "metadata[\"p_scores\"][i-3:i+3]" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Array([[0, 5],\n", + " [0, 6],\n", + " [0, 7],\n", + " [1, 5],\n", + " [1, 6],\n", + " [1, 7],\n", + " [2, 5],\n", + " [2, 6],\n", + " [2, 7],\n", + " [3, 5],\n", + " [3, 6],\n", + " [3, 7]], dtype=int32)" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "def all_pairs_2(X, Y):\n", + " return jnp.swapaxes(jnp.stack(jnp.meshgrid(X, Y), axis=-1), 0, 1).reshape(-1, 2)\n", + "\n", + "all_pairs_2(jnp.arange(0, 4), jnp.arange(5, 8))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "code", "execution_count": 84, @@ -1800,7 +1872,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.9" + "version": "3.12.5" } }, "nbformat": 4, diff --git a/src/b3d/chisight/gen3d/inference_moves.py b/src/b3d/chisight/gen3d/inference_moves.py index e77e6772..ce90429e 100644 --- a/src/b3d/chisight/gen3d/inference_moves.py +++ b/src/b3d/chisight/gen3d/inference_moves.py @@ -55,263 +55,194 @@ def propose_other_latents_given_pose(key, advanced_trace, pose, inference_hyperp proposed latents (and the same pose and observed rgbd as in the given trace). `log_q` is (a fair estimate of) the log proposal density. """ - k1, k2, k3, k4, k5 = split(key, 5) + k1, k2, k3, k4 = split(key, 4) trace = update_field(k1, advanced_trace, "pose", pose) k2a, k2b = split(k2) - depth_nonreturn_probs, log_q_dnrps, dnrp_metadata = propose_depth_nonreturn_probs( - k2a, trace - ) - trace = update_vmapped_field( - k2b, trace, "depth_nonreturn_prob", depth_nonreturn_probs + ( + colors, + visibility_probs, + depth_nonreturn_probs, + log_q_point_attributes, + point_proposal_metadata, + ) = propose_all_pointlevel_attributes(k2a, trace, inference_hyperparams) + trace = update_vmapped_fields( + k2b, + trace, + ["colors", "visibility_prob", "depth_nonreturn_prob"], + [colors, visibility_probs, depth_nonreturn_probs], ) + # TODO: debug these scores -- right now they are causing bad behavior + log_q_point_attributes = 0.0 k3a, k3b = split(k3) - colors, visibility_probs, log_q_cvp = propose_colors_and_visibility_probs( - k3a, trace, inference_hyperparams - ) - trace = update_vmapped_fields( - k3b, trace, ["colors", "visibility_prob"], [colors, visibility_probs] - ) - log_q_cvp = 0.0 + depth_scale, log_q_ds = propose_depth_scale(k3a, trace) + trace = update_field(k3b, trace, "depth_scale", depth_scale) k4a, k4b = split(k4) - depth_scale, log_q_ds = propose_depth_scale(k4a, trace) - trace = update_field(k4b, trace, "depth_scale", depth_scale) - - k5a, k5b = split(k5) - color_scale, log_q_cs = propose_color_scale(k5a, trace) - trace = update_field(k5b, trace, "color_scale", color_scale) + color_scale, log_q_cs = propose_color_scale(k4a, trace) + trace = update_field(k4b, trace, "color_scale", color_scale) - log_q = log_q_dnrps + log_q_cvp + log_q_ds + log_q_cs + log_q = log_q_point_attributes + log_q_ds + log_q_cs return ( trace, log_q, - {"depth_nonreturn_proposal": dnrp_metadata, "dnrps": depth_nonreturn_probs}, + {"point_attribute_proposal_metadata": point_proposal_metadata}, ) -def propose_depth_nonreturn_probs(key, trace): +def propose_all_pointlevel_attributes(key, trace, inference_hyperparams): """ - Propose a new depth nonreturn probability for every vertex, conditioned - upon the other values in `trace`. - Returns (depth_nonreturn_probs, log_q) where `depth_nonreturn_probs` is - a vector of shape (n_vertices,) and `log_q` is (a fair estimate of) - the log proposal density of this list of values. - """ - observed_depths_per_points = PixelsPointsAssociation.from_hyperparams_and_pose( - get_hypers(trace), get_new_state(trace)["pose"] - ).get_point_depths(get_observed_rgbd(trace)) - - depth_nonreturn_probs, per_vertex_log_qs, metadata = jax.vmap( - propose_vertex_depth_nonreturn_prob, in_axes=(0, 0, 0, None, None, None) - )( - split(key, get_n_vertices(trace)), - jnp.arange(get_n_vertices(trace)), - observed_depths_per_points, - get_prev_state(trace), - get_new_state(trace), - get_hypers(trace), - ) + Propose a new color, visibility probability, and depth non-return probability + for every vertex, conditioned upon the other values in `trace`. - return depth_nonreturn_probs, per_vertex_log_qs.sum(), metadata - - -def propose_colors_and_visibility_probs(key, trace, inference_hyperparams): - """ - Propose a new color and visibility probability for every vertex, conditioned - upon the other values in `trace`. - Returns (colors, visibility_probs, log_q) where `colors` has shape - (n_vertices, 3), `visibility_probs` is a vector of shape (n_vertices,) - and `log_q` is (a fair estimate of) the log proposal density of these - values. + Returns (colors, visibility_probs, depth_nonreturn_probs, log_q, metadata), + where colors has shape (n_vertices, 3), visibility_probs and depth_nonreturn_probs + have shape (n_vertices,), log_q (a float) is (an estimate of) + the overall log proposal density, and metadata is a dict. """ - observed_rgbds_per_points = PixelsPointsAssociation.from_hyperparams_and_pose( + observed_rgbds_per_point = PixelsPointsAssociation.from_hyperparams_and_pose( get_hypers(trace), get_new_state(trace)["pose"] ).get_point_rgbds(get_observed_rgbd(trace)) - colors, visibility_probs, per_vertex_log_qs = jax.vmap( - propose_vertex_color_and_visibility_prob, - in_axes=(0, 0, 0, None, None, None, None), + colors, visibility_probs, depth_nonreturn_probs, log_qs, metadata = jax.vmap( + propose_a_points_attributes, in_axes=(0, 0, 0, None, None, None, None) )( split(key, get_n_vertices(trace)), jnp.arange(get_n_vertices(trace)), - observed_rgbds_per_points, + observed_rgbds_per_point, get_prev_state(trace), get_new_state(trace), get_hypers(trace), inference_hyperparams, ) - return colors, visibility_probs, per_vertex_log_qs.sum() + return colors, visibility_probs, depth_nonreturn_probs, log_qs.sum(), metadata -def propose_vertex_depth_nonreturn_prob( - key, vertex_index, observed_depth, previous_state, new_state, hyperparams +def propose_a_points_attributes( + key, + vertex_index, + observed_rgbd_for_point, + prev_state, + new_state, + hyperparams, + inference_hyperparams, ): """ - Propose a new depth nonreturn probability for the single vertex - with index `vertex_index`. - Returns (depth_nonreturn_prob, log_q) where `depth_nonreturn_prob` is - the proposed value and `log_q` is (a fair estimate of) the log proposal density. + Propose a new color, visibility probability, and depth non-return probability + for the vertex with index `vertex_index`. + + Returns (color, visibility_prob, depth_nonreturn_prob, log_q, metadata), + where color is a 3-array, visibility_prob and depth_nonreturn_prob are floats, + log_q (a float) is (a fair estimate of) the log proposal density, + and metadata is a dict. """ - previous_dnrp = previous_state["depth_nonreturn_prob"][vertex_index] - visibility_prob = new_state["visibility_prob"][vertex_index] - latent_depth = new_state["pose"].apply(hyperparams["vertices"][vertex_index])[2] - return _propose_vertex_depth_nonreturn_prob( + return _propose_a_points_attributes( key, - observed_depth, - latent_depth, - visibility_prob, - new_state["depth_scale"], - previous_dnrp, - hyperparams["depth_nonreturn_prob_kernel"], - hyperparams["image_kernel"].get_depth_vertex_kernel(), + observed_rgbd=observed_rgbd_for_point, + latent_depth=new_state["pose"].apply(hyperparams["vertices"][vertex_index])[2], + previous_color=prev_state["colors"][vertex_index], + previous_visibility_prob=prev_state["visibility_prob"][vertex_index], + previous_dnrp=prev_state["depth_nonreturn_prob"][vertex_index], + dnrp_transition_kernel=hyperparams["depth_nonreturn_prob_kernel"], + visibility_transition_kernel=hyperparams["visibility_prob_kernel"], + color_kernel=hyperparams["color_kernel"], + obs_rgbd_kernel=hyperparams["image_kernel"].get_rgbd_vertex_kernel(), + color_scale=new_state["color_scale"], + depth_scale=new_state["depth_scale"], + inference_hyperparams=inference_hyperparams, ) -def _propose_vertex_depth_nonreturn_prob( +def _propose_a_points_attributes( key, - observed_depth, + observed_rgbd, latent_depth, - visibility_prob, - depth_scale, + previous_color, + previous_visibility_prob, previous_dnrp, dnrp_transition_kernel, - obs_depth_kernel, - return_metadata=True, -): - def score_dnrp_value(dnrp_value): - transition_score = dnrp_transition_kernel.logpdf(dnrp_value, previous_dnrp) - likelihood_score = obs_depth_kernel.logpdf( - observed_depth=observed_depth, - latent_depth=latent_depth, - depth_scale=depth_scale, - visibility_prob=visibility_prob, - depth_nonreturn_prob=dnrp_value, - ) - return transition_score + likelihood_score - - support = dnrp_transition_kernel.support - log_pscores = jax.vmap(score_dnrp_value)(support) - log_normalized_scores = normalize_log_scores(log_pscores) - index = jax.random.categorical(key, log_normalized_scores) - # ^ since we are enumerating over every value in the domain, it is unnecessary - # to add a 1/q score when resampling. (Equivalently, we could include - # q = 1/len(support), which does not change the resampling distribuiton at all.) - - if return_metadata: - metadata = { - "support": support, - "log_normalized_scores": log_normalized_scores, - "index": index, - "observed_depth": observed_depth, - "latent_depth": latent_depth, - "prev_dnrp": previous_dnrp, - "transition_score": jax.vmap( - lambda dnrp_value: dnrp_transition_kernel.logpdf( - dnrp_value, previous_dnrp - ) - )(support), - "likelihood_score": jax.vmap( - lambda dnrp_value: obs_depth_kernel.logpdf( - observed_depth, - latent_depth, - visibility_prob, - dnrp_value, - depth_scale, - ) - )(support), - } - else: - metadata = {} - - return support[index], log_normalized_scores[index], metadata - - -def propose_vertex_color_and_visibility_prob( - key, - vertex_index, - observed_rgbd_for_this_vertex, - previous_state, - new_state, - hyperparams, + visibility_transition_kernel, + color_kernel, + obs_rgbd_kernel, + color_scale, + depth_scale, inference_hyperparams, + return_metadata=True, ): - """ - Propose a new color and visibility probability for the single vertex - with index `vertex_index`. - Returns (color, visibility_prob, log_q) where `color` and `visibility_prob` - are the proposed values and `log_q` is (a fair estimate of) the log proposal density. - """ k1, k2 = split(key, 2) - previous_rgb = previous_state["colors"][vertex_index] - previous_visibility_prob = previous_state["visibility_prob"][vertex_index] - latent_depth = new_state["pose"].apply(hyperparams["vertices"][vertex_index])[2] - all_vis_probs = hyperparams["visibility_prob_kernel"].support - - def score_visprob_rgb(visprob, rgb): - """ - Compute P(visprob, rgb, observed_rgbd_for_this_vertex | previous_visprob, previous_rgb). - """ - rgb_transition_score = hyperparams["color_kernel"].logpdf(rgb, previous_rgb) - visprob_transition_score = hyperparams["visibility_prob_kernel"].logpdf( + + def score_attribute_assignment(color, visprob, dnrprob): + visprob_transition_score = visibility_transition_kernel.logpdf( visprob, previous_visibility_prob ) - likelihood_score = ( - hyperparams["image_kernel"] - .get_rgbd_vertex_kernel() - .logpdf( - observed_rgbd=observed_rgbd_for_this_vertex, - latent_rgbd=jnp.append(rgb, latent_depth), - color_scale=new_state["color_scale"], - depth_scale=new_state["depth_scale"], - visibility_prob=visprob, - depth_nonreturn_prob=new_state["depth_nonreturn_prob"][vertex_index], - ) + dnrprob_transition_score = dnrp_transition_kernel.logpdf(dnrprob, previous_dnrp) + color_transition_score = color_kernel.logpdf(color, previous_color) + likelihood_score = obs_rgbd_kernel.logpdf( + observed_rgbd=observed_rgbd, + latent_rgbd=jnp.append(color, latent_depth), + color_scale=color_scale, + depth_scale=depth_scale, + visibility_prob=visprob, + depth_nonreturn_prob=dnrprob, + ) + return ( + visprob_transition_score + + dnrprob_transition_score + + color_transition_score + + likelihood_score ) - return rgb_transition_score + visprob_transition_score + likelihood_score - # Propose a rgb value for each visprob. - # `rgbs` has shape (len(all_vis_probs), 3). - # `log_qs_rgb` has shape (len(all_vis_probs),). + # Say there are V values in visibility_transition_kernel.support + # and D values in dnrp_transition_kernel.support. + + # (D*V, 2) array of all pairs of values in the support of the two kernels. + all_visprob_dnrprob_pairs = all_pairs( + visibility_transition_kernel.support, dnrp_transition_kernel.support + ) + + # Propose a color for each visprob-dnrprob pair. rgbs, log_qs_rgb = jax.vmap( - lambda k, visprob: propose_vertex_color_given_visibility( - k, - visprob, - observed_rgbd_for_this_vertex[:3], - score_visprob_rgb, - previous_rgb, - new_state, - inference_hyperparams, + lambda k, visprob_dnrprob_pair: propose_vertex_color_given_other_attributes( + key=k, + visprob=visprob_dnrprob_pair[0], + dnrprob=visprob_dnrprob_pair[1], + observed_rgb=observed_rgbd[:3], + score_attribute_assignment=score_attribute_assignment, + previous_rgb=previous_color, + color_scale=color_scale, + inference_hyperparams=inference_hyperparams, ) - )(split(k1, len(all_vis_probs)), all_vis_probs) + )(split(k1, len(all_visprob_dnrprob_pairs)), all_visprob_dnrprob_pairs) - # shape: (len(all_vis_probs),) - log_pscores = jax.vmap(score_visprob_rgb, in_axes=(0, 0))(all_vis_probs, rgbs) + log_pscores = jax.vmap( + lambda visprob_dnrprob_pair, rgb: score_attribute_assignment( + rgb, visprob_dnrprob_pair[0], visprob_dnrprob_pair[1] + ), + in_axes=(0, 0), + )(all_visprob_dnrprob_pairs, rgbs) - # We don't need to subtract a q score for the visibility probability, since - # we are enumerating over every value in the domain. (Equivalently, - # we could subtract a log q score of log(1/len(support)) for each value.) log_weights = log_pscores - log_qs_rgb log_normalized_scores = normalize_log_scores(log_weights) index = jax.random.categorical(k2, log_normalized_scores) rgb = rgbs[index] - visibility_prob = all_vis_probs[index] + visibility_prob, dnr_prob = all_visprob_dnrprob_pairs[index] log_q_score = log_normalized_scores[index] + log_qs_rgb[index] - return rgb, visibility_prob, log_q_score + return rgb, visibility_prob, dnr_prob, log_q_score, {} -def propose_vertex_color_given_visibility( +def propose_vertex_color_given_other_attributes( key, visprob, + dnrprob, observed_rgb, - score_visprob_and_rgb, + score_attribute_assignment, previous_rgb, - new_state, + color_scale, inference_hyperparams, ): """ @@ -344,7 +275,6 @@ def propose_vertex_color_given_visibility( propose traces that match that part of the posterior. """ color_shift_scale = inference_hyperparams.effective_color_transition_scale - color_scale = new_state["color_scale"] d = 1 / (1 / color_shift_scale + 1 / color_scale) r_diff = jnp.abs(previous_rgb[0] - observed_rgb[0]) @@ -379,7 +309,9 @@ def propose_vertex_color_given_visibility( log_qs = jnp.array([log_q_rgb_1, log_q_rgb_2, log_q_rgb_3]) scores = ( - jax.vmap(lambda rgb: score_visprob_and_rgb(visprob, rgb))(proposed_rgbs) + jax.vmap(lambda rgb: score_attribute_assignment(rgb, visprob, dnrprob))( + proposed_rgbs + ) - log_qs ) normalized_scores = normalize_log_scores(scores) @@ -517,3 +449,12 @@ def update_vmapped_field(key, trace, fieldname, value): For information, see `update_vmapped_fields`. """ return update_vmapped_fields(key, trace, [fieldname], [value]) + + +def all_pairs(X, Y): + """ + Return an array `ret` of shape (|X| * |Y|, 2) where each row + is a pair of values from X and Y. + That is, `ret[i, :]` is a pair [x, y] for some x in X and y in Y. + """ + return jnp.swapaxes(jnp.stack(jnp.meshgrid(X, Y), axis=-1), 0, 1).reshape(-1, 2)