diff --git a/tests/io/test_cameras.py b/tests/io/test_cameras.py index 30c2771d2..84e26b9f3 100644 --- a/tests/io/test_cameras.py +++ b/tests/io/test_cameras.py @@ -763,6 +763,34 @@ def test_instance_group( assert cam not in instance_group.cameras +def test_instance_group_update_points_from_2d( + multiview_min_session_frame_groups: Labels, +): + + labels = multiview_min_session_frame_groups + session: RecordingSession = labels.sessions[0] + frame_idx = 0 + frame_group = session.frame_groups[frame_idx] + instance_group = frame_group.instance_groups[0] + + # Test `update_points_from_2d` (all in bounds, all updated) + n_cameras = len(frame_group.cams_to_include) + n_instance_groups = 1 + n_nodes = len(frame_group.session.labels.skeleton.nodes) + n_coords = 2 + value = 100 + points = np.full((n_cameras, n_nodes, n_coords), value) + projection_bounds = frame_group.session.projection_bounds + cams_to_include = frame_group.cams_to_include + instance_group.update_points_from_2d( + points_reprojected=points, + projection_bounds=projection_bounds, + cams_to_include=cams_to_include, + exclude_complete=False, + ) + assert np.all(instance_group.numpy(invisible_as_nan=False) == value) + + def test_frame_group( multiview_min_session_labels: Labels, multiview_min_session_frame_groups: Labels ):