Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference Unit Tests - Part 1 #163

Merged
merged 2 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
298 changes: 87 additions & 211 deletions notebooks/bayes3d_paper/run_ycbv_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,18 @@ def run_tracking(scene=None, object=None, debug=False):
import os

import b3d
import b3d.chisight.gen3d.image_kernel as image_kernel
import b3d.chisight.gen3d.transition_kernels as transition_kernels
import genjax
import jax
import jax.numpy as jnp
from b3d import Mesh, Pose
from b3d.chisight.gen3d.model import (
dynamic_object_generative_model,
make_colors_choicemap,
make_depth_nonreturn_prob_choicemap,
make_visibility_prob_choicemap,
)
from genjax import Pytree
from tqdm import tqdm

Expand All @@ -32,6 +40,27 @@ def run_tracking(scene=None, object=None, debug=False):
elif isinstance(scene, list):
scenes = scene

hyperparams = {
"pose_kernel": transition_kernels.UniformPoseDriftKernel(max_shift=0.1),
"color_kernel": transition_kernels.LaplaceNotTruncatedColorDriftKernel(
scale=0.15
),
"visibility_prob_kernel": transition_kernels.DiscreteFlipKernel(
resample_probability=0.05, possible_values=jnp.array([0.01, 0.99])
),
"depth_nonreturn_prob_kernel": transition_kernels.DiscreteFlipKernel(
resample_probability=0.05, possible_values=jnp.array([0.01, 0.99])
),
"depth_scale_kernel": transition_kernels.DiscreteFlipKernel(
resample_probability=0.05, possible_values=jnp.array([0.0025, 0.01, 0.02])
),
"color_scale_kernel": transition_kernels.DiscreteFlipKernel(
resample_probability=0.05, possible_values=jnp.array([0.05, 0.1, 0.15])
),
"image_likelihood": image_kernel.SimpleNoRenderImageLikelihood(),
}
info_from_trace = hyperparams["image_likelihood"].info_from_trace

for scene_id in scenes:
print(f"Scene {scene_id}")
num_scenes = b3d.io.data_loader.get_ycbv_num_test_images(ycb_dir, scene_id)
Expand Down Expand Up @@ -61,98 +90,9 @@ def run_tracking(scene=None, object=None, debug=False):
4.0,
)

def grid_outlier_prob(trace, values):
return jax.vmap(
lambda x: info_from_trace(
b3d.update_choices(trace, Pytree.const(("outlier_probability",)), x)
)["scores"]
)(values)

@jax.jit
def update_pose_and_color(trace, address, pose):
trace = b3d.update_choices(trace, address, pose)
info = info_from_trace(trace)
current_outlier_probabilities = trace.get_choices()["outlier_probability"]
model_rgbd, observed_rgbd = (
info["model_rgbd"],
info["corresponding_observed_rgbd"],
)
deltas = (observed_rgbd - model_rgbd)[..., :3]
deltas_clipped = jnp.clip(deltas, -0.1, 0.1)

mesh = trace.get_args()[0]["meshes"][0]
is_inlier = current_outlier_probabilities == outlier_probability_sweep[0]
mesh.vertex_attributes = (
mesh.vertex_attributes + deltas_clipped * is_inlier[..., None]
)

trace, _ = importance_jit(
jax.random.PRNGKey(2),
trace.get_choices(),
(
{
"num_objects": Pytree.const(1),
"meshes": [mesh],
"likelihood_args": likelihood_args,
},
),
)
return trace.get_score()

def _gvmf_and_select_best_move(
trace, key, variance, concentration, address, number
):
test_poses = Pose.concatenate_poses(
[
jax.vmap(
Pose.sample_gaussian_vmf_pose, in_axes=(0, None, None, None)
)(
jax.random.split(key, number),
trace.get_choices()[address.const],
variance,
concentration,
),
trace.get_choices()[address.const][None, ...],
]
)
scores = jax.vmap(update_pose_and_color, in_axes=(None, None, 0))(
trace, address, test_poses
)
trace = b3d.update_choices(
trace,
address,
test_poses[scores.argmax()],
)
key = jax.random.split(key, 2)[-1]
return trace, key

gvmf_and_select_best_move = jax.jit(
_gvmf_and_select_best_move, static_argnames=["number"]
)

from b3d.chisight.dense.likelihoods.simplified_rendering_laplace_likelihood import (
simplified_rendering_laplace_likelihood,
)

model, viz_trace, info_from_trace = (
b3d.chisight.dense.dense_model.make_dense_multiobject_model(
None, simplified_rendering_laplace_likelihood
)
)
importance_jit = jax.jit(model.importance)

