diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 84ec6b8b..d7b19868 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -54,7 +54,7 @@ jobs: # Tests with pytest tests: - timeout-minutes: 15 + timeout-minutes: 25 strategy: fail-fast: false matrix: @@ -92,7 +92,7 @@ jobs: if: ${{ startsWith(matrix.os, 'ubuntu') }} shell: bash -l {0} run: | - sudo apt-get update && sudo apt-get install libglapi-mesa libegl-mesa0 libegl1 libopengl0 libgl1-mesa-glx + sudo apt-get update && sudo apt-get install libglapi-mesa libegl-mesa0 libegl1 libopengl0 libgl1 libglx-mesa0 - name: Test with pytest (with coverage) shell: bash -l {0} diff --git a/pyproject.toml b/pyproject.toml index d9ceecfe..cc4b5d03 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ dependencies = [ "attrs", "h5py>=3.8.0", "pynwb", - "ndx-pose<0.2.0", + "ndx-pose>=0.2.1", "pandas", "simplejson", "imageio", diff --git a/sleap_io/io/nwb.py b/sleap_io/io/nwb.py index b46dc23a..e317cb45 100644 --- a/sleap_io/io/nwb.py +++ b/sleap_io/io/nwb.py @@ -1,7 +1,7 @@ """Functions to write and read from the neurodata without borders (NWB) format.""" from copy import deepcopy -from typing import List, Optional, Union +from typing import List, Optional, Union, Dict from pathlib import Path import datetime import uuid @@ -15,14 +15,14 @@ except ImportError: ArrayLike = np.ndarray from pynwb import NWBFile, NWBHDF5IO, ProcessingModule # type: ignore[import] -from ndx_pose import PoseEstimationSeries, PoseEstimation # type: ignore[import] +from ndx_pose import PoseEstimationSeries, PoseEstimation, Skeleton, Skeletons # type: ignore[import] from sleap_io import ( Labels, Video, LabeledFrame, Track, - Skeleton, + Skeleton as SleapSkeleton, Instance, PredictedInstance, ) @@ -116,22 +116,30 @@ def read_nwb(path: str) -> Labels: """ with NWBHDF5IO(path, mode="r", load_namespaces=True) as io: read_nwbfile = io.read() - nwb_file = read_nwbfile.processing + nwb_file_processing = read_nwbfile.processing # Get list of videos - video_keys: List[str] = [key for key in nwb_file.keys() if "SLEAP_VIDEO" in key] + video_keys: List[str] = [ + key for key in nwb_file_processing.keys() if "SLEAP_VIDEO" in key + ] video_tracks = dict() - # Get track keys - test_processing_module: ProcessingModule = nwb_file[video_keys[0]] + # Get track keys from first video's processing module + test_processing_module: ProcessingModule = nwb_file_processing[video_keys[0]] track_keys: List[str] = list(test_processing_module.fields["data_interfaces"]) - # Get track + # Get first track's skeleton test_pose_estimation: PoseEstimation = test_processing_module[track_keys[0]] - node_names = test_pose_estimation.nodes[:] - edge_inds = test_pose_estimation.edges[:] + skeleton = test_pose_estimation.skeleton + skeleton_nodes = skeleton.nodes[:] + skeleton_edges = skeleton.edges[:] - for processing_module in nwb_file.values(): + # Filtering out behavior module with skeletons + pose_estimation_container_modules = [ + nwb_file_processing[key] for key in video_keys + ] + + for processing_module in pose_estimation_container_modules: # Get track keys _track_keys: List[str] = list(processing_module.fields["data_interfaces"]) is_tracked: bool = re.sub("[0-9]+", "", _track_keys[0]) == "track" @@ -139,8 +147,9 @@ def read_nwb(path: str) -> Labels: # Figure out the max number of frames and the canonical timestamps timestamps = np.empty(()) for track_key in _track_keys: - for node_name in node_names: - pose_estimation_series = processing_module[track_key][node_name] + pose_estimation = processing_module[track_key] + for node_name in skeleton.nodes: + pose_estimation_series = pose_estimation[node_name] timestamps = np.union1d( timestamps, get_timestamps(pose_estimation_series) ) @@ -149,13 +158,13 @@ def read_nwb(path: str) -> Labels: # Recreate Labels numpy (same as output of Labels.numpy()) n_tracks = len(_track_keys) n_frames = len(timestamps) - n_nodes = len(node_names) + n_nodes = len(skeleton.nodes) tracks_numpy = np.full((n_frames, n_tracks, n_nodes, 2), np.nan, np.float32) confidence = np.full((n_frames, n_tracks, n_nodes), np.nan, np.float32) + for track_idx, track_key in enumerate(_track_keys): pose_estimation = processing_module[track_key] - - for node_idx, node_name in enumerate(node_names): + for node_idx, node_name in enumerate(skeleton.nodes): pose_estimation_series = pose_estimation[node_name] frame_inds = np.searchsorted( timestamps, get_timestamps(pose_estimation_series) @@ -173,10 +182,10 @@ def read_nwb(path: str) -> Labels: is_tracked, ) - # Create skeleton - skeleton = Skeleton( - nodes=node_names, - edges=edge_inds, + # Create SLEAP skeleton from NWB skeleton + sleap_skeleton = SleapSkeleton( + nodes=skeleton_nodes, + edges=skeleton_edges.tolist(), ) # Add instances to labeled frames @@ -185,6 +194,7 @@ def read_nwb(path: str) -> Labels: video = Video(filename=video_fn) n_frames, n_tracks, n_nodes, _ = tracks_numpy.shape tracks = [Track(name=f"track{track_idx}") for track_idx in range(n_tracks)] + for frame_idx, (frame_pts, frame_confs) in enumerate( zip(tracks_numpy, confidence) ): @@ -199,7 +209,7 @@ def read_nwb(path: str) -> Labels: points=inst_pts, # (n_nodes, 2) point_scores=inst_confs, # (n_nodes,) instance_score=inst_confs.mean(), # () - skeleton=skeleton, + skeleton=sleap_skeleton, track=track if is_tracked else None, ) ) @@ -207,11 +217,69 @@ def read_nwb(path: str) -> Labels: lfs.append( LabeledFrame(video=video, frame_idx=frame_idx, instances=insts) ) + labels = Labels(lfs) labels.provenance["filename"] = path return labels +def create_skeleton_container( + labels: Labels, + nwbfile: NWBFile, +) -> Dict[str, Skeleton]: + """Create NWB skeleton containers from SLEAP skeletons. + + Args: + labels: SLEAP Labels object containing skeleton definitions + nwbfile: NWB file to add skeletons to + + Returns: + Dictionary mapping skeleton names to NWB Skeleton objects + """ + skeleton_map = {} + nwb_skeletons = [] + + # Get or create behavior processing module + behavior_pm = nwbfile.processing.get("behavior") + if behavior_pm is None: + behavior_pm = nwbfile.create_processing_module( + name="behavior", description="processed behavioral data" + ) + + # Check if Skeletons container already exists + existing_skeletons = None + if "Skeletons" in behavior_pm.data_interfaces: + existing_skeletons = behavior_pm.data_interfaces["Skeletons"] + # Add existing skeletons to our map + for skeleton_name in existing_skeletons.skeletons: + nwb_skeleton = existing_skeletons.skeletons[skeleton_name] + skeleton_map[skeleton_name] = nwb_skeleton + + # Create new skeletons for ones that don't exist yet + for sleap_skeleton in labels.skeletons: + if sleap_skeleton.name not in skeleton_map: + nwb_skeleton = Skeleton( + name=sleap_skeleton.name, + nodes=sleap_skeleton.node_names, + edges=np.array(sleap_skeleton.edge_inds, dtype="uint8"), + ) + nwb_skeletons.append(nwb_skeleton) + skeleton_map[sleap_skeleton.name] = nwb_skeleton + + # If we have new skeletons to add + if nwb_skeletons: + if existing_skeletons is None: + # Create new Skeletons container if none exists + skeletons_container = Skeletons(skeletons=nwb_skeletons) + behavior_pm.add(skeletons_container) + else: + # Add new skeletons to existing container + for skeleton in nwb_skeletons: + existing_skeletons.add_skeleton(skeleton) + + return skeleton_map + + def write_nwb( labels: Labels, nwbfile_path: str, @@ -266,14 +334,22 @@ def write_nwb( ) nwbfile = NWBFile(**nwb_file_kwargs) - nwbfile = append_nwb_data(labels, nwbfile, pose_estimation_metadata) + + # Create skeleton containers first + skeleton_map = create_skeleton_container(labels, nwbfile) + + # Then append pose data + nwbfile = append_nwb_data(labels, nwbfile, pose_estimation_metadata, skeleton_map) with NWBHDF5IO(str(nwbfile_path), "w") as io: io.write(nwbfile) def append_nwb_data( - labels: Labels, nwbfile: NWBFile, pose_estimation_metadata: Optional[dict] = None + labels: Labels, + nwbfile: NWBFile, + pose_estimation_metadata: Optional[dict] = None, + skeleton_map: Optional[Dict[str, Skeleton]] = None, ) -> NWBFile: """Append data from a Labels object to an in-memory nwb file. @@ -294,11 +370,14 @@ def append_nwb_data( 2) The other use of this dictionary is to ovewrite sleap-io default arguments for the PoseEstimation container. see https://github.com/rly/ndx-pose for a full list or arguments. + skeleton_map: Mapping of skeleton names to NWB Skeleton objects. Returns: An in-memory nwbfile with the data from the labels object appended. """ pose_estimation_metadata = pose_estimation_metadata or dict() + if skeleton_map is None: + skeleton_map = create_skeleton_container(labels=labels, nwbfile=nwbfile) # Extract default metadata provenance = labels.provenance @@ -316,6 +395,16 @@ def append_nwb_data( processing_module_name, nwbfile ) + device_name = f"camera_{video_index}" + if device_name in nwbfile.devices: + device = nwbfile.devices[device_name] + else: + device = nwbfile.create_device( + name=device_name, + description=f"Camera for {video_path.name}", + manufacturer="Unknown", + ) + # Propagate video metadata default_metadata["original_videos"] = [f"{video.filename}"] # type: ignore default_metadata["labeled_videos"] = [f"{video.filename}"] # type: ignore @@ -337,6 +426,8 @@ def append_nwb_data( track_name, video, default_metadata, + skeleton_map, + devices=[device], ) nwb_processing_module.add(pose_estimation_container) @@ -395,6 +486,8 @@ def build_pose_estimation_container_for_track( track_name: str, video: Video, pose_estimation_metadata: dict, + skeleton_map: Dict[str, Skeleton], + devices: Optional[List] = None, ) -> PoseEstimation: """Create a PoseEstimation container for a track. @@ -404,7 +497,10 @@ def build_pose_estimation_container_for_track( labels (Labels): A general labels object track_name (str): The name of the track in labels.tracks video (Video): The video to which data belongs to - + pose_estimation_metadata: (dict) Metadata for pose estimation. See `append_nwb_data` + skeleton_map: Mapping of skeleton names to NWB Skeleton objects + skeleton_map: Mapping of skeleton names to NWB Skeleton objects + devices: Optional list of recording devices Returns: PoseEstimation: A PoseEstimation multicontainer where the time series of all the node trajectories in the track are stored. One time series per @@ -422,13 +518,15 @@ def build_pose_estimation_container_for_track( # Assuming only one skeleton per track skeleton_name = all_track_skeletons[0] - skeleton = next( + sleap_skeleton = next( skeleton for skeleton in labels.skeletons if skeleton.name == skeleton_name ) + nwb_skeleton = skeleton_map[skeleton_name] + # Get track data track_data_df = labels_data_df[ video.filename, - skeleton.name, + sleap_skeleton.name, track_name, ] @@ -448,12 +546,12 @@ def build_pose_estimation_container_for_track( # Arrange and mix metadata pose_estimation_container_kwargs = dict( name=f"track={track_name}", - description=f"Estimated positions of {skeleton.name} in video {video_path.name}", + description=f"Estimated positions of {sleap_skeleton.name} in video {video_path.name}", pose_estimation_series=pose_estimation_series_list, - nodes=skeleton.node_names, - edges=np.array(skeleton.edge_inds).astype("uint64"), + skeleton=nwb_skeleton, source_software="SLEAP", - # dimensions=np.array([[video.backend.height, video.backend.width]]), + # dimensions=np.array([[video.height, video.width]], dtype="uint16"), + devices=devices or [], ) pose_estimation_container_kwargs.update(**pose_estimation_metadata_copy) @@ -468,12 +566,12 @@ def build_track_pose_estimation_list( """Build a list of PoseEstimationSeries from tracks. Args: - track_data_df (pd.DataFrame): A pandas DataFrame object containing the - trajectories for all the nodes associated with a specific track. + track_data_df: A pandas DataFrame containing the trajectories + for all the nodes associated with a specific track. + timestamps: Array of timestamps for the data points Returns: - List[PoseEstimationSeries]: The list of all the PoseEstimationSeries. - One for each node. + List of PoseEstimationSeries, one for each node. """ name_of_nodes_in_track = track_data_df.columns.get_level_values( "node_name" @@ -490,7 +588,7 @@ def build_track_pose_estimation_list( reference_frame = ( "The coordinates are in (x, y) relative to the top-left of the image. " "Coordinates refer to the midpoint of the pixel. " - "That is, t the midpoint of the top-left pixel is at (0, 0), whereas " + "That is, the midpoint of the top-left pixel is at (0, 0), whereas " "the top-left corner of that same pixel is at (-0.5, -0.5)." ) diff --git a/sleap_io/model/skeleton.py b/sleap_io/model/skeleton.py index 620e1bc4..fea89c08 100644 --- a/sleap_io/model/skeleton.py +++ b/sleap_io/model/skeleton.py @@ -129,6 +129,7 @@ def __attrs_post_init__(self): """Ensure nodes are `Node`s, edges are `Edge`s, and `Node` map is updated.""" self._convert_nodes() self._convert_edges() + self._convert_symmetries() self.rebuild_cache() def _convert_nodes(self): @@ -174,6 +175,44 @@ def _convert_edges(self): self.edges[i] = Edge(src, dst) + def _convert_symmetries(self): + """Convert list of symmetric node names or integers to `Symmetry` objects.""" + if isinstance(self.symmetries, np.ndarray): + self.symmetries = self.symmetries.tolist() + + node_names = self.node_names + for i, symmetry in enumerate(self.symmetries): + if type(symmetry) == Symmetry: + continue + node1, node2 = symmetry + if type(node1) == str: + try: + node1 = node_names.index(node1) + except ValueError: + raise ValueError( + f"Node '{node1}' specified in the symmetry list is not in the " + "nodes." + ) + if type(node1) == int or ( + np.isscalar(node1) and np.issubdtype(node1.dtype, np.integer) + ): + node1 = self.nodes[node1] + + if type(node2) == str: + try: + node2 = node_names.index(node2) + except ValueError: + raise ValueError( + f"Node '{node2}' specified in the symmetry list is not in the " + "nodes." + ) + if type(node2) == int or ( + np.isscalar(node2) and np.issubdtype(node2.dtype, np.integer) + ): + node2 = self.nodes[node2] + + self.symmetries[i] = Symmetry({node1, node2}) + def rebuild_cache(self, nodes: list[Node] | None = None): """Rebuild the node name/index to `Node` map caches. @@ -425,6 +464,17 @@ def add_symmetry( if symmetry not in self.symmetries: self.symmetries.append(symmetry) + def add_symmetries( + self, symmetries: list[Symmetry | tuple[NodeOrIndex, NodeOrIndex]] + ): + """Add multiple `Symmetry` relationships to the skeleton. + + Args: + symmetries: A list of `Symmetry` objects or 2-tuples of symmetric nodes. + """ + for symmetry in symmetries: + self.add_symmetry(*symmetry) + def rename_nodes(self, name_map: dict[NodeOrIndex, str] | list[str]): """Rename nodes in the skeleton. diff --git a/tests/io/test_nwb.py b/tests/io/test_nwb.py index 02d14ef7..1bef71ac 100644 --- a/tests/io/test_nwb.py +++ b/tests/io/test_nwb.py @@ -4,6 +4,7 @@ import numpy as np from pynwb import NWBFile, NWBHDF5IO +from pynwb.file import Subject from sleap_io import load_slp from sleap_io.io.nwb import write_nwb, append_nwb_data, get_timestamps @@ -28,17 +29,31 @@ def test_typical_case_append(nwbfile, slp_typical): labels = load_slp(slp_typical) nwbfile = append_nwb_data(labels, nwbfile) - # Test matching number of processing modules + # Test that behavior module exists + assert "behavior" in nwbfile.processing + behavior_pm = nwbfile.processing["behavior"] + + # Test that Skeletons container exists and has correct skeleton + assert "Skeletons" in behavior_pm.data_interfaces + skeletons_container = behavior_pm.data_interfaces["Skeletons"] + assert len(skeletons_container.skeletons) == len(labels.skeletons) + + # Test matching number of video processing modules number_of_videos = len(labels.videos) - assert len(nwbfile.processing) == number_of_videos + video_modules = [mod for mod in nwbfile.processing.keys() if "SLEAP_VIDEO" in mod] + assert len(video_modules) == number_of_videos - # Test processing module naming + # Test processing module naming and content video_index = 0 video = labels.videos[video_index] video_path = Path(video.filename) processing_module_name = f"SLEAP_VIDEO_{video_index:03}_{video_path.stem}" assert processing_module_name in nwbfile.processing + # Test device creation + device_name = f"camera_{video_index}" + assert device_name in nwbfile.devices + processing_module = nwbfile.processing[processing_module_name] all_containers = processing_module.data_interfaces # Test name of PoseEstimation containers @@ -48,11 +63,15 @@ def test_typical_case_append(nwbfile, slp_typical): # Test that the skeleton nodes are stored as nodes in containers pose_estimation_container = all_containers[container_name] - expected_node_names = [node.name for node in labels.skeletons[0]] - assert expected_node_names == pose_estimation_container.nodes + expected_skeleton_name = labels.skeletons[0].name + assert pose_estimation_container.skeleton.name == expected_skeleton_name + + # Test that skeleton nodes match + expected_node_names = labels.skeletons[0].node_names + assert expected_node_names == pose_estimation_container.skeleton.nodes # Test that each PoseEstimationSeries is named as a node - for node_name in pose_estimation_container.nodes: + for node_name in pose_estimation_container.skeleton.nodes: assert node_name in pose_estimation_container.pose_estimation_series @@ -122,8 +141,7 @@ def test_default_metadata_overwriting(nwbfile, slp_predictions_with_provenance): # Test that the value of scorer was overwritten for pose_estimation_container in processing_module.data_interfaces.values(): assert pose_estimation_container.scorer == "overwritten_value" - all_nodes = pose_estimation_container.nodes - for node in all_nodes: + for node in pose_estimation_container.skeleton.nodes: pose_estimation_series = pose_estimation_container[node] if pose_estimation_series.rate: assert pose_estimation_series.rate == expected_sampling_rate @@ -134,9 +152,14 @@ def test_complex_case_append(nwbfile, centered_pair): labels.clean(tracks=True) nwbfile = append_nwb_data(labels, nwbfile) - # Test matching number of processing modules + # Test Skeletons container + assert "behavior" in nwbfile.processing + behavior_pm = nwbfile.processing["behavior"] + assert "Skeletons" in behavior_pm.data_interfaces + + # Test matching number of processing modules plus the skeletonw number_of_videos = len(labels.videos) - assert len(nwbfile.processing) == number_of_videos + assert len(nwbfile.processing) == number_of_videos + 1 # Test processing module naming video_index = 0 @@ -156,15 +179,17 @@ def test_complex_case_append(nwbfile, centered_pair): expected_track_name = f"track={track.name}" assert expected_track_name in extracted_container_names - # Test one PoseEstimation container container_name = "track=1" pose_estimation_container = all_containers[container_name] - # Test that the skeleton nodes are store as nodes in containers - expected_node_names = [node.name for node in labels.skeletons[0]] - assert expected_node_names == pose_estimation_container.nodes - # Test that each PoseEstimationSeries is named as a node - for node_name in pose_estimation_container.nodes: + # Test skeleton reference and nodes + expected_skeleton_name = labels.skeletons[0].name + assert pose_estimation_container.skeleton.name == expected_skeleton_name + expected_node_names = labels.skeletons[0].node_names + assert expected_node_names == pose_estimation_container.skeleton.nodes + + # Test pose estimation series + for node_name in pose_estimation_container.skeleton.nodes: assert node_name in pose_estimation_container.pose_estimation_series @@ -233,9 +258,16 @@ def test_typical_case_write(slp_typical, tmp_path): with NWBHDF5IO(str(nwbfile_path), "r") as io: nwbfile = io.read() - # Test matching number of processing modules + # Test Skeletons container exists + assert "behavior" in nwbfile.processing + assert "Skeletons" in nwbfile.processing["behavior"].data_interfaces + + # Test video modules number_of_videos = len(labels.videos) - assert len(nwbfile.processing) == number_of_videos + video_modules = [ + mod for mod in nwbfile.processing.keys() if "SLEAP_VIDEO" in mod + ] + assert len(video_modules) == number_of_videos def test_get_timestamps(nwbfile, centered_pair): @@ -243,7 +275,6 @@ def test_get_timestamps(nwbfile, centered_pair): labels.clean(tracks=True) nwbfile = append_nwb_data(labels, nwbfile) processing = nwbfile.processing["SLEAP_VIDEO_000_centered_pair_low_quality"] - assert True # explicit timestamps series = processing["track=1"]["head"] diff --git a/tests/model/test_skeleton.py b/tests/model/test_skeleton.py index 5fccf4e1..568950ef 100644 --- a/tests/model/test_skeleton.py +++ b/tests/model/test_skeleton.py @@ -54,6 +54,15 @@ def test_skeleton(): with pytest.raises(ValueError): Skeleton(["A", "B"], edges=[("A", "C")]) + skel = Skeleton(["A", "B"], symmetries=[("A", "B")]) + assert skel.symmetry_inds == [(0, 1)] + + with pytest.raises(ValueError): + Skeleton(["A", "B"], symmetries=[("a", "B")]) + + with pytest.raises(ValueError): + Skeleton(["A", "B"], symmetries=[("A", "b")]) + def test_skeleton_node_map(): """Test `Skeleton` node map returns correct nodes.""" @@ -165,6 +174,11 @@ def test_add_symmetry(): skel.add_symmetry("E", "F") assert skel.symmetry_inds == [(0, 1), (2, 3), (4, 5)] + # Add symmetries + skel.add_nodes(["GL", "GR", "HL", "HR"]) + skel.add_symmetries([("GL", "GR"), ("HL", "HR")]) + assert skel.symmetry_inds == [(0, 1), (2, 3), (4, 5), (6, 7), (8, 9)] + def test_rename_nodes(): """Test renaming nodes in the skeleton."""