Skip to content

Commit

Permalink
Inference Unit Tests - Part 1 (#163)
Browse files Browse the repository at this point in the history
Co-authored-by: George Matheos <[email protected]>
  • Loading branch information
nishadgothoskar and georgematheos authored Sep 12, 2024
1 parent 9fd1edd commit 6ab90c2
Show file tree
Hide file tree
Showing 5 changed files with 370 additions and 226 deletions.
298 changes: 87 additions & 211 deletions notebooks/bayes3d_paper/run_ycbv_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,18 @@ def run_tracking(scene=None, object=None, debug=False):
import os

import b3d
import b3d.chisight.gen3d.image_kernel as image_kernel
import b3d.chisight.gen3d.transition_kernels as transition_kernels
import genjax
import jax
import jax.numpy as jnp
from b3d import Mesh, Pose
from b3d.chisight.gen3d.model import (
dynamic_object_generative_model,
make_colors_choicemap,
make_depth_nonreturn_prob_choicemap,
make_visibility_prob_choicemap,
)
from genjax import Pytree
from tqdm import tqdm

Expand All @@ -32,6 +40,27 @@ def run_tracking(scene=None, object=None, debug=False):
elif isinstance(scene, list):
scenes = scene

hyperparams = {
"pose_kernel": transition_kernels.UniformPoseDriftKernel(max_shift=0.1),
"color_kernel": transition_kernels.LaplaceNotTruncatedColorDriftKernel(
scale=0.15
),
"visibility_prob_kernel": transition_kernels.DiscreteFlipKernel(
resample_probability=0.05, possible_values=jnp.array([0.01, 0.99])
),
"depth_nonreturn_prob_kernel": transition_kernels.DiscreteFlipKernel(
resample_probability=0.05, possible_values=jnp.array([0.01, 0.99])
),
"depth_scale_kernel": transition_kernels.DiscreteFlipKernel(
resample_probability=0.05, possible_values=jnp.array([0.0025, 0.01, 0.02])
),
"color_scale_kernel": transition_kernels.DiscreteFlipKernel(
resample_probability=0.05, possible_values=jnp.array([0.05, 0.1, 0.15])
),
"image_likelihood": image_kernel.SimpleNoRenderImageLikelihood(),
}
info_from_trace = hyperparams["image_likelihood"].info_from_trace

for scene_id in scenes:
print(f"Scene {scene_id}")
num_scenes = b3d.io.data_loader.get_ycbv_num_test_images(ycb_dir, scene_id)
Expand Down Expand Up @@ -61,98 +90,9 @@ def run_tracking(scene=None, object=None, debug=False):
4.0,
)

def grid_outlier_prob(trace, values):
return jax.vmap(
lambda x: info_from_trace(
b3d.update_choices(trace, Pytree.const(("outlier_probability",)), x)
)["scores"]
)(values)

@jax.jit
def update_pose_and_color(trace, address, pose):
trace = b3d.update_choices(trace, address, pose)
info = info_from_trace(trace)
current_outlier_probabilities = trace.get_choices()["outlier_probability"]
model_rgbd, observed_rgbd = (
info["model_rgbd"],
info["corresponding_observed_rgbd"],
)
deltas = (observed_rgbd - model_rgbd)[..., :3]
deltas_clipped = jnp.clip(deltas, -0.1, 0.1)

mesh = trace.get_args()[0]["meshes"][0]
is_inlier = current_outlier_probabilities == outlier_probability_sweep[0]
mesh.vertex_attributes = (
mesh.vertex_attributes + deltas_clipped * is_inlier[..., None]
)

trace, _ = importance_jit(
jax.random.PRNGKey(2),
trace.get_choices(),
(
{
"num_objects": Pytree.const(1),
"meshes": [mesh],
"likelihood_args": likelihood_args,
},
),
)
return trace.get_score()

def _gvmf_and_select_best_move(
trace, key, variance, concentration, address, number
):
test_poses = Pose.concatenate_poses(
[
jax.vmap(
Pose.sample_gaussian_vmf_pose, in_axes=(0, None, None, None)
)(
jax.random.split(key, number),
trace.get_choices()[address.const],
variance,
concentration,
),
trace.get_choices()[address.const][None, ...],
]
)
scores = jax.vmap(update_pose_and_color, in_axes=(None, None, 0))(
trace, address, test_poses
)
trace = b3d.update_choices(
trace,
address,
test_poses[scores.argmax()],
)
key = jax.random.split(key, 2)[-1]
return trace, key

gvmf_and_select_best_move = jax.jit(
_gvmf_and_select_best_move, static_argnames=["number"]
)

from b3d.chisight.dense.likelihoods.simplified_rendering_laplace_likelihood import (
simplified_rendering_laplace_likelihood,
)