# initial_camera_pose = all_data[0]["camera_pose"]
initial_object_poses = all_data[0]["object_poses"]

likelihood_args = {
"fx": fx,
"fy": fy,
"cx": cx,
"cy": cy,
"image_height": Pytree.const(image_height),
"image_width": Pytree.const(image_width),
}

object_indices = (
[object] if object is not None else range(len(initial_object_poses))
)
Expand All @@ -178,140 +118,76 @@ def _gvmf_and_select_best_move(
* (xyz_observed[..., 2] > 0)
* (jnp.linalg.norm(xyz_rendered - xyz_observed, axis=-1) < 0.01)
)
mesh = Mesh(
vertices=template_pose.inv().apply(xyz_rendered[mask]),
faces=jnp.zeros((0, 3), dtype=jnp.int32),
vertex_attributes=all_data[T]["rgbd"][..., :3][mask],
)

outlier_probability_sweep = jnp.array([0.05, 1.0])

choicemap = genjax.ChoiceMap.d(
{
"rgbd": all_data[T]["rgbd"],
"camera_pose": Pose.identity(),
"object_pose_0": template_pose,
"outlier_probability": jnp.ones(len(mesh.vertices))
* outlier_probability_sweep[0],
"color_noise_variance": 0.05,
"depth_noise_variance": 0.01,
}
)

trace0, _ = importance_jit(
jax.random.PRNGKey(2),
choicemap,
(
model_vertices = template_pose.inv().apply(xyz_rendered[mask])
model_colors = all_data[T]["rgbd"][..., :3][mask]

subset = jax.random.permutation(jax.random.PRNGKey(0), len(model_vertices))[
: min(10000, len(model_vertices))
]
model_vertices = model_vertices[subset]
model_colors = model_colors[subset]

num_vertices = model_vertices.shape[0]
previous_state = {
"pose": template_pose,
"colors": model_colors,
"visibility_prob": jnp.ones(num_vertices)
* hyperparams["visibility_prob_kernel"].possible_values[-1],
"depth_nonreturn_prob": jnp.ones(num_vertices)
* hyperparams["depth_nonreturn_prob_kernel"].possible_values[0],
"depth_scale": hyperparams["depth_scale_kernel"].possible_values[0],
"color_scale": hyperparams["color_scale_kernel"].possible_values[0],
}

hyperparams["vertices"] = model_vertices
hyperparams["fx"] = fx
hyperparams["fy"] = fy
hyperparams["cx"] = cx
hyperparams["cy"] = cy
hyperparams["image_height"] = Pytree.const(image_height)
hyperparams["image_width"] = Pytree.const(image_width)
choicemap = (
genjax.ChoiceMap.d(
{
"num_objects": Pytree.const(1),
"meshes": [mesh],
"likelihood_args": likelihood_args,
},
),
"pose": previous_state["pose"],
"color_scale": previous_state["color_scale"],
"depth_scale": previous_state["depth_scale"],
"rgbd": all_data[T]["rgbd"],
}
)
^ make_visibility_prob_choicemap(previous_state["visibility_prob"])
^ make_colors_choicemap(previous_state["colors"])
^ make_depth_nonreturn_prob_choicemap(
previous_state["depth_nonreturn_prob"]
)
)
key = jax.random.PRNGKey(100)
key = jax.random.PRNGKey(0)

trace = trace0
tracking_results = {}
for T in tqdm(range(len(all_data))):
trace = b3d.update_choices(
trace,
Pytree.const(("rgbd",)),
all_data[T]["rgbd"],
)
trace = dynamic_object_generative_model.importance(
key, choicemap, (hyperparams, previous_state)
)[0]

for _ in range(5):
trace, key = gvmf_and_select_best_move(
trace,
key,
0.01,
1000.0,
Pytree.const(("object_pose_0",)),
10000,
)
trace, key = gvmf_and_select_best_move(
trace,
key,
0.005,
2000.0,
Pytree.const(("object_pose_0",)),
10000,
)
# viz_trace(trace, T)
from b3d.chisight.gen3d.inference import inference_step

if T % 1 == 0:
trace = b3d.bayes3d.enumerative_proposals.enumerate_and_select_best(
trace,
Pytree.const(("color_noise_variance",)),
jnp.linspace(0.05, 0.1, 10),
)
trace = b3d.bayes3d.enumerative_proposals.enumerate_and_select_best(
trace,
Pytree.const(("depth_noise_variance",)),
jnp.linspace(0.005, 0.01, 10),
)
### Run inference ###
for T in tqdm(range(len(all_data))):
key = b3d.split_key(key)
trace = inference_step(trace, key, all_data[T]["rgbd"])
tracking_results[T] = trace

