Skip to content

Commit

Permalink
jitting in demo
Browse files Browse the repository at this point in the history
  • Loading branch information
nishadgothoskar committed Jul 25, 2024
1 parent 4bb14fc commit 8aaed91
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 26 deletions.
58 changes: 32 additions & 26 deletions scripts/acquire_object_model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import argparse

import b3d
import jax
import jax.numpy as jnp
from b3d import Pose
from tqdm import tqdm

import b3d
from b3d import Mesh, Pose

b3d.rr_init("acquire_object_model")

Check failure on line 10 in scripts/acquire_object_model.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

scripts/acquire_object_model.py:1:1: I001 Import block is un-sorted or un-formatted

# python scripts/acquire_object_model.py assets/shared_data_bucket/input_data/lysol_static.r3d
Expand All @@ -23,19 +24,13 @@ def acquire(input_path, output_path=None):
_, _, fx, fy, cx, cy, near, far = data["camera_intrinsics_depth"]
image_height, image_width = data["depth"].shape[1:3]
num_scenes = data["depth"].shape[0]

indices = jnp.arange(0, num_scenes, 10)

camera_poses_full = data["camera_pose"]
camera_poses = camera_poses_full[indices]

xyz = b3d.xyz_from_depth_vectorized(data["depth"][indices], fx, fy, cx, cy)
xyz_world_frame = camera_poses[:, None, None].apply(xyz)

# for i in range(len(xyz_world_frame)):
# b3d.rr_set_time(i)
# b3d.utils.rr_log_cloud("xyz", xyz_world_frame[i])

# Resize rgbs to be same size as depth.
rgbs = data["rgb"]
rgbs_resized = jnp.clip(
Expand Down Expand Up @@ -91,14 +86,11 @@ def acquire(input_path, output_path=None):
grid_points = grid[model_mask]
colors = grid_colors[model_mask]

meshes = b3d.mesh.transform_mesh(
jax.vmap(b3d.mesh.Mesh.cube_mesh)(
jnp.ones((grid_points.shape[0], 3)) * resolution * 2.0, colors
),
b3d.Pose.from_translation(grid_points)[:, None],
_object_mesh = Mesh.voxel_mesh_from_xyz_colors_dimensions(
grid_points,
jnp.ones((grid_points.shape[0], 3)) * resolution * 2.0,
colors,
)
_object_mesh = b3d.mesh.Mesh.squeeze_mesh(meshes)

object_pose = Pose.from_translation(jnp.median(_object_mesh.vertices, axis=0))
object_mesh = _object_mesh.transform(object_pose.inv())
object_mesh.rr_visualize("mesh")
Expand Down Expand Up @@ -130,13 +122,11 @@ def acquire(input_path, output_path=None):
# colors = colors[subset]
# distances_from_camera = distances_from_camera[subset]

meshes = b3d.mesh.transform_mesh(
jax.vmap(b3d.mesh.Mesh.cube_mesh)(
jnp.ones((background_xyzs.shape[0], 3)) * distances_from_camera, colors
),
b3d.Pose.from_translation(background_xyzs)[:, None],
background_mesh = Mesh.voxel_mesh_from_xyz_colors_dimensions(
background_xyzs,
jnp.ones((background_xyzs.shape[0], 3)) * distances_from_camera,
colors,
)
background_mesh = b3d.mesh.Mesh.squeeze_mesh(meshes)
background_mesh.rr_visualize("background_mesh")

object_poses = [
Expand All @@ -151,16 +141,32 @@ def acquire(input_path, output_path=None):
object_poses,
)

viz_images = []
for t in tqdm(range(len(camera_poses_full))):
b3d.utils.rr_set_time(t)
rgbd = renderer.render_rgbd_from_mesh(
renderer = b3d.RendererOriginal(
image_width, image_height, fx, fy, cx, cy, near, far
)

def render_image(t):
return renderer.render_rgbd_from_mesh(
scene_mesh.transform(camera_poses_full[t].inv())
)
viz_images.append(b3d.viz_rgb(rgbd))

ss = jnp.concatenate(
[
jnp.arange(0, len(camera_poses_full), 30),
jnp.array([len(camera_poses_full) - 1]),
]
)
ss = jnp.vstack([ss[:-1], ss[1:]]).T
render_images = jax.jit(jax.vmap(render_image))
images = jnp.concatenate([render_images(jnp.arange(s[0], s[1])) for s in ss])

viz_images = []
for t in tqdm(range(len(images))):
viz_images.append(b3d.viz_rgb(images[t]))

b3d.make_video_from_pil_images(viz_images, output_path, fps=30.0)
print(f"Saved video to {output_path}")

return output_path


Expand Down
12 changes: 12 additions & 0 deletions src/b3d/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,15 @@ def rr_visualize_mesh(channel, mesh):
)


@jax.jit
def voxel_mesh_from_xyz_colors_dimensions(xyz, resolutions, colors):
meshes = b3d.mesh.transform_mesh(
jax.vmap(b3d.mesh.Mesh.cube_mesh)(resolutions, colors),
b3d.Pose.from_translation(xyz)[:, None],
)
return b3d.mesh.Mesh.squeeze_mesh(meshes)


@register_pytree_node_class
class Mesh:
def __init__(self, vertices, faces, vertex_attributes):
Expand Down Expand Up @@ -148,6 +157,9 @@ def __getitem__(self, index):
transform_and_merge_meshes = staticmethod(transform_and_merge_meshes)
transform_mesh = staticmethod(transform_mesh)
squeeze_mesh = staticmethod(squeeze_mesh)
voxel_mesh_from_xyz_colors_dimensions = staticmethod(
voxel_mesh_from_xyz_colors_dimensions
)

def rr_visualize(self, channel):
rr_visualize_mesh(channel, self)
Expand Down

0 comments on commit 8aaed91

Please sign in to comment.