Skip to content

Commit

Permalink
Add update_points method
Browse files Browse the repository at this point in the history
  • Loading branch information
roomrys committed Feb 5, 2025
1 parent 04184db commit 52826be
Show file tree
Hide file tree
Showing 2 changed files with 304 additions and 15 deletions.
251 changes: 246 additions & 5 deletions sleap_io/model/camera.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,150 @@
import cv2
import numpy as np
import toml
import warnings
from attrs import define, field

from sleap_io.model.instance import Instance, PredictedInstance
from sleap_io.model.video import Video


def compute_instance_area(points: np.ndarray) -> np.ndarray:
"""Compute the area of the bounding box of a set of keypoints.
Args:
points: A numpy array of coordinates.
Returns:
The area of the bounding box of the points.
"""
if points.ndim == 2:
points = np.expand_dims(points, axis=0)

Check warning on line 28 in sleap_io/model/camera.py

View check run for this annotation

Codecov / codecov/patch

sleap_io/model/camera.py#L27-L28

Added lines #L27 - L28 were not covered by tests

min_pt = np.nanmin(points, axis=-2)
max_pt = np.nanmax(points, axis=-2)

Check warning on line 31 in sleap_io/model/camera.py

View check run for this annotation

Codecov / codecov/patch

sleap_io/model/camera.py#L30-L31

Added lines #L30 - L31 were not covered by tests

return np.prod(max_pt - min_pt, axis=-1)

Check warning on line 33 in sleap_io/model/camera.py

View check run for this annotation

Codecov / codecov/patch

sleap_io/model/camera.py#L33

Added line #L33 was not covered by tests


def compute_oks(
points_gt: np.ndarray,
points_pr: np.ndarray,
scale: float | None = None,
stddev: float = 0.025,
use_cocoeval: bool = True,
) -> np.ndarray:
"""Compute the object keypoints similarity between sets of points.
Notes:
It's important to set the stddev appropriately when accounting for the
difficulty of each keypoint type. For reference, the median value for
all keypoint types in COCO is 0.072. The "easiest" keypoint is the left
eye, with stddev of 0.025, since it is easy to precisely locate the
eyes when labeling. The "hardest" keypoint is the left hip, with stddev
of 0.107, since it's hard to locate the left hip bone without external
anatomical features and since it is often occluded by clothing.
The implementation here is based off of the descriptions in:
Ronch & Perona. "Benchmarking and Error Diagnosis in Multi-Instance Pose
Estimation." ICCV (2017).
Args:
points_gt: Ground truth instances of shape (n_gt, n_nodes, n_ed),
where n_nodes is the number of body parts/keypoint types, and n_ed
is the number of Euclidean dimensions (typically 2 or 3). Keypoints
that are missing/not visible should be represented as NaNs.
points_pr: Predicted instance of shape (n_pr, n_nodes, n_ed).
use_cocoeval: Indicates whether the OKS score is calculated like cocoeval
method or not. True indicating the score is calculated using the
cocoeval method (widely used and the code can be found here at
https://github.com/cocodataset/cocoapi/blob/8c9bcc3cf640524c4c20a9c40e89cb6a2f2fa0e9/PythonAPI/pycocotools/cocoeval.py#L192C5-L233C20)
and False indicating the score is calculated using the method exactly
as given in the paper referenced in the Notes below.
scale: Size scaling factor to use when weighing the scores, typically
the area of the bounding box of the instance (in pixels). This
should be of the length n_gt. If a scalar is provided, the same
number is used for all ground truth instances. If set to None, the
bounding box area of the ground truth instances will be calculated.
stddev: The standard deviation associated with the spread in the
localization accuracy of each node/keypoint type. This should be of
the length n_nodes. "Easier" keypoint types will have lower values
to reflect the smaller spread expected in localizing it.
Returns:
The object keypoints similarity between every pair of ground truth and
predicted instance, a numpy array of of shape (n_gt, n_pr) in the range
of [0, 1.0], with 1.0 denoting a perfect match.
"""
if points_gt.ndim == 2:
points_gt = np.expand_dims(points_gt, axis=0)
if points_pr.ndim == 2:
points_pr = np.expand_dims(points_pr, axis=0)

Check warning on line 88 in sleap_io/model/camera.py

View check run for this annotation

Codecov / codecov/patch

sleap_io/model/camera.py#L85-L88

Added lines #L85 - L88 were not covered by tests

if scale is None:
scale = compute_instance_area(points_gt)

Check warning on line 91 in sleap_io/model/camera.py

View check run for this annotation

Codecov / codecov/patch

sleap_io/model/camera.py#L90-L91

Added lines #L90 - L91 were not covered by tests