current_outlier_probabilities = trace.get_choices()[
"outlier_probability"
]
scores = grid_outlier_prob(
trace,
outlier_probability_sweep[..., None]
* jnp.ones_like(current_outlier_probabilities),
)
trace = b3d.update_choices(
if debug:
b3d.chisight.gen3d.model.viz_trace(
trace,
Pytree.const(("outlier_probability",)),
outlier_probability_sweep[jnp.argmax(scores, axis=0)],
)

current_outlier_probabilities = trace.get_choices()[
"outlier_probability"
]
# b3d.rr_log_cloud(
# mesh.vertices,
# colors=colors[1 * (current_outlier_probabilities == outlier_probability_sweep[0])],
# channel="cloud/outlier_probabilities"
# )

info = info_from_trace(trace)
current_outlier_probabilities = trace.get_choices()[
"outlier_probability"
]
model_rgbd, observed_rgbd = (
info["model_rgbd"],
info["corresponding_observed_rgbd"],
T,
ground_truth_vertices=meshes[OBJECT_INDEX].vertices,
ground_truth_pose=all_data[T]["camera_pose"].inv()
@ all_data[T]["object_poses"][OBJECT_INDEX],
)
deltas = observed_rgbd - model_rgbd
deltas_clipped = jnp.clip(deltas, -0.05, 0.05)
new_model_rgbd = model_rgbd + deltas_clipped

mesh = trace.get_args()[0]["meshes"][0]
is_inlier = (
current_outlier_probabilities == outlier_probability_sweep[0]
)
mesh.vertex_attributes = mesh.vertex_attributes.at[is_inlier].set(
new_model_rgbd[is_inlier, :3]
)

trace, _ = importance_jit(
jax.random.PRNGKey(2),
trace.get_choices(),
(
{
"num_objects": Pytree.const(1),
"meshes": [mesh],
"likelihood_args": likelihood_args,
},
),
)
tracking_results[T] = trace
if debug:
viz_trace(trace, T)

inferred_poses = Pose.stack_poses(
[
tracking_results[t].get_choices()["object_pose_0"]
tracking_results[t].get_choices()["pose"]
for t in range(len(all_data))
]
)
Expand Down
8 changes: 4 additions & 4 deletions src/b3d/chisight/gen3d/inference_moves.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def propose_a_points_attributes(
"""
return _propose_a_points_attributes(
key,
observed_rgbd=observed_rgbd_for_point,
observed_rgbd_for_point=observed_rgbd_for_point,
latent_depth=new_state["pose"].apply(hyperparams["vertices"][vertex_index])[2],
previous_color=prev_state["colors"][vertex_index],
previous_visibility_prob=prev_state["visibility_prob"][vertex_index],
Expand All @@ -158,7 +158,7 @@ def propose_a_points_attributes(

def _propose_a_points_attributes(
key,
observed_rgbd,
observed_rgbd_for_point,
latent_depth,
previous_color,
previous_visibility_prob,
Expand All @@ -181,7 +181,7 @@ def score_attribute_assignment(color, visprob, dnrprob):
dnrprob_transition_score = dnrp_transition_kernel.logpdf(dnrprob, previous_dnrp)
color_transition_score = color_kernel.logpdf(color, previous_color)
likelihood_score = obs_rgbd_kernel.logpdf(
observed_rgbd=observed_rgbd,
observed_rgbd=observed_rgbd_for_point,
latent_rgbd=jnp.append(color, latent_depth),
color_scale=color_scale,
depth_scale=depth_scale,
Expand Down Expand Up @@ -209,7 +209,7 @@ def score_attribute_assignment(color, visprob, dnrprob):
key=k,
visprob=visprob_dnrprob_pair[0],
dnrprob=visprob_dnrprob_pair[1],
observed_rgb=observed_rgbd[:3],
observed_rgb=observed_rgbd_for_point[:3],
score_attribute_assignment=score_attribute_assignment,
previous_rgb=previous_color,
color_scale=color_scale,
Expand Down
3 changes: 3 additions & 0 deletions src/b3d/chisight/gen3d/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,3 +181,6 @@ def viz_trace(trace, t=0, ground_truth_vertices=None, ground_truth_pose=None):
ground_truth_pose.apply(ground_truth_vertices),
"scene/ground_truth_object_mesh",
)

b3d.rr_log_pose(ground_truth_pose, "scene/ground_truth_pose")
b3d.rr_log_pose(trace.get_choices()["pose"], "scene/inferred_pose")
Loading
Loading