model, viz_trace, info_from_trace = (
b3d.chisight.dense.dense_model.make_dense_multiobject_model(
None, simplified_rendering_laplace_likelihood
)
)
importance_jit = jax.jit(model.importance)

# initial_camera_pose = all_data[0]["camera_pose"]
initial_object_poses = all_data[0]["object_poses"]

likelihood_args = {
"fx": fx,
"fy": fy,
"cx": cx,
"cy": cy,
"image_height": Pytree.const(image_height),
"image_width": Pytree.const(image_width),
}

object_indices = (
[object] if object is not None else range(len(initial_object_poses))
)
Expand All @@ -178,140 +118,76 @@ def _gvmf_and_select_best_move(
* (xyz_observed[..., 2] > 0)
* (jnp.linalg.norm(xyz_rendered - xyz_observed, axis=-1) < 0.01)
)
mesh = Mesh(
vertices=template_pose.inv().apply(xyz_rendered[mask]),
faces=jnp.zeros((0, 3), dtype=jnp.int32),
vertex_attributes=all_data[T]["rgbd"][..., :3][mask],
)

outlier_probability_sweep = jnp.array([0.05, 1.0])

choicemap = genjax.ChoiceMap.d(
{
"rgbd": all_data[T]["rgbd"],
"camera_pose": Pose.identity(),
"object_pose_0": template_pose,
"outlier_probability": jnp.ones(len(mesh.vertices))
* outlier_probability_sweep[0],
"color_noise_variance": 0.05,
"depth_noise_variance": 0.01,
}
)

trace0, _ = importance_jit(
jax.random.PRNGKey(2),
choicemap,
(
model_vertices = template_pose.inv().apply(xyz_rendered[mask])
model_colors = all_data[T]["rgbd"][..., :3][mask]

subset = jax.random.permutation(jax.random.PRNGKey(0), len(model_vertices))[
: min(10000, len(model_vertices))
]
model_vertices = model_vertices[subset]
model_colors = model_colors[subset]

num_vertices = model_vertices.shape[0]
previous_state = {
"pose": template_pose,
"colors": model_colors,
"visibility_prob": jnp.ones(num_vertices)
* hyperparams["visibility_prob_kernel"].possible_values[-1],
"depth_nonreturn_prob": jnp.ones(num_vertices)
* hyperparams["depth_nonreturn_prob_kernel"].possible_values[0],
"depth_scale": hyperparams["depth_scale_kernel"].possible_values[0],
"color_scale": hyperparams["color_scale_kernel"].possible_values[0],
}

hyperparams["vertices"] = model_vertices
hyperparams["fx"] = fx
hyperparams["fy"] = fy
hyperparams["cx"] = cx
hyperparams["cy"] = cy
hyperparams["image_height"] = Pytree.const(image_height)
hyperparams["image_width"] = Pytree.const(image_width)
choicemap = (
genjax.ChoiceMap.d(
{
"num_objects": Pytree.const(1),
"meshes": [mesh],
"likelihood_args": likelihood_args,
},
),
"pose": previous_state["pose"],
"color_scale": previous_state["color_scale"],
"depth_scale": previous_state["depth_scale"],
"rgbd": all_data[T]["rgbd"],
}
)
^ make_visibility_prob_choicemap(previous_state["visibility_prob"])
^ make_colors_choicemap(previous_state["colors"])
^ make_depth_nonreturn_prob_choicemap(
previous_state["depth_nonreturn_prob"]
)
)
key = jax.random.PRNGKey(100)
key = jax.random.PRNGKey(0)

trace = trace0
tracking_results = {}
for T in tqdm(range(len(all_data))):
trace = b3d.update_choices(
trace,
Pytree.const(("rgbd",)),
all_data[T]["rgbd"],
)
trace = dynamic_object_generative_model.importance(
key, choicemap, (hyperparams, previous_state)
)[0]

for _ in range(5):
trace, key = gvmf_and_select_best_move(
trace,
key,
0.01,
1000.0,
Pytree.const(("object_pose_0",)),
10000,
)
trace, key = gvmf_and_select_best_move(
trace,
key,
0.005,
2000.0,
Pytree.const(("object_pose_0",)),
10000,
)
# viz_trace(trace, T)
from b3d.chisight.gen3d.inference import inference_step

if T % 1 == 0:
trace = b3d.bayes3d.enumerative_proposals.enumerate_and_select_best(
trace,
Pytree.const(("color_noise_variance",)),
jnp.linspace(0.05, 0.1, 10),
)
trace = b3d.bayes3d.enumerative_proposals.enumerate_and_select_best(
trace,
Pytree.const(("depth_noise_variance",)),
jnp.linspace(0.005, 0.01, 10),
)
### Run inference ###
for T in tqdm(range(len(all_data))):
key = b3d.split_key(key)
trace = inference_step(trace, key, all_data[T]["rgbd"])
tracking_results[T] = trace

