Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Interactive experiments. SLAM and HB #82

Merged
merged 16 commits into from
Jul 26, 2024
118 changes: 67 additions & 51 deletions demos/differentiable_renderer/gradient_based_pose_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,18 @@
from functools import partial

import b3d
import b3d.chisight.dense.differentiable_renderer as rendering
from jax.scipy.spatial.transform import Rotation as Rot
from b3d import Pose, Mesh
import rerun as rr
import functools
import genjax
from tqdm import tqdm
import jax
import jax.numpy as jnp
import optax
import rerun as rr
import trimesh
from b3d import Pose
from b3d.renderer_original import RendererOriginal
from tqdm import tqdm
import b3d.chisight.dense.differentiable_renderer as rendering
import demos.differentiable_renderer.utils as utils
from functools import partial

rr.init("gradients")
rr.connect("127.0.0.1:8812")
Expand All @@ -28,28 +31,6 @@ def map_fn(nested_dict):
return map_fn


# Set up OpenGL renderer
image_width = 200
image_height = 200
fx = 150.0
fy = 150.0
cx = 100.0
cy = 100.0
near = 0.001
far = 16.0
renderer = RendererOriginal(image_width, image_height, fx, fy, cx, cy, near, far)

WINDOW = 5

mesh_path = os.path.join(
b3d.get_root_path(),
"assets/shared_data_bucket/ycb_video_models/models/006_mustard_bottle/textured_simple.obj",
)
mesh = trimesh.load(mesh_path)
object_library = b3d.MeshLibrary.make_empty_library()
object_library.add_trimesh(mesh)


def render_to_dist_params(
renderer,
vertices,
Expand All @@ -75,7 +56,7 @@ def render_to_dist_params(
The remaining weights are those assigned to some triangles in the scene.
The attributes measured on those triangles are contained in `attributes`.
"""
image = renderer.rasterize(vertices[None, ...], faces)
image = renderer.rasterize_many(vertices[None, ...], faces)
triangle_id_image = image[0, ..., -1].astype(jnp.int32)

triangle_intersected_padded = jnp.pad(
Expand Down Expand Up @@ -151,36 +132,75 @@ def render_to_average_rgbd(
hyperparams = rendering.DifferentiableRendererHyperparams(3, 5e-5, 0.25, -1)


def render(params):
def render(params, mesh_params):
image = render_to_average_rgbd(
renderer,
b3d.Pose(params["position"], params["quaternion"]).apply(
object_library.vertices
mesh_params["vertices"]
),
object_library.faces,
object_library.attributes,
mesh_params["faces"],
mesh_params["vertex_attributes"],
background_attribute=jnp.array([0.0, 0.0, 0.0, 0]),
hyperparams=hyperparams,
)
return image


render_jit = jax.jit(render)
WINDOW = 5


ycb_dir = os.path.join(b3d.get_assets_path(), "bop/ycbv")

# image_ids = [image] if image is not None else range(1, num_scenes, FRAME_RATE)
scene_id = 48
print(f"Scene {scene_id}")
num_scenes = b3d.io.data_loader.get_ycbv_num_test_images(ycb_dir, scene_id)
image_ids = range(1, num_scenes + 1, 50)
all_data = b3d.io.get_ycbv_test_images(ycb_dir, scene_id, image_ids)

meshes = [
Mesh.from_obj_file(
os.path.join(ycb_dir, f'models/obj_{f"{id + 1}".rjust(6, "0")}.ply')
).scale(0.001)
for id in all_data[0]["object_types"]
]

height, width = all_data[0]["rgbd"].shape[:2]
fx, fy, cx, cy = all_data[0]["camera_intrinsics"]
scaling_factor = 0.3
renderer = b3d.renderer.renderer_original.RendererOriginal(
width * scaling_factor,
height * scaling_factor,
fx * scaling_factor,
fy * scaling_factor,
cx * scaling_factor,
cy * scaling_factor,
0.01,
2.0,
)

vertices, faces = object_library.vertices, object_library.faces
image = renderer.rasterize(vertices[None, ...], faces)
IDX = 1
mesh = meshes[IDX]

render_jit = jax.jit(render)

mesh_params = {
"vertices": mesh.vertices,
"faces": mesh.faces,
"vertex_attributes": mesh.vertex_attributes,
}
gt_pose = Pose.from_position_and_target(
jnp.array([0.3, 0.3, 0.0]),
jnp.array([0.0, 0.0, 0.0]),
).inv()
gt_image = render_jit({"position": gt_pose.position, "quaternion": gt_pose.quaternion})
gt_image = b3d.resize_image(all_data[0]["rgbd"], renderer.height, renderer.width)


def loss_func_rgbd(params, gt):
image = render(params)
return jnp.mean(jnp.abs(image[..., :3] - gt[..., :3]))
def loss_func_rgbd(params, mesh_params, gt):
image = render(params, mesh_params)
rendered_depth = image[..., 3]
rendered_areas = (rendered_depth / fx) * (rendered_depth / fy)
return jnp.mean(jnp.abs(image[..., :3] - gt[..., :3]) * rendered_areas[..., None])
# + jnp.mean(jnp.abs(image[...,3] - gt[...,3]))


Expand All @@ -190,7 +210,7 @@ def loss_func_rgbd(params, gt):
@partial(jax.jit, static_argnums=(1,))
def step(carry, tx):
(params, gt_image, state) = carry
_loss, (gradients,) = loss_func_rgbd_grad(params, gt_image)
loss, (gradients,) = loss_func_rgbd_grad(params, mesh_params, gt_image)
updates, state = tx.update(gradients, state, params)
params = optax.apply_updates(params, updates)
return ((params, gt_image, state), None)
Expand All @@ -206,32 +226,28 @@ def step(carry, tx):
label_fn,
)

pose = Pose.from_position_and_target(
jnp.array([0.6, 0.3, 0.6]),
jnp.array([0.0, 0.0, 0.0]),
).inv()
pose = all_data[0]["camera_pose"].inv() @ all_data[0]["object_poses"][IDX]

params = {
"position": pose.position,
"quaternion": pose.quaternion,
}

rr.log("image", rr.Image(gt_image[..., :3]), timeless=True)
rr.log("cloud", rr.Points3D(gt_pose.apply(object_library.vertices)), timeless=True)
rr.log("cloud", rr.Points3D(gt_pose.apply(mesh.vertices)), timeless=True)

pbar = tqdm(range(200))
state = tx.init(params)
images = [render_jit(params)]
images = [render_jit(params, mesh_params)]
for t in pbar:
(params, gt_image, state), _ = step((params, gt_image, state), tx)
rr.set_time_sequence("frame", t)
image = render_jit(params)
image = render_jit(params, mesh_params)
pbar.set_description(f"Loss: {loss_func_rgbd(params, mesh_params, gt_image)}")
rr.log("image/reconstruction", rr.Image(image[..., :3]))
rr.log(
"cloud/reconstruction",
rr.Points3D(
b3d.Pose(params["position"], params["quaternion"]).apply(
object_library.vertices
)
b3d.Pose(params["position"], params["quaternion"]).apply(mesh.vertices)
),
)
Loading