Skip to content

Commit

Permalink
Refactor RGBD Pixel Kernel (#161)
Browse files Browse the repository at this point in the history
Co-authored-by: George Matheos <[email protected]>
  • Loading branch information
nishadgothoskar and georgematheos authored Sep 11, 2024
1 parent ac49e93 commit 9fd1edd
Show file tree
Hide file tree
Showing 11 changed files with 457 additions and 530 deletions.
124 changes: 98 additions & 26 deletions notebooks/bayes3d_paper/tester.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -12,7 +12,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -28,7 +28,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -37,7 +37,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 4,
"metadata": {},
"outputs": [
{
Expand All @@ -51,7 +51,10 @@
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 49/49 [00:06<00:00, 7.29it/s]\n"
"100%|██████████| 49/49 [00:03<00:00, 13.52it/s]\n",
"/home/georgematheos/b3d/.pixi/envs/gpu/lib/python3.12/site-packages/torch/utils/cpp_extension.py:1967: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. \n",
"If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].\n",
" warnings.warn(\n"
]
},
{
Expand All @@ -62,7 +65,7 @@
"<PIL.Image.Image image mode=RGB size=640x480>"
]
},
"execution_count": 14,
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -97,7 +100,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -125,7 +128,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -134,7 +137,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -156,7 +159,7 @@
},
{
"cell_type": "code",
"execution_count": 85,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -167,7 +170,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -216,7 +219,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -225,7 +228,7 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -244,7 +247,7 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -265,16 +268,16 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Array(158743.16, dtype=float32)"
"Array(158884.22, dtype=float32)"
]
},
"execution_count": 23,
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -287,16 +290,16 @@
},
{
"cell_type": "code",
"execution_count": 24,
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Array(158743.16, dtype=float32)"
"Array(158884.22, dtype=float32)"
]
},
"execution_count": 24,
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -307,7 +310,7 @@
},
{
"cell_type": "code",
"execution_count": 62,
"execution_count": 39,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -319,16 +322,16 @@
},
{
"cell_type": "code",
"execution_count": 26,
"execution_count": 40,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Array(43753.14, dtype=float32)"
"Array(43455.8, dtype=float32)"
]
},
"execution_count": 26,
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -352,7 +355,7 @@
},
{
"cell_type": "code",
"execution_count": 27,
"execution_count": 41,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -721,6 +724,75 @@
"metadata[\"p_scores\"][i-3:i+3]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Array([[0, 5],\n",
" [0, 6],\n",
" [0, 7],\n",
" [1, 5],\n",
" [1, 6],\n",
" [1, 7],\n",
" [2, 5],\n",
" [2, 6],\n",
" [2, 7],\n",
" [3, 5],\n",
" [3, 6],\n",
" [3, 7]], dtype=int32)"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def all_pairs_2(X, Y):\n",
" return jnp.swapaxes(jnp.stack(jnp.meshgrid(X, Y), axis=-1), 0, 1).reshape(-1, 2)\n",
"\n",
"all_pairs_2(jnp.arange(0, 4), jnp.arange(5, 8))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 84,
Expand Down Expand Up @@ -1800,7 +1872,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
"version": "3.12.5"
}
},
"nbformat": 4,
Expand Down
36 changes: 19 additions & 17 deletions src/b3d/chisight/gen3d/image_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,18 @@
from genjax import Pytree
from genjax.typing import FloatArray, PRNGKey

from b3d.chisight.gen3d.pixel_kernels import (
FullPixelColorDistribution,
FullPixelDepthDistribution,
PixelDepthDistribution,
from b3d.chisight.gen3d.pixel_kernels import is_unexplained
from b3d.chisight.gen3d.pixel_kernels.pixel_color_kernels import (
TruncatedLaplacePixelColorDistribution,
UniformPixelColorDistribution,
)
from b3d.chisight.gen3d.pixel_kernels.pixel_depth_kernels import (
TruncatedLaplacePixelDepthDistribution,
UniformPixelDepthDistribution,
)
from b3d.chisight.gen3d.pixel_kernels.pixel_rgbd_kernels import (
FullPixelRGBDDistribution,
PixelRGBDDistribution,
is_unexplained,
)
from b3d.chisight.gen3d.projection import PixelsPointsAssociation

Expand Down Expand Up @@ -51,9 +57,6 @@ def logpdf(
) -> FloatArray:
raise NotImplementedError

def get_depth_vertex_kernel(self) -> PixelDepthDistribution:
raise NotImplementedError

def get_rgbd_vertex_kernel(self) -> PixelRGBDDistribution:
raise NotImplementedError

Expand Down Expand Up @@ -126,22 +129,21 @@ def logpdf(
state["visibility_prob"],
state["depth_nonreturn_prob"],
)
# the pixel kernel does not expect invalid observed_rgbd and will return
# -inf if it is invalid. We need to filter those out here.
# (invalid rgbd could happen when the vertex is projected out of the image)
# Points that don't hit the camera plane should not contribute to the score.
scores = jnp.where(is_unexplained(observed_rgbd_per_point), 0.0, scores)

# TODO: add scoring for pixels that are not explained by the latent points

return scores.sum()

def get_rgbd_vertex_kernel(self) -> PixelRGBDDistribution:
# Note: The distributions were originally defined for per-pixel computation,
# but they should work for per-vertex computation as well, except that
# they don't expect observed_rgbd to be invalid, so we need to handle
# that manually.
return PixelRGBDDistribution(
FullPixelColorDistribution(),
FullPixelDepthDistribution(self.near, self.far),
return FullPixelRGBDDistribution(
TruncatedLaplacePixelColorDistribution(),
UniformPixelColorDistribution(),
TruncatedLaplacePixelDepthDistribution(self.near, self.far),
UniformPixelDepthDistribution(self.near, self.far),
)

def get_depth_vertex_kernel(self) -> PixelDepthDistribution:
return self.get_rgbd_vertex_kernel().depth_kernel
Loading

0 comments on commit 9fd1edd

Please sign in to comment.