From 31ac7b64eca5d449159a7bb53dd18918be0a977e Mon Sep 17 00:00:00 2001 From: Jvshen Date: Thu, 9 Jan 2025 14:41:55 -0800 Subject: [PATCH] Add attribute for 3d triangulation --- sleap/gui/commands.py | 2 +- sleap/io/cameras.py | 79 +++++++++++++++++++++++++++++++++++-------- 2 files changed, 66 insertions(+), 15 deletions(-) diff --git a/sleap/gui/commands.py b/sleap/gui/commands.py index 075ccf060..03f183a49 100644 --- a/sleap/gui/commands.py +++ b/sleap/gui/commands.py @@ -3785,7 +3785,7 @@ def do_action(cls, context: CommandContext, params: dict): # Update or create/insert ("upsert") instance points frame_group.upsert_points( - points=points_reprojected, + points_3d=points_reprojected, instance_groups=instance_groups, exclude_complete=True, ) diff --git a/sleap/io/cameras.py b/sleap/io/cameras.py index c2a1bb831..3191ef1a8 100644 --- a/sleap/io/cameras.py +++ b/sleap/io/cameras.py @@ -16,6 +16,7 @@ from sleap.instance import Instance, LabeledFrame, PredictedInstance from sleap.io.video import Video from sleap.util import compute_oks, deep_iterable_converter +from sleap_anipose.triangulation import reproject logger = logging.getLogger(__name__) @@ -452,6 +453,7 @@ class InstanceGroup: _dummy_instance: Optional[Instance] = field(default=None) camera_cluster: Optional[CameraCluster] = field(default=None) _score: Optional[float] = field(default=None) + _triangulation: Optional[np.ndarray] = field(default=None) def __attrs_post_init__(self): """Initialize `InstanceGroup` object.""" @@ -460,6 +462,14 @@ def __attrs_post_init__(self): for cam, instance in self._instance_by_camcorder.items(): self._camcorder_by_instance[instance] = cam + @property + def triangulation(self) -> Optional[np.ndarray]: + return self._triangulation + + @triangulation.setter + def triangulation(self, triangulation: np.ndarray): + self._triangulation = triangulation + def _create_dummy_instance(self, instance: Optional[Instance] = None): """Create a dummy instance to fill in for missing instances. @@ -837,7 +847,8 @@ def get_cam(self, instance: Instance) -> Optional[Camcorder]: def update_points( self, - points: np.ndarray, + points_3d: np.ndarray, + instance_groups: List['InstanceGroup'], cams_to_include: Optional[List[Camcorder]] = None, exclude_complete: bool = True, ): @@ -853,13 +864,62 @@ def update_points( exclude_complete: If True, then do not update points that are marked as complete. Default is True. """ + # Ensure we are working with a float array + points_3d = points_3d.astype(float) + + # Check if points are 3D + is_3d = points_3d.shape[-1] == 3 + if not is_3d: + raise ValueError("Expected 3D points with shape (M, T, N, 3).") + + # Check that the correct shape was passed in + n_views, n_instances, n_nodes, n_coords = points_3d.shape + assert n_views == len( + self.cams_to_include + ), f"Expected {len(self.cams_to_include)} views, got {n_views}." + assert n_instances == len( + instance_groups + ), f"Expected {len(instance_groups)} instances, got {n_instances}." + assert n_coords == 3, f"Expected 3 coordinates, got {n_coords}." + + # Reproject 3D points into 2D points for each camera view + pts_reprojected = reproject( + points_3d, + calib=self.session.camera_cluster, + excluded_views=self.excluded_views, + ) # M=include x F=1 x T x N x 2 + + # Squeeze back to the original shape + points_reprojected = np.squeeze(pts_reprojected, axis=1) # M=include x TxNx2 + + # Get projection bounds (based on video height/width) + bounds = self.session.projection_bounds + bounds_expanded_x = bounds[:, None, None, 0] + bounds_expanded_y = bounds[:, None, None, 1] + + # Create masks for out-of-bounds x and y coordinates + out_of_bounds_x = (points_reprojected[..., 0] < 0) | (points_reprojected[..., 0] > bounds_expanded_x) + out_of_bounds_y = (points_reprojected[..., 1] < 0) | (points_reprojected[..., 1] > bounds_expanded_y) + + # Replace out-of-bounds x and y coordinates with nan + points_reprojected[out_of_bounds_x, 0] = np.nan + points_reprojected[out_of_bounds_y, 1] = np.nan + + # Update points for each `InstanceGroup` + for ig_idx, instance_group in enumerate(instance_groups): + # Ensure that `InstanceGroup`s is in this `FrameGroup` + self._raise_if_instance_group_not_in_frame_group( + instance_group=instance_group + ) + # Update points for the instance group + instance_group.points = points_reprojected[ig_idx] # If no `Camcorder`s specified, then update `Instance`s for all `CameraCluster` if cams_to_include is None: cams_to_include = self.camera_cluster.cameras # Check that correct shape was passed in - n_views, n_nodes, _ = points.shape + n_views, n_nodes, _ = points_3d.shape assert n_views == len(cams_to_include), ( f"Number of views in `points` ({n_views}) does not match the number of " f"Camcorders in `cams_to_include` ({len(cams_to_include)})." @@ -883,13 +943,13 @@ def update_points( if not isinstance(instance, PredictedInstance): instance_oks = compute_oks( gt_points[cam_idx, :, :], - points[cam_idx, :, :], + points_3d[cam_idx, :, :], ) oks_scores[cam_idx] = instance_oks # Update the points for the instance instance.update_points( - points=points[cam_idx, :, :], exclude_complete=exclude_complete + points=points_3d[cam_idx, :, :], exclude_complete=exclude_complete ) # Update the score for the InstanceGroup to be the average OKS score @@ -2289,16 +2349,6 @@ def upsert_points( complete. Default is True. """ - # Check that the correct shape was passed in - n_views, n_instances, n_nodes, n_coords = points.shape - assert n_views == len( - self.cams_to_include - ), f"Expected {len(self.cams_to_include)} views, got {n_views}." - assert n_instances == len( - instance_groups - ), f"Expected {len(instance_groups)} instances, got {n_instances}." - assert n_coords == 2, f"Expected 2 coordinates, got {n_coords}." - # Ensure we are working with a float array points = points.astype(float) @@ -2331,6 +2381,7 @@ def upsert_points( points=instance_points, cams_to_include=self.cams_to_include, exclude_complete=exclude_complete, + bounds=bounds, ) def _raise_if_instance_not_in_instance_group(self, instance: Instance):