current_outlier_probabilities = trace.get_choices()[
"outlier_probability"
]
scores = grid_outlier_prob(
trace,
outlier_probability_sweep[..., None]
* jnp.ones_like(current_outlier_probabilities),
)
trace = b3d.update_choices(
if debug:
b3d.chisight.gen3d.model.viz_trace(
trace,
Pytree.const(("outlier_probability",)),
outlier_probability_sweep[jnp.argmax(scores, axis=0)],
)

current_outlier_probabilities = trace.get_choices()[
"outlier_probability"
]
# b3d.rr_log_cloud(
# mesh.vertices,
# colors=colors[1 * (current_outlier_probabilities == outlier_probability_sweep[0])],
# channel="cloud/outlier_probabilities"
# )

info = info_from_trace(trace)
current_outlier_probabilities = trace.get_choices()[
"outlier_probability"
]
model_rgbd, observed_rgbd = (
info["model_rgbd"],
info["corresponding_observed_rgbd"],
T,
ground_truth_vertices=meshes[OBJECT_INDEX].vertices,
ground_truth_pose=all_data[T]["camera_pose"].inv()
@ all_data[T]["object_poses"][OBJECT_INDEX],
)
deltas = observed_rgbd - model_rgbd
deltas_clipped = jnp.clip(deltas, -0.05, 0.05)
new_model_rgbd = model_rgbd + deltas_clipped

mesh = trace.get_args()[0]["meshes"][0]
is_inlier = (
current_outlier_probabilities == outlier_probability_sweep[0]
)
mesh.vertex_attributes = mesh.vertex_attributes.at[is_inlier].set(
new_model_rgbd[is_inlier, :3]
)

trace, _ = importance_jit(
jax.random.PRNGKey(2),
trace.get_choices(),
(
{
"num_objects": Pytree.const(1),
"meshes": [mesh],
"likelihood_args": likelihood_args,
},
),
)
tracking_results[T] = trace
if debug:
viz_trace(trace, T)

inferred_poses = Pose.stack_poses(
[
tracking_results[t].get_choices()["object_pose_0"]
tracking_results[t].get_choices()["pose"]
for t in range(len(all_data))
]
)
Expand Down
8 changes: 4 additions & 4 deletions src/b3d/chisight/gen3d/inference_moves.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def propose_a_points_attributes(
"""
return _propose_a_points_attributes(
key,
observed_rgbd=observed_rgbd_for_point,
observed_rgbd_for_point=observed_rgbd_for_point,
latent_depth=new_state["pose"].apply(hyperparams["vertices"][vertex_index])[2],
previous_color=prev_state["colors"][vertex_index],
previous_visibility_prob=prev_state["visibility_prob"][vertex_index],
Expand All @@ -158,7 +158,7 @@ def propose_a_points_attributes(

def _propose_a_points_attributes(
key,
observed_rgbd,
observed_rgbd_for_point,
latent_depth,
previous_color,
previous_visibility_prob,
Expand All @@ -181,7 +181,7 @@ def score_attribute_assignment(color, visprob, dnrprob):
dnrprob_transition_score = dnrp_transition_kernel.logpdf(dnrprob, previous_dnrp)
color_transition_score = color_kernel.logpdf(color, previous_color)
likelihood_score = obs_rgbd_kernel.logpdf(
observed_rgbd=observed_rgbd,
observed_rgbd=observed_rgbd_for_point,
latent_rgbd=jnp.append(color, latent_depth),
color_scale=color_scale,
depth_scale=depth_scale,
Expand Down Expand Up @@ -209,7 +209,7 @@ def score_attribute_assignment(color, visprob, dnrprob):
key=k,
visprob=visprob_dnrprob_pair[0],
dnrprob=visprob_dnrprob_pair[1],
observed_rgb=observed_rgbd[:3],
observed_rgb=observed_rgbd_for_point[:3],
score_attribute_assignment=score_attribute_assignment,
previous_rgb=previous_color,
color_scale=color_scale,
Expand Down
3 changes: 3 additions & 0 deletions src/b3d/chisight/gen3d/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,3 +181,6 @@ def viz_trace(trace, t=0, ground_truth_vertices=None, ground_truth_pose=None):
ground_truth_pose.apply(ground_truth_vertices),
"scene/ground_truth_object_mesh",
)

b3d.rr_log_pose(ground_truth_pose, "scene/ground_truth_pose")
b3d.rr_log_pose(trace.get_choices()["pose"], "scene/inferred_pose")
Loading

0 comments on commit 6ab90c2

Please sign in to comment.