From 9771ff5f97a0df1b4f334817835ef8d3e28b56f5 Mon Sep 17 00:00:00 2001 From: roomrys <38435167+roomrys@users.noreply.github.com> Date: Fri, 13 Oct 2023 16:24:25 -0700 Subject: [PATCH] Add tests for `TriangulateSession` --- tests/gui/test_commands.py | 356 ++++++++++++++++++++++++++++++++++++- 1 file changed, 350 insertions(+), 6 deletions(-) diff --git a/tests/gui/test_commands.py b/tests/gui/test_commands.py index 6048e13ef..836fe0df5 100644 --- a/tests/gui/test_commands.py +++ b/tests/gui/test_commands.py @@ -2,7 +2,8 @@ import sys import time from pathlib import Path, PurePath -from typing import List +from typing import Dict, List +import numpy as np import pytest @@ -17,9 +18,11 @@ RemoveVideo, ReplaceVideo, SaveProjectAs, + TriangulateSession, get_new_version_filename, ) from sleap.instance import Instance, LabeledFrame +from sleap.io.cameras import Camcorder from sleap.io.convert import default_analysis_filename from sleap.io.dataset import Labels from sleap.io.format.adaptor import Adaptor @@ -28,11 +31,11 @@ from sleap.io.video import Video from sleap.util import get_package_file -# These imports cause trouble when running `pytest.main()` from within the file -# Comment out to debug tests file via VSCode's "Debug Python File" -from tests.info.test_h5 import extract_meta_hdf5 -from tests.io.test_formats import read_nix_meta -from tests.io.test_video import assert_video_params +# # These imports cause trouble when running `pytest.main()` from within the file +# # Comment out to debug tests file via VSCode's "Debug Python File" +# from tests.info.test_h5 import extract_meta_hdf5 +# from tests.io.test_formats import read_nix_meta +# from tests.io.test_video import assert_video_params def test_delete_user_dialog(centered_pair_predictions): @@ -952,3 +955,344 @@ def test_AddSession( assert len(labels.sessions) == 2 assert context.state["session"] is session assert labels.sessions[1] is not session + + +def test_triangulate_session_get_all_views_at_frame( + multiview_min_session_labels: Labels, +): + labels = multiview_min_session_labels + session = labels.sessions[0] + lf = labels.labeled_frames[0] + frame_idx = lf.frame_idx + + # Test with no cams_to_include, expect views from all linked cameras + views = TriangulateSession.get_all_views_at_frame(session, frame_idx) + assert len(views) == len(session.linked_cameras) + for cam in session.linked_cameras: + assert views[cam].frame_idx == frame_idx + assert views[cam].video == session[cam] + + # Test with cams_to_include, expect views from only those cameras + cams_to_include = session.linked_cameras[0:2] + views = TriangulateSession.get_all_views_at_frame( + session, frame_idx, cams_to_include=cams_to_include + ) + assert len(views) == len(cams_to_include) + for cam in cams_to_include: + assert views[cam].frame_idx == frame_idx + assert views[cam].video == session[cam] + + +def test_triangulate_session_get_instances_across_views( + multiview_min_session_labels: Labels, +): + + labels = multiview_min_session_labels + session = labels.sessions[0] + + # Test get_instances_across_views + lf: LabeledFrame = labels[0] + track = labels.tracks[0] + instances: Dict[ + Camcorder, Instance + ] = TriangulateSession.get_instances_across_views( + session=session, frame_idx=lf.frame_idx, track=track + ) + assert len(instances) == len(session.videos) + for vid in session.videos: + cam = session[vid] + inst = instances[cam] + assert inst.frame_idx == lf.frame_idx + assert inst.track == track + assert inst.video == vid + + # Try with excluding cam views + lf: LabeledFrame = labels[2] + track = labels.tracks[1] + cams_to_include = session.linked_cameras[:4] + videos_to_include: Dict[ + Camcorder, Video + ] = session.get_videos_from_selected_cameras(cams_to_include=cams_to_include) + assert len(cams_to_include) == 4 + assert len(videos_to_include) == len(cams_to_include) + instances: Dict[ + Camcorder, Instance + ] = TriangulateSession.get_instances_across_views( + session=session, + frame_idx=lf.frame_idx, + track=track, + cams_to_include=cams_to_include, + ) + assert len(instances) == len( + videos_to_include + ) # May not be true if no instances at that frame + for cam, vid in videos_to_include.items(): + inst = instances[cam] + assert inst.frame_idx == lf.frame_idx + assert inst.track == track + assert inst.video == vid + + # Try with only a single view + cams_to_include = [session.linked_cameras[0]] + with pytest.raises(ValueError): + instances = TriangulateSession.get_instances_across_views( + session=session, + frame_idx=lf.frame_idx, + cams_to_include=cams_to_include, + track=track, + require_multiple_views=True, + ) + + # Try with multiple views, but not enough instances + track = labels.tracks[1] + cams_to_include = session.linked_cameras[4:6] + with pytest.raises(ValueError): + instances = TriangulateSession.get_instances_across_views( + session=session, + frame_idx=lf.frame_idx, + cams_to_include=cams_to_include, + track=track, + require_multiple_views=True, + ) + + +def test_triangulate_session_get_and_verify_enough_instances( + multiview_min_session_labels: Labels, + caplog, +): + labels = multiview_min_session_labels + session = labels.sessions[0] + lf = labels.labeled_frames[0] + track = labels.tracks[1] + + # Test with no cams_to_include, expect views from all linked cameras + instances = TriangulateSession.get_and_verify_enough_instances( + session=session, frame_idx=lf.frame_idx, track=track + ) + assert len(instances) == 6 # Some views don't have an instance at this track + for cam in session.linked_cameras: + if cam.name in ["side", "sideL"]: # The views that don't have an instance + continue + assert instances[cam].frame_idx == lf.frame_idx + assert instances[cam].track == track + assert instances[cam].video == session[cam] + + # Test with cams_to_include, expect views from only those cameras + cams_to_include = session.linked_cameras[-2:] + instances = TriangulateSession.get_and_verify_enough_instances( + session=session, + frame_idx=lf.frame_idx, + cams_to_include=cams_to_include, + track=track, + ) + assert len(instances) == len(cams_to_include) + for cam in cams_to_include: + assert instances[cam].frame_idx == lf.frame_idx + assert instances[cam].track == track + assert instances[cam].video == session[cam] + + # Test with not enough instances, expect views from only those cameras + cams_to_include = session.linked_cameras[0:2] + instances = TriangulateSession.get_and_verify_enough_instances( + session=session, frame_idx=lf.frame_idx, cams_to_include=cams_to_include + ) + assert isinstance(instances, bool) + assert not instances + messages = "".join([rec.message for rec in caplog.records]) + assert "One or less instances found for frame" in messages + + +def test_triangulate_session_verify_enough_views( + multiview_min_session_labels: Labels, caplog +): + labels = multiview_min_session_labels + session = labels.sessions[0] + + # Test with enough views + enough_views = TriangulateSession.verify_enough_views( + session=session, show_dialog=False + ) + assert enough_views + messages = "".join([rec.message for rec in caplog.records]) + assert len(messages) == 0 + caplog.clear() + + # Test with not enough views + cams_to_include = [session.linked_cameras[0]] + enough_views = TriangulateSession.verify_enough_views( + session=session, cams_to_include=cams_to_include, show_dialog=False + ) + assert not enough_views + messages = "".join([rec.message for rec in caplog.records]) + assert "One or less cameras available." in messages + + +def test_triangulate_session_verify_views_and_instances( + multiview_min_session_labels: Labels, +): + labels = multiview_min_session_labels + session = labels.sessions[0] + + # Test with enough views and instances + lf = labels.labeled_frames[0] + instance = lf.instances[0] + + context = CommandContext.from_labels(labels) + params = { + "video": session.videos[0], + "session": session, + "frame_idx": lf.frame_idx, + "instance": instance, + "show_dialog": False, + } + enough_views = TriangulateSession.verify_views_and_instances(context, params) + assert enough_views + assert "instances" in params + + # Test with not enough views + cams_to_include = [session.linked_cameras[0]] + params = { + "video": session.videos[0], + "session": session, + "frame_idx": lf.frame_idx, + "instance": instance, + "cams_to_include": cams_to_include, + "show_dialog": False, + } + enough_views = TriangulateSession.verify_views_and_instances(context, params) + assert not enough_views + assert "instances" not in params + + +def test_triangulate_session_calculate_reprojected_points( + multiview_min_session_labels: Labels, +): + """Test `TriangulateSession.calculate_reprojected_points`.""" + + session = multiview_min_session_labels.sessions[0] + lf: LabeledFrame = multiview_min_session_labels[0] + track = multiview_min_session_labels.tracks[0] + instances: Dict[ + Camcorder, Instance + ] = TriangulateSession.get_instances_across_views( + session=session, frame_idx=lf.frame_idx, track=track + ) + instances_and_coords = TriangulateSession.calculate_reprojected_points( + session=session, instances=instances + ) + + # Check that we get the same number of instances as input + assert len(instances) == len(list(instances_and_coords)) + + # Check that each instance has the same number of points + for inst, inst_coords in instances_and_coords: + assert inst_coords.shape[1] == len(inst.skeleton) # (1, 15, 2) + + +def test_triangulate_session_get_instances_matrices( + multiview_min_session_labels: Labels, +): + """Test `TriangulateSession.get_instance_matrices`.""" + labels = multiview_min_session_labels + session = labels.sessions[0] + lf: LabeledFrame = labels[0] + track = labels.tracks[0] + instances: Dict[ + Camcorder, Instance + ] = TriangulateSession.get_instances_across_views( + session=session, frame_idx=lf.frame_idx, track=track + ) + instances_matrices = TriangulateSession.get_instances_matrices( + instances_ordered=instances.values() + ) + + # Verify shape + n_views = len(instances) + n_frames = 1 + n_tracks = 1 + n_nodes = len(labels.skeleton) + assert instances_matrices.shape == (n_views, n_frames, n_tracks, n_nodes, 2) + + +def test_triangulate_session_update_instances(multiview_min_session_labels: Labels): + """Test `RecordingSession.update_instances`.""" + + # Test update_instances + session = multiview_min_session_labels.sessions[0] + lf: LabeledFrame = multiview_min_session_labels[0] + track = multiview_min_session_labels.tracks[0] + instances: Dict[ + Camcorder, Instance + ] = TriangulateSession.get_instances_across_views( + session=session, frame_idx=lf.frame_idx, track=track + ) + instances_and_coordinates = TriangulateSession.calculate_reprojected_points( + session=session, instances=instances + ) + for inst, inst_coords in instances_and_coordinates: + assert inst_coords.shape == (1, len(inst.skeleton), 2) # Tracks, Nodes, 2 + # Assert coord are different from original + assert not np.array_equal(inst_coords, inst.points_array) + + # Just run for code coverage testing, do not test output here (race condition) + # (see "functional core, imperative shell" pattern) + TriangulateSession.update_instances(session=session, instances=instances) + + +def test_triangulate_session_do_action(multiview_min_session_labels: Labels): + """Test `TriangulateSession.do_action`.""" + + labels = multiview_min_session_labels + session = labels.sessions[0] + + # Test with enough views and instances + lf = labels.labeled_frames[0] + instance = lf.instances[0] + + context = CommandContext.from_labels(labels) + params = { + "video": session.videos[0], + "session": session, + "frame_idx": lf.frame_idx, + "instance": instance, + "ask_again": True, + } + TriangulateSession.do_action(context, params) + + # Test with not enough views + cams_to_include = [session.linked_cameras[0]] + params = { + "video": session.videos[0], + "session": session, + "frame_idx": lf.frame_idx, + "instance": instance, + "cams_to_include": cams_to_include, + "ask_again": True, + } + TriangulateSession.do_action(context, params) + + +def test_triangulate_session(multiview_min_session_labels: Labels): + """Test `TriangulateSession`, if""" + + labels = multiview_min_session_labels + session = labels.sessions[0] + video = session.videos[0] + lf = labels.labeled_frames[0] + instance = lf.instances[0] + context = CommandContext.from_labels(labels) + + # Test with enough views and instances so we don't get any GUI pop-ups + context.triangulateSession( + frame_idx=lf.frame_idx, + video=video, + instance=instance, + session=session, + ) + + # Test with using state to gather params + context.state["session"] = session + context.state["video"] = video + context.state["instance"] = instance + context.state["frame_idx"] = lf.frame_idx + context.triangulateSession()