n_gt, n_nodes, n_ed = points_gt.shape # n_ed = 2 or 3 (euclidean dimensions)
n_pr = points_pr.shape[0]

Check warning on line 94 in sleap_io/model/camera.py

View check run for this annotation

Codecov / codecov/patch

sleap_io/model/camera.py#L93-L94

Added lines #L93 - L94 were not covered by tests

# If scalar scale was provided, use the same for each ground truth instance.
if np.isscalar(scale):
scale = np.full(n_gt, scale)

Check warning on line 98 in sleap_io/model/camera.py

View check run for this annotation

Codecov / codecov/patch

sleap_io/model/camera.py#L97-L98

Added lines #L97 - L98 were not covered by tests

# If scalar standard deviation was provided, use the same for each node.
if np.isscalar(stddev):
stddev = np.full(n_nodes, stddev)

Check warning on line 102 in sleap_io/model/camera.py

View check run for this annotation

Codecov / codecov/patch

sleap_io/model/camera.py#L101-L102

Added lines #L101 - L102 were not covered by tests

# Compute displacement between each pair.
displacement = np.reshape(points_gt, (n_gt, 1, n_nodes, n_ed)) - np.reshape(

Check warning on line 105 in sleap_io/model/camera.py

View check run for this annotation

Codecov / codecov/patch

sleap_io/model/camera.py#L105

Added line #L105 was not covered by tests
points_pr, (1, n_pr, n_nodes, n_ed)
)
assert displacement.shape == (n_gt, n_pr, n_nodes, n_ed)

Check warning on line 108 in sleap_io/model/camera.py

View check run for this annotation

Codecov / codecov/patch

sleap_io/model/camera.py#L108

Added line #L108 was not covered by tests

# Convert to pairwise Euclidean distances.
distance = (displacement**2).sum(axis=-1) # (n_gt, n_pr, n_nodes)
assert distance.shape == (n_gt, n_pr, n_nodes)

Check warning on line 112 in sleap_io/model/camera.py

View check run for this annotation

Codecov / codecov/patch

sleap_io/model/camera.py#L111-L112

Added lines #L111 - L112 were not covered by tests

# Compute the normalization factor per keypoint.
if use_cocoeval:

Check warning on line 115 in sleap_io/model/camera.py

View check run for this annotation

Codecov / codecov/patch

sleap_io/model/camera.py#L115

Added line #L115 was not covered by tests
# If use_cocoeval is True, then compute normalization factor according to cocoeval.
spread_factor = (2 * stddev) ** 2
scale_factor = 2 * (scale + np.spacing(1))

Check warning on line 118 in sleap_io/model/camera.py

View check run for this annotation

Codecov / codecov/patch

sleap_io/model/camera.py#L117-L118

Added lines #L117 - L118 were not covered by tests
else:
# If use_cocoeval is False, then compute normalization factor according to the paper.
spread_factor = stddev**2
scale_factor = 2 * ((scale + np.spacing(1)) ** 2)
normalization_factor = np.reshape(spread_factor, (1, 1, n_nodes)) * np.reshape(

Check warning on line 123 in sleap_io/model/camera.py

View check run for this annotation

Codecov / codecov/patch

sleap_io/model/camera.py#L121-L123

Added lines #L121 - L123 were not covered by tests
scale_factor, (n_gt, 1, 1)
)
assert normalization_factor.shape == (n_gt, 1, n_nodes)

Check warning on line 126 in sleap_io/model/camera.py

View check run for this annotation

Codecov / codecov/patch

sleap_io/model/camera.py#L126

Added line #L126 was not covered by tests

# Since a "miss" is considered as KS < 0.5, we'll set the
# distances for predicted points that are missing to inf.
missing_pr = np.any(np.isnan(points_pr), axis=-1) # (n_pr, n_nodes)
assert missing_pr.shape == (n_pr, n_nodes)
distance[:, missing_pr] = np.inf

Check warning on line 132 in sleap_io/model/camera.py

View check run for this annotation

Codecov / codecov/patch

sleap_io/model/camera.py#L130-L132

Added lines #L130 - L132 were not covered by tests

# Compute the keypoint similarity as per the top of Eq. 1.
ks = np.exp(-(distance / normalization_factor)) # (n_gt, n_pr, n_nodes)
assert ks.shape == (n_gt, n_pr, n_nodes)

Check warning on line 136 in sleap_io/model/camera.py

View check run for this annotation

Codecov / codecov/patch

sleap_io/model/camera.py#L135-L136

Added lines #L135 - L136 were not covered by tests

# Set the KS for missing ground truth points to 0.
# This is equivalent to the visibility delta function of the bottom
# of Eq. 1.
missing_gt = np.any(np.isnan(points_gt), axis=-1) # (n_gt, n_nodes)
assert missing_gt.shape == (n_gt, n_nodes)
ks[np.expand_dims(missing_gt, axis=1)] = 0

