Skip to content

Commit

Permalink
Add attribute for 3d triangulation
Browse files Browse the repository at this point in the history
  • Loading branch information
Jvshen committed Jan 9, 2025
1 parent 9775fc2 commit 31ac7b6
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 15 deletions.
2 changes: 1 addition & 1 deletion sleap/gui/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -3785,7 +3785,7 @@ def do_action(cls, context: CommandContext, params: dict):

# Update or create/insert ("upsert") instance points
frame_group.upsert_points(
points=points_reprojected,
points_3d=points_reprojected,
instance_groups=instance_groups,
exclude_complete=True,
)
Expand Down
79 changes: 65 additions & 14 deletions sleap/io/cameras.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from sleap.instance import Instance, LabeledFrame, PredictedInstance
from sleap.io.video import Video
from sleap.util import compute_oks, deep_iterable_converter
from sleap_anipose.triangulation import reproject

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -452,6 +453,7 @@ class InstanceGroup:
_dummy_instance: Optional[Instance] = field(default=None)
camera_cluster: Optional[CameraCluster] = field(default=None)
_score: Optional[float] = field(default=None)
_triangulation: Optional[np.ndarray] = field(default=None)

def __attrs_post_init__(self):
"""Initialize `InstanceGroup` object."""
Expand All @@ -460,6 +462,14 @@ def __attrs_post_init__(self):
for cam, instance in self._instance_by_camcorder.items():
self._camcorder_by_instance[instance] = cam

@property
def triangulation(self) -> Optional[np.ndarray]:
return self._triangulation

@triangulation.setter
def triangulation(self, triangulation: np.ndarray):
self._triangulation = triangulation

def _create_dummy_instance(self, instance: Optional[Instance] = None):
"""Create a dummy instance to fill in for missing instances.
Expand Down Expand Up @@ -837,7 +847,8 @@ def get_cam(self, instance: Instance) -> Optional[Camcorder]:

def update_points(
self,
points: np.ndarray,
points_3d: np.ndarray,
instance_groups: List['InstanceGroup'],
cams_to_include: Optional[List[Camcorder]] = None,
exclude_complete: bool = True,
):
Expand All @@ -853,13 +864,62 @@ def update_points(
exclude_complete: If True, then do not update points that are marked as
complete. Default is True.
"""
# Ensure we are working with a float array
points_3d = points_3d.astype(float)

# Check if points are 3D
is_3d = points_3d.shape[-1] == 3
if not is_3d:
raise ValueError("Expected 3D points with shape (M, T, N, 3).")

# Check that the correct shape was passed in
n_views, n_instances, n_nodes, n_coords = points_3d.shape
assert n_views == len(
self.cams_to_include
), f"Expected {len(self.cams_to_include)} views, got {n_views}."
assert n_instances == len(
instance_groups
), f"Expected {len(instance_groups)} instances, got {n_instances}."
assert n_coords == 3, f"Expected 3 coordinates, got {n_coords}."

# Reproject 3D points into 2D points for each camera view
pts_reprojected = reproject(
points_3d,
calib=self.session.camera_cluster,
excluded_views=self.excluded_views,
) # M=include x F=1 x T x N x 2

# Squeeze back to the original shape
points_reprojected = np.squeeze(pts_reprojected, axis=1) # M=include x TxNx2

# Get projection bounds (based on video height/width)
bounds = self.session.projection_bounds
bounds_expanded_x = bounds[:, None, None, 0]
bounds_expanded_y = bounds[:, None, None, 1]

# Create masks for out-of-bounds x and y coordinates
out_of_bounds_x = (points_reprojected[..., 0] < 0) | (points_reprojected[..., 0] > bounds_expanded_x)
out_of_bounds_y = (points_reprojected[..., 1] < 0) | (points_reprojected[..., 1] > bounds_expanded_y)

# Replace out-of-bounds x and y coordinates with nan
points_reprojected[out_of_bounds_x, 0] = np.nan
points_reprojected[out_of_bounds_y, 1] = np.nan

# Update points for each `InstanceGroup`
for ig_idx, instance_group in enumerate(instance_groups):
# Ensure that `InstanceGroup`s is in this `FrameGroup`
self._raise_if_instance_group_not_in_frame_group(
instance_group=instance_group
)
# Update points for the instance group
instance_group.points = points_reprojected[ig_idx]

# If no `Camcorder`s specified, then update `Instance`s for all `CameraCluster`
if cams_to_include is None:
cams_to_include = self.camera_cluster.cameras

# Check that correct shape was passed in
n_views, n_nodes, _ = points.shape
n_views, n_nodes, _ = points_3d.shape
assert n_views == len(cams_to_include), (
f"Number of views in `points` ({n_views}) does not match the number of "
f"Camcorders in `cams_to_include` ({len(cams_to_include)})."
Expand All @@ -883,13 +943,13 @@ def update_points(
if not isinstance(instance, PredictedInstance):
instance_oks = compute_oks(
gt_points[cam_idx, :, :],
points[cam_idx, :, :],
points_3d[cam_idx, :, :],
)
oks_scores[cam_idx] = instance_oks

# Update the points for the instance
instance.update_points(
points=points[cam_idx, :, :], exclude_complete=exclude_complete
points=points_3d[cam_idx, :, :], exclude_complete=exclude_complete
)

# Update the score for the InstanceGroup to be the average OKS score
Expand Down Expand Up @@ -2289,16 +2349,6 @@ def upsert_points(
complete. Default is True.
"""

# Check that the correct shape was passed in
n_views, n_instances, n_nodes, n_coords = points.shape
assert n_views == len(
self.cams_to_include
), f"Expected {len(self.cams_to_include)} views, got {n_views}."
assert n_instances == len(
instance_groups
), f"Expected {len(instance_groups)} instances, got {n_instances}."
assert n_coords == 2, f"Expected 2 coordinates, got {n_coords}."

# Ensure we are working with a float array
points = points.astype(float)

Expand Down Expand Up @@ -2331,6 +2381,7 @@ def upsert_points(
points=instance_points,
cams_to_include=self.cams_to_include,
exclude_complete=exclude_complete,
bounds=bounds,
)

def _raise_if_instance_not_in_instance_group(self, instance: Instance):
Expand Down

0 comments on commit 31ac7b6

Please sign in to comment.