Check warning on line 143 in sleap_io/model/camera.py

View check run for this annotation

Codecov / codecov/patch

sleap_io/model/camera.py#L141-L143

Added lines #L141 - L143 were not covered by tests

# Compute the OKS.
n_visible_gt = np.sum(

Check warning on line 146 in sleap_io/model/camera.py

View check run for this annotation

Codecov / codecov/patch

sleap_io/model/camera.py#L146

Added line #L146 was not covered by tests
(~missing_gt).astype("float64"), axis=-1, keepdims=True
) # (n_gt, 1)
oks = np.sum(ks, axis=-1) / n_visible_gt
assert oks.shape == (n_gt, n_pr)

Check warning on line 150 in sleap_io/model/camera.py

View check run for this annotation

Codecov / codecov/patch

sleap_io/model/camera.py#L149-L150

Added lines #L149 - L150 were not covered by tests

return oks

Check warning on line 152 in sleap_io/model/camera.py

View check run for this annotation

Codecov / codecov/patch

sleap_io/model/camera.py#L152

Added line #L152 was not covered by tests


def triangulate_dlt_vectorized(
points: np.ndarray, projection_matrices: np.ndarray
) -> np.ndarray:
Expand Down Expand Up @@ -342,17 +480,17 @@ def cameras(self) -> list[Camera]:
def score(self) -> float | None:
"""Get reprojection score of the `InstanceGroup`.
The score is the average OKS for all ground and projected point pairs in the
`InstanceGroup`.
Returns:
Score of `InstanceGroup`.
"""
return self._score

Check warning on line 489 in sleap_io/model/camera.py

View check run for this annotation

Codecov / codecov/patch

sleap_io/model/camera.py#L489

Added line #L489 was not covered by tests

@score.setter
def score(self, score: float):
"""Set score of `InstanceGroup`.
This function sets the score for the `InstanceGroup` and then sets the score for
each `Instance` in the group.
def score(self, score: float | None):
"""Set the reprojection score of the `InstanceGroup`.
Args:
score: Score to set for `InstanceGroup`.
Expand All @@ -365,6 +503,49 @@ def score(self, score: float):
if hasattr(instance, "score"):
instance.score = score

Check warning on line 504 in sleap_io/model/camera.py

View check run for this annotation

Codecov / codecov/patch

sleap_io/model/camera.py#L502-L504

Added lines #L502 - L504 were not covered by tests

@property
def point_scores(self) -> np.ndarray | None:
"""Get point scores of `InstanceGroup`.
Returns:
Point scores of `InstanceGroup` as np.ndarray of shape (N,) where N is the
number of nodes in each `Instance` or None if not set.
"""
return self._point_scores

Check warning on line 514 in sleap_io/model/camera.py

View check run for this annotation

Codecov / codecov/patch

sleap_io/model/camera.py#L514

Added line #L514 was not covered by tests

@point_scores.setter
def point_scores(self, point_scores: np.ndarray | None):
"""Set point scores of `InstanceGroup`.
Args:
point_scores: Point scores to set for `InstanceGroup`.
"""
if point_scores is None:
self._point_scores = None
return

Check warning on line 525 in sleap_io/model/camera.py

View check run for this annotation

Codecov / codecov/patch

sleap_io/model/camera.py#L523-L525

Added lines #L523 - L525 were not covered by tests

# Ensure point scores have correct shape.
point_scores = point_scores.flatten()
try:
n_points = len(self._template_instance)
if point_scores.shape != (n_points,):
raise ValueError
except Exception as e:
raise ValueError(

Check warning on line 534 in sleap_io/model/camera.py

View check run for this annotation

Codecov / codecov/patch

sleap_io/model/camera.py#L528-L534

Added lines #L528 - L534 were not covered by tests
"Expected point scores to be np.ndarray of shape (N,) or (1, N) where N"
" is the number of nodes in each `Instance`, but received shape "
f"{point_scores.shape}.\n\n{e}"
)

# Set point scores for all instances in group.
self._point_scores = point_scores
for instance in self.instances:
if hasattr(instance, "score"):
instance.point_scores = point_scores

Check warning on line 544 in sleap_io/model/camera.py

View check run for this annotation

Codecov / codecov/patch

sleap_io/model/camera.py#L541-L544

Added lines #L541 - L544 were not covered by tests

# Now set the score for the group.
self.score = np.nanmean(point_scores)

Check warning on line 547 in sleap_io/model/camera.py

View check run for this annotation

Codecov / codecov/patch

sleap_io/model/camera.py#L547

Added line #L547 was not covered by tests

@property
def triangulation(self) -> np.ndarray | None:
"""Get triangulated 3D points of `InstanceGroup`.
Expand All @@ -380,6 +561,10 @@ def triangulation(self, points: np.ndarray):
Args:
points: Triangulated 3D points to set for `InstanceGroup`.
Raises:
ValueError: If `points` are not of shape (N, 3) where N is the number of
nodes in each `Instance`.
"""
# Validate points in
points_shape = points.shape
Expand Down Expand Up @@ -416,6 +601,62 @@ def get_instance(self, camera: Camera) -> Instance | None:
"""
return self._instance_by_camcorder.get(camera, None)

Check warning on line 602 in sleap_io/model/camera.py

View check run for this annotation

Codecov / codecov/patch

sleap_io/model/camera.py#L602

Added line #L602 was not covered by tests

def update_points(self, points: np.ndarray, cams_to_include: list[Camera]):
"""Update points of `Instance`s in group using triangulated 3D points.
This function updates points for `Instance`s at select `Camera`s in the group
using the projected 3D points. If the `Instance` does not exist for a `Camera`,
the `Camera` is skipped and a warning is given (a new `Instance` is NOT created)
. If the `Instance` exists, the points are updated.
Args:
points: Triangulated 3D points to update or insert in instances of shape
(N, 3) where N is the number of points.
cams_to_include: List of `Camera` objects to update or insert points for.
Raises:
ValueError: If `points` does not have shape (N, 3) where N is the number of
nodes in each `Instance`.
Warning: If no `Instance` found for `Camera` in `InstanceGroup`.
"""
# Validate the points using triangulation setter (possibly raising ValueError).
self.triangulation = points

Check warning on line 624 in sleap_io/model/camera.py

View check run for this annotation

Codecov / codecov/patch

sleap_io/model/camera.py#L624

Added line #L624 was not covered by tests

# For calculating OKS.
n_views = len(cams_to_include)
n_nodes = len(self._template_instance)
oks_scores = np.full((n_views, n_nodes), np.nan)

Check warning on line 629 in sleap_io/model/camera.py

View check run for this annotation

Codecov / codecov/patch

sleap_io/model/camera.py#L627-L629

Added lines #L627 - L629 were not covered by tests

# Get the projected 2D points for the cameras.
for cam_idx, cam in enumerate(cams_to_include):

Check warning on line 632 in sleap_io/model/camera.py

View check run for this annotation

Codecov / codecov/patch

sleap_io/model/camera.py#L632

Added line #L632 was not covered by tests
# Get instance for camera or create a new instance if it doesn't exist.
instance = self.get_instance(cam)
if instance is None:
warnings.warn(

Check warning on line 636 in sleap_io/model/camera.py

View check run for this annotation

Codecov / codecov/patch

sleap_io/model/camera.py#L634-L636

Added lines #L634 - L636 were not covered by tests
f"No instance found for camera {cam} in group {self}. Cannot update"
" points. Skipping..."
)
continue

Check warning on line 640 in sleap_io/model/camera.py

View check run for this annotation

Codecov / codecov/patch

sleap_io/model/camera.py#L640

Added line #L640 was not covered by tests

# Get the projected 2D points for the camera.
points_projected = cam.project(points) # N x 2

Check warning on line 643 in sleap_io/model/camera.py

View check run for this annotation

Codecov / codecov/patch

sleap_io/model/camera.py#L643

Added line #L643 was not covered by tests

# Compute the OKS score for the instance if it is a ground truth instance
if not isinstance(instance, PredictedInstance):
gt_points = instance.numpy() # N x 2
instance_oks = compute_oks(

Check warning on line 648 in sleap_io/model/camera.py

View check run for this annotation

Codecov / codecov/patch

sleap_io/model/camera.py#L646-L648

Added lines #L646 - L648 were not covered by tests
gt_points,
points_projected,
) # 1 x N
oks_scores[cam_idx] = instance_oks

Check warning on line 652 in sleap_io/model/camera.py

View check run for this annotation

Codecov / codecov/patch

sleap_io/model/camera.py#L652

Added line #L652 was not covered by tests

# Update points for instance.
instance.points = points_projected

Check warning on line 655 in sleap_io/model/camera.py

View check run for this annotation

Codecov / codecov/patch

sleap_io/model/camera.py#L655

Added line #L655 was not covered by tests

# Set the OKS scores for the group (and update `PredictedInstance.score`s).
self.point_scores = oks_scores

Check warning on line 658 in sleap_io/model/camera.py

View check run for this annotation

Codecov / codecov/patch

sleap_io/model/camera.py#L658

Added line #L658 was not covered by tests

def numpy(
self,
cams_to_include: list[Camera],
Expand Down
Loading

0 comments on commit 52826be

Please sign in to comment.