From bd12d041543ad4f6f88d32d7bcd000cb530a19bb Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Thu, 18 Apr 2024 09:54:41 -0500 Subject: [PATCH 01/44] Create nwb_export.py --- movement/io/nwb_export.py | 143 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 143 insertions(+) create mode 100644 movement/io/nwb_export.py diff --git a/movement/io/nwb_export.py b/movement/io/nwb_export.py new file mode 100644 index 00000000..faac5717 --- /dev/null +++ b/movement/io/nwb_export.py @@ -0,0 +1,143 @@ +import xarray as xr + +try: + import ndx_pose + import pynwb +except ImportError: + ndx_pose = None + pynwb = None + + +def _create_pose_and_skeleton_objects( + ds: xr.Dataset, + pose_estimation_series_kwargs: dict = None, + pose_estimation_kwargs: dict = None, + skeleton_kwargs: dict = None, +) -> tuple[list[ndx_pose.PoseEstimation], ndx_pose.Skeletons]: + """Creates PoseEstimation and Skeletons objects from a movement xarray dataset. + + Parameters + ---------- + ds : xr.Dataset + Movement dataset containing the data to be converted to NWB. + pose_estimation_series_kwargs : dict, optional + PoseEstimationSeries keyword arguments. See ndx_pose, by default None + pose_estimation_kwargs : dict, optional + PoseEstimation keyword arguments. See ndx_pose, by default None + skeleton_kwargs : dict, optional + Skeleton keyword arguments. See ndx_pose, by default None + + Returns + ------- + pose_estimation : list[ndx_pose.PoseEstimation] + List of PoseEstimation objects + skeletons : ndx_pose.Skeletons + Skeletons object containing all skeletons + """ + if pose_estimation_series_kwargs is None: + pose_estimation_series_kwargs = dict( + reference_frame="(0,0,0) corresponds to ...", + confidence_definition=None, + conversion=1.0, + resolution=-1.0, + offset=0.0, + starting_time=None, + comments="no comments", + description="no description", + control=None, + control_description=None, + ) + + if skeleton_kwargs is None: + skeleton_kwargs = dict(edges=None) + + if pose_estimation_kwargs is None: + pose_estimation_kwargs = dict( + original_videos=None, + labeled_videos=None, + dimensions=None, + devices=None, + scorer=None, + source_software_version=None, + ) + + skeleton_list = [] + pose_estimation = [] + + for subject in ds.individuals.to_numpy(): + pose_estimation_series = [] + + for keypoint in ds.keypoints.to_numpy(): + pose_estimation_series.append( + ndx_pose.PoseEstimationSeries( + name=keypoint, + data=ds.sel( + keypoints=keypoint, individuals=subject + ).position.to_numpy(), + confidence=ds.sel( + keypoints=keypoint, individuals=subject + ).confidence.to_numpy(), + unit="pixels", + timestamps=ds.sel( + keypoints=keypoint, individuals=subject + ).time.to_numpy(), + **pose_estimation_series_kwargs, + ) + ) + + skeleton_list.append( + ndx_pose.Skeleton( + name=f"{subject}_skeleton", + nodes=ds.sel(individuals=subject).keypoints.to_numpy().tolist(), + **skeleton_kwargs, + ) + ) + + bodyparts_str = ", ".join( + ds.sel(individuals=subject).keypoints.to_numpy().tolist() + ) + description = f"Estimated positions of {bodyparts_str} of {subject} using {ds.source_software}." + + pose_estimation.append( + ndx_pose.PoseEstimation( + name="PoseEstimation", + pose_estimation_series=pose_estimation_series, + description=description, + source_software=ds.source_software, + skeleton=skeleton_list[-1], + **pose_estimation_kwargs, + ) + ) + + skeletons = ndx_pose.Skeletons(skeletons=skeleton_list) + + return pose_estimation, skeletons + + +def convert_movement_to_nwb( + nwbfile: pynwb.NWBFile, + ds: xr.Dataset, + pose_estimation_series_kwargs: dict = None, + pose_estimation_kwargs: dict = None, + skeletons_kwargs: dict = None, +): + pose_estimation, skeletons = _create_pose_and_skeleton_objects( + ds, pose_estimation_series_kwargs, pose_estimation_kwargs, skeletons_kwargs + ) + try: + behavior_pm = nwbfile.create_processing_module( + name="behavior", + description="processed behavioral data", + ) + except ValueError: + print("Behavior processing module already exists. Skipping...") + behavior_pm = nwbfile.processing["behavior"] + + try: + behavior_pm.add(skeletons) + except ValueError: + print("Skeletons already exists. Skipping...") + try: + behavior_pm.add(pose_estimation) + except ValueError: + print("PoseEstimation already exists. Skipping...") From c5319b9b1f763b3b8377a6440632b05e13172cb8 Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Thu, 18 Apr 2024 10:21:54 -0500 Subject: [PATCH 02/44] NWB requires one file per individual --- movement/io/nwb_export.py | 130 ++++++++++++++++++++------------------ 1 file changed, 68 insertions(+), 62 deletions(-) diff --git a/movement/io/nwb_export.py b/movement/io/nwb_export.py index faac5717..0d5c51d9 100644 --- a/movement/io/nwb_export.py +++ b/movement/io/nwb_export.py @@ -1,4 +1,5 @@ import xarray as xr +from typing import Union try: import ndx_pose @@ -10,6 +11,7 @@ def _create_pose_and_skeleton_objects( ds: xr.Dataset, + subject: str, pose_estimation_series_kwargs: dict = None, pose_estimation_kwargs: dict = None, skeleton_kwargs: dict = None, @@ -20,6 +22,8 @@ def _create_pose_and_skeleton_objects( ---------- ds : xr.Dataset Movement dataset containing the data to be converted to NWB. + subject : str + Name of the subject to be converted. pose_estimation_series_kwargs : dict, optional PoseEstimationSeries keyword arguments. See ndx_pose, by default None pose_estimation_kwargs : dict, optional @@ -61,53 +65,41 @@ def _create_pose_and_skeleton_objects( source_software_version=None, ) - skeleton_list = [] - pose_estimation = [] - - for subject in ds.individuals.to_numpy(): - pose_estimation_series = [] - - for keypoint in ds.keypoints.to_numpy(): - pose_estimation_series.append( - ndx_pose.PoseEstimationSeries( - name=keypoint, - data=ds.sel( - keypoints=keypoint, individuals=subject - ).position.to_numpy(), - confidence=ds.sel( - keypoints=keypoint, individuals=subject - ).confidence.to_numpy(), - unit="pixels", - timestamps=ds.sel( - keypoints=keypoint, individuals=subject - ).time.to_numpy(), - **pose_estimation_series_kwargs, - ) - ) - - skeleton_list.append( - ndx_pose.Skeleton( - name=f"{subject}_skeleton", - nodes=ds.sel(individuals=subject).keypoints.to_numpy().tolist(), - **skeleton_kwargs, + pose_estimation_series = [] + + for keypoint in ds.keypoints.to_numpy(): + pose_estimation_series.append( + ndx_pose.PoseEstimationSeries( + name=keypoint, + data=ds.sel(keypoints=keypoint).position.to_numpy(), + confidence=ds.sel(keypoints=keypoint).confidence.to_numpy(), + unit="pixels", + timestamps=ds.sel(keypoints=keypoint).time.to_numpy(), + **pose_estimation_series_kwargs, ) ) - bodyparts_str = ", ".join( - ds.sel(individuals=subject).keypoints.to_numpy().tolist() + skeleton_list = [ + ndx_pose.Skeleton( + name=f"{subject}_skeleton", + nodes=ds.keypoints.to_numpy().tolist(), + **skeleton_kwargs, ) - description = f"Estimated positions of {bodyparts_str} of {subject} using {ds.source_software}." - - pose_estimation.append( - ndx_pose.PoseEstimation( - name="PoseEstimation", - pose_estimation_series=pose_estimation_series, - description=description, - source_software=ds.source_software, - skeleton=skeleton_list[-1], - **pose_estimation_kwargs, - ) + ] + + bodyparts_str = ", ".join(ds.keypoints.to_numpy().tolist()) + description = f"Estimated positions of {bodyparts_str} of {subject} using {ds.source_software}." + + pose_estimation = [ + ndx_pose.PoseEstimation( + name="PoseEstimation", + pose_estimation_series=pose_estimation_series, + description=description, + source_software=ds.source_software, + skeleton=skeleton_list[-1], + **pose_estimation_kwargs, ) + ] skeletons = ndx_pose.Skeletons(skeletons=skeleton_list) @@ -115,29 +107,43 @@ def _create_pose_and_skeleton_objects( def convert_movement_to_nwb( - nwbfile: pynwb.NWBFile, + nwbfiles: Union[list[pynwb.NWBFile], pynwb.NWBFile], ds: xr.Dataset, pose_estimation_series_kwargs: dict = None, pose_estimation_kwargs: dict = None, skeletons_kwargs: dict = None, ): - pose_estimation, skeletons = _create_pose_and_skeleton_objects( - ds, pose_estimation_series_kwargs, pose_estimation_kwargs, skeletons_kwargs - ) - try: - behavior_pm = nwbfile.create_processing_module( - name="behavior", - description="processed behavioral data", + if isinstance(nwbfiles, pynwb.NWBFile): + nwbfiles = [nwbfiles] + + if len(nwbfiles) != len(ds.individuals): + raise ValueError( + "Number of NWBFiles must be equal to the number of individuals in the dataset. " + "NWB requires one file per individual." ) - except ValueError: - print("Behavior processing module already exists. Skipping...") - behavior_pm = nwbfile.processing["behavior"] - - try: - behavior_pm.add(skeletons) - except ValueError: - print("Skeletons already exists. Skipping...") - try: - behavior_pm.add(pose_estimation) - except ValueError: - print("PoseEstimation already exists. Skipping...") + + for nwbfile, subject in zip(nwbfiles, ds.individuals.to_numpy()): + pose_estimation, skeletons = _create_pose_and_skeleton_objects( + ds.sel(individuals=subject), + subject, + pose_estimation_series_kwargs, + pose_estimation_kwargs, + skeletons_kwargs, + ) + try: + behavior_pm = nwbfile.create_processing_module( + name="behavior", + description="processed behavioral data", + ) + except ValueError: + print("Behavior processing module already exists. Skipping...") + behavior_pm = nwbfile.processing["behavior"] + + try: + behavior_pm.add(skeletons) + except ValueError: + print("Skeletons already exists. Skipping...") + try: + behavior_pm.add(pose_estimation) + except ValueError: + print("PoseEstimation already exists. Skipping...") From d82fe3028b73b523ef7d71e9bf6dfdd614748c07 Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Fri, 19 Apr 2024 09:47:01 -0500 Subject: [PATCH 03/44] Add script --- examples/nwb_conversion.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 examples/nwb_conversion.py diff --git a/examples/nwb_conversion.py b/examples/nwb_conversion.py new file mode 100644 index 00000000..4499e951 --- /dev/null +++ b/examples/nwb_conversion.py @@ -0,0 +1,27 @@ +from movement import sample_data +from pynwb import NWBFile +import datetime +from movement.io.nwb_export import convert_movement_to_nwb + +# Load the sample data +ds = sample_data.fetch_sample_data("DLC_two-mice.predictions.csv") + +# The dataset has two individuals, we will create two NWBFiles for each individual + +nwbfile_individual1 = NWBFile( + session_description="session_description", + identifier="individual1", + session_start_time=datetime.datetime.now(datetime.timezone.utc), +) +nwbfile_individual2 = NWBFile( + session_description="session_description", + identifier="individual2", + session_start_time=datetime.datetime.now(datetime.timezone.utc), +) + +nwbfiles = [nwbfile_individual1, nwbfile_individual2] + +# Convert the dataset to NWB +# This will create PoseEstimation and Skeleton objects for each individual +# and add them to the NWBFile +convert_movement_to_nwb(nwbfiles, ds) From d889105af615ecc00dee595878bb5e2c1d60ab22 Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Fri, 19 Apr 2024 09:49:13 -0500 Subject: [PATCH 04/44] Remove import error handling --- movement/io/nwb_export.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/movement/io/nwb_export.py b/movement/io/nwb_export.py index 0d5c51d9..a9d3d6a2 100644 --- a/movement/io/nwb_export.py +++ b/movement/io/nwb_export.py @@ -1,12 +1,8 @@ -import xarray as xr from typing import Union -try: - import ndx_pose - import pynwb -except ImportError: - ndx_pose = None - pynwb = None +import ndx_pose +import pynwb +import xarray as xr def _create_pose_and_skeleton_objects( From 72aea478be453d434742f965c729b4913fb37547 Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Fri, 19 Apr 2024 09:49:24 -0500 Subject: [PATCH 05/44] Add nwb optional dependencies --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 22d1b600..6c502458 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,6 +58,7 @@ dev = [ "types-PyYAML", "types-requests", ] +nwb = ["pynwb", "git+https://github.com/rly/ndx-pose.git"] [build-system] requires = ["setuptools>=45", "wheel", "setuptools_scm[toml]>=6.2"] From 12bf83f3de6b84ca0a49d2c39b0db5f7bc24fd5b Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Fri, 19 Apr 2024 10:03:36 -0500 Subject: [PATCH 06/44] Fix linting based on pre-commit hooks --- examples/nwb_conversion.py | 14 +++++++++----- movement/io/nwb_export.py | 24 ++++++++++++++---------- pyproject.toml | 2 +- 3 files changed, 24 insertions(+), 16 deletions(-) diff --git a/examples/nwb_conversion.py b/examples/nwb_conversion.py index 4499e951..67ee5810 100644 --- a/examples/nwb_conversion.py +++ b/examples/nwb_conversion.py @@ -1,22 +1,26 @@ -from movement import sample_data -from pynwb import NWBFile import datetime + +from pynwb import NWBFile + +from movement import sample_data from movement.io.nwb_export import convert_movement_to_nwb # Load the sample data ds = sample_data.fetch_sample_data("DLC_two-mice.predictions.csv") -# The dataset has two individuals, we will create two NWBFiles for each individual +# The dataset has two individuals +# we will create two NWBFiles for each individual +session_start_time = datetime.datetime.now(datetime.timezone.utc) nwbfile_individual1 = NWBFile( session_description="session_description", identifier="individual1", - session_start_time=datetime.datetime.now(datetime.timezone.utc), + session_start_time=session_start_time, ) nwbfile_individual2 = NWBFile( session_description="session_description", identifier="individual2", - session_start_time=datetime.datetime.now(datetime.timezone.utc), + session_start_time=session_start_time, ) nwbfiles = [nwbfile_individual1, nwbfile_individual2] diff --git a/movement/io/nwb_export.py b/movement/io/nwb_export.py index a9d3d6a2..77e211ee 100644 --- a/movement/io/nwb_export.py +++ b/movement/io/nwb_export.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import Optional, Union import ndx_pose import pynwb @@ -8,11 +8,12 @@ def _create_pose_and_skeleton_objects( ds: xr.Dataset, subject: str, - pose_estimation_series_kwargs: dict = None, - pose_estimation_kwargs: dict = None, - skeleton_kwargs: dict = None, + pose_estimation_series_kwargs: Optional[dict] = None, + pose_estimation_kwargs: Optional[dict] = None, + skeleton_kwargs: Optional[dict] = None, ) -> tuple[list[ndx_pose.PoseEstimation], ndx_pose.Skeletons]: - """Creates PoseEstimation and Skeletons objects from a movement xarray dataset. + """Creates PoseEstimation and Skeletons objects from a movement xarray + dataset. Parameters ---------- @@ -84,7 +85,10 @@ def _create_pose_and_skeleton_objects( ] bodyparts_str = ", ".join(ds.keypoints.to_numpy().tolist()) - description = f"Estimated positions of {bodyparts_str} of {subject} using {ds.source_software}." + description = ( + f"Estimated positions of {bodyparts_str} of" + f"{subject} using {ds.source_software}." + ) pose_estimation = [ ndx_pose.PoseEstimation( @@ -105,16 +109,16 @@ def _create_pose_and_skeleton_objects( def convert_movement_to_nwb( nwbfiles: Union[list[pynwb.NWBFile], pynwb.NWBFile], ds: xr.Dataset, - pose_estimation_series_kwargs: dict = None, - pose_estimation_kwargs: dict = None, - skeletons_kwargs: dict = None, + pose_estimation_series_kwargs: Optional[dict] = None, + pose_estimation_kwargs: Optional[dict] = None, + skeletons_kwargs: Optional[dict] = None, ): if isinstance(nwbfiles, pynwb.NWBFile): nwbfiles = [nwbfiles] if len(nwbfiles) != len(ds.individuals): raise ValueError( - "Number of NWBFiles must be equal to the number of individuals in the dataset. " + "Number of NWBFiles must be equal to the number of individuals. " "NWB requires one file per individual." ) diff --git a/pyproject.toml b/pyproject.toml index 6c502458..8a017910 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,7 +58,7 @@ dev = [ "types-PyYAML", "types-requests", ] -nwb = ["pynwb", "git+https://github.com/rly/ndx-pose.git"] +nwb = ["pynwb", "ndx-pose"] [build-system] requires = ["setuptools>=45", "wheel", "setuptools_scm[toml]>=6.2"] From 742bf86b38e75610c1c5702693727dd3e4dc6bde Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Fri, 19 Apr 2024 10:21:05 -0500 Subject: [PATCH 07/44] Add example docstring --- examples/nwb_conversion.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/examples/nwb_conversion.py b/examples/nwb_conversion.py index 67ee5810..bc7b80cd 100644 --- a/examples/nwb_conversion.py +++ b/examples/nwb_conversion.py @@ -1,3 +1,10 @@ +""" +Export pose tracks to NWB +============================ + +Export pose tracks to NWB +""" + import datetime from pynwb import NWBFile From a06d48522dccb66bb4aa80c5f052c534d496ff6b Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Fri, 19 Apr 2024 13:31:34 -0500 Subject: [PATCH 08/44] Rename to fit module naming pattern --- examples/nwb_conversion.py | 2 +- movement/io/{nwb_export.py => export_nwb.py} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename movement/io/{nwb_export.py => export_nwb.py} (100%) diff --git a/examples/nwb_conversion.py b/examples/nwb_conversion.py index bc7b80cd..396e578b 100644 --- a/examples/nwb_conversion.py +++ b/examples/nwb_conversion.py @@ -10,7 +10,7 @@ from pynwb import NWBFile from movement import sample_data -from movement.io.nwb_export import convert_movement_to_nwb +from movement.io.export_nwb import convert_movement_to_nwb # Load the sample data ds = sample_data.fetch_sample_data("DLC_two-mice.predictions.csv") diff --git a/movement/io/nwb_export.py b/movement/io/export_nwb.py similarity index 100% rename from movement/io/nwb_export.py rename to movement/io/export_nwb.py From 739c4d869bc2a57ac398098e8f03faa075365f08 Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Fri, 19 Apr 2024 14:45:20 -0500 Subject: [PATCH 09/44] Add import from nwb --- examples/nwb_conversion.py | 2 +- movement/io/{export_nwb.py => nwb.py} | 63 +++++++++++++++++++++++++++ 2 files changed, 64 insertions(+), 1 deletion(-) rename movement/io/{export_nwb.py => nwb.py} (71%) diff --git a/examples/nwb_conversion.py b/examples/nwb_conversion.py index 396e578b..e796c222 100644 --- a/examples/nwb_conversion.py +++ b/examples/nwb_conversion.py @@ -10,7 +10,7 @@ from pynwb import NWBFile from movement import sample_data -from movement.io.export_nwb import convert_movement_to_nwb +from movement.io.nwb import convert_movement_to_nwb # Load the sample data ds = sample_data.fetch_sample_data("DLC_two-mice.predictions.csv") diff --git a/movement/io/export_nwb.py b/movement/io/nwb.py similarity index 71% rename from movement/io/export_nwb.py rename to movement/io/nwb.py index 77e211ee..8afce46a 100644 --- a/movement/io/export_nwb.py +++ b/movement/io/nwb.py @@ -1,6 +1,7 @@ from typing import Optional, Union import ndx_pose +import numpy as np import pynwb import xarray as xr @@ -147,3 +148,65 @@ def convert_movement_to_nwb( behavior_pm.add(pose_estimation) except ValueError: print("PoseEstimation already exists. Skipping...") + + +def _convert_pse( + pes: ndx_pose.PoseEstimationSeries, + keypoint: str, + subject_name: str, + source_software: str, + source_file: Optional[str] = None, +): + attrs = { + "fps": int(np.median(1 / np.diff(pes.timestamps))), + "time_units": pes.timestamps_unit, + "source_software": source_software, + "source_file": source_file, + } + n_space_dims = pes.data.shape[1] + space_dims = ["x", "y", "z"] + + return xr.Dataset( + data_vars={ + "position": ( + ["time", "individuals", "keypoints", "space"], + pes.data[:, np.newaxis, np.newaxis, :], + ), + "confidence": ( + ["time", "individuals", "keypoints"], + pes.confidence[:, np.newaxis, np.newaxis], + ), + }, + coords={ + "time": pes.timestamps, + "individuals": [subject_name], + "keypoints": [keypoint], + "space": space_dims[:n_space_dims], + }, + attrs=attrs, + ) + + +def convert_nwb_to_movement(nwb_filepaths: list[str]) -> xr.Dataset: + datasets = [] + for path in nwb_filepaths: + with pynwb.NWBHDF5IO(path, mode="r") as io: + nwbfile = io.read() + pose_estimation = nwbfile.processing["behavior"]["PoseEstimation"] + source_software = pose_estimation.fields["source_software"] + pose_estimation_series = pose_estimation.fields[ + "pose_estimation_series" + ] + + for keypoint, pes in pose_estimation_series.items(): + datasets.append( + _convert_pse( + pes, + keypoint, + subject_name=nwbfile.identifier, + source_software=source_software, + source_file=None, + ) + ) + + return xr.merge(datasets) From ce28f904e9adf0a61b832ba7bd408eab61f9efba Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 22 Apr 2024 13:29:05 +0000 Subject: [PATCH 10/44] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/nwb_conversion.py | 3 +-- movement/io/nwb.py | 1 + 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/nwb_conversion.py b/examples/nwb_conversion.py index e796c222..29d9f4c5 100644 --- a/examples/nwb_conversion.py +++ b/examples/nwb_conversion.py @@ -1,5 +1,4 @@ -""" -Export pose tracks to NWB +"""Export pose tracks to NWB ============================ Export pose tracks to NWB diff --git a/movement/io/nwb.py b/movement/io/nwb.py index 8afce46a..722bfb0a 100644 --- a/movement/io/nwb.py +++ b/movement/io/nwb.py @@ -35,6 +35,7 @@ def _create_pose_and_skeleton_objects( List of PoseEstimation objects skeletons : ndx_pose.Skeletons Skeletons object containing all skeletons + """ if pose_estimation_series_kwargs is None: pose_estimation_series_kwargs = dict( From 2491cf6510e10897f72694931a984de672682201 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 8 Jun 2024 16:01:33 +0000 Subject: [PATCH 11/44] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- movement/io/nwb.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/movement/io/nwb.py b/movement/io/nwb.py index 722bfb0a..8960227d 100644 --- a/movement/io/nwb.py +++ b/movement/io/nwb.py @@ -1,5 +1,3 @@ -from typing import Optional, Union - import ndx_pose import numpy as np import pynwb @@ -9,9 +7,9 @@ def _create_pose_and_skeleton_objects( ds: xr.Dataset, subject: str, - pose_estimation_series_kwargs: Optional[dict] = None, - pose_estimation_kwargs: Optional[dict] = None, - skeleton_kwargs: Optional[dict] = None, + pose_estimation_series_kwargs: dict | None = None, + pose_estimation_kwargs: dict | None = None, + skeleton_kwargs: dict | None = None, ) -> tuple[list[ndx_pose.PoseEstimation], ndx_pose.Skeletons]: """Creates PoseEstimation and Skeletons objects from a movement xarray dataset. @@ -109,11 +107,11 @@ def _create_pose_and_skeleton_objects( def convert_movement_to_nwb( - nwbfiles: Union[list[pynwb.NWBFile], pynwb.NWBFile], + nwbfiles: list[pynwb.NWBFile] | pynwb.NWBFile, ds: xr.Dataset, - pose_estimation_series_kwargs: Optional[dict] = None, - pose_estimation_kwargs: Optional[dict] = None, - skeletons_kwargs: Optional[dict] = None, + pose_estimation_series_kwargs: dict | None = None, + pose_estimation_kwargs: dict | None = None, + skeletons_kwargs: dict | None = None, ): if isinstance(nwbfiles, pynwb.NWBFile): nwbfiles = [nwbfiles] @@ -124,7 +122,9 @@ def convert_movement_to_nwb( "NWB requires one file per individual." ) - for nwbfile, subject in zip(nwbfiles, ds.individuals.to_numpy()): + for nwbfile, subject in zip( + nwbfiles, ds.individuals.to_numpy(), strict=False + ): pose_estimation, skeletons = _create_pose_and_skeleton_objects( ds.sel(individuals=subject), subject, @@ -156,7 +156,7 @@ def _convert_pse( keypoint: str, subject_name: str, source_software: str, - source_file: Optional[str] = None, + source_file: str | None = None, ): attrs = { "fps": int(np.median(1 / np.diff(pes.timestamps))), From 58bef932e531dd9fe767e6f0f54a4c1570f91c2c Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Sat, 8 Jun 2024 09:02:33 -0700 Subject: [PATCH 12/44] Apply suggestions from code review Co-authored-by: Niko Sirmpilatze --- movement/io/nwb.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/movement/io/nwb.py b/movement/io/nwb.py index 8960227d..0b85fe53 100644 --- a/movement/io/nwb.py +++ b/movement/io/nwb.py @@ -11,13 +11,12 @@ def _create_pose_and_skeleton_objects( pose_estimation_kwargs: dict | None = None, skeleton_kwargs: dict | None = None, ) -> tuple[list[ndx_pose.PoseEstimation], ndx_pose.Skeletons]: - """Creates PoseEstimation and Skeletons objects from a movement xarray - dataset. + """Create PoseEstimation and Skeletons objects from a movement dataset. Parameters ---------- ds : xr.Dataset - Movement dataset containing the data to be converted to NWB. + movement dataset containing the data to be converted to NWB. subject : str Name of the subject to be converted. pose_estimation_series_kwargs : dict, optional From 3ab9aa47222dad961c10814e92260e64271c6920 Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Sat, 8 Jun 2024 09:26:33 -0700 Subject: [PATCH 13/44] Update make pynwb and ndx-pose core dependencies --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 0ddff4c6..226ca80b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,8 @@ dependencies = [ "sleap-io", "xarray[accel,viz]", "PyYAML", + "pynwb", + "ndx-pose>=0.2", ] classifiers = [ @@ -58,7 +60,6 @@ dev = [ "types-PyYAML", "types-requests", ] -nwb = ["pynwb", "ndx-pose"] [project.scripts] movement = "movement.cli_entrypoint:main" From 910ce9005e81b473c0da5cf48e12ef7b6b8a00a7 Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Sat, 8 Jun 2024 10:58:07 -0700 Subject: [PATCH 14/44] Cleanup of docstrings and variable names from code review --- movement/io/nwb.py | 107 ++++++++++++++++++++++++++++++++++++--------- 1 file changed, 86 insertions(+), 21 deletions(-) diff --git a/movement/io/nwb.py b/movement/io/nwb.py index 0b85fe53..f95cfde1 100644 --- a/movement/io/nwb.py +++ b/movement/io/nwb.py @@ -1,8 +1,14 @@ +"""Functions to convert movement data to and from NWB format.""" + +from pathlib import Path + import ndx_pose import numpy as np import pynwb import xarray as xr +from movement.logging import log_error + def _create_pose_and_skeleton_objects( ds: xr.Dataset, @@ -11,14 +17,14 @@ def _create_pose_and_skeleton_objects( pose_estimation_kwargs: dict | None = None, skeleton_kwargs: dict | None = None, ) -> tuple[list[ndx_pose.PoseEstimation], ndx_pose.Skeletons]: - """Create PoseEstimation and Skeletons objects from a movement dataset. + """Create PoseEstimation and Skeletons objects from a ``movement`` dataset. Parameters ---------- - ds : xr.Dataset + ds : xarray.Dataset movement dataset containing the data to be converted to NWB. subject : str - Name of the subject to be converted. + Name of the subject (individual) to be converted. pose_estimation_series_kwargs : dict, optional PoseEstimationSeries keyword arguments. See ndx_pose, by default None pose_estimation_kwargs : dict, optional @@ -107,25 +113,48 @@ def _create_pose_and_skeleton_objects( def convert_movement_to_nwb( nwbfiles: list[pynwb.NWBFile] | pynwb.NWBFile, - ds: xr.Dataset, + movement_dataset: xr.Dataset, pose_estimation_series_kwargs: dict | None = None, pose_estimation_kwargs: dict | None = None, skeletons_kwargs: dict | None = None, -): +) -> None: + """Convert a ``movement`` dataset to the ndx-pose extension format for NWB. + + Parameters + ---------- + nwbfiles : list[pynwb.NWBFile] | pynwb.NWBFile + NWBFile object(s) to which the data will be added. + movement_dataset : xr.Dataset + ``movement`` dataset containing the data to be converted to NWB. + pose_estimation_series_kwargs : dict, optional + PoseEstimationSeries keyword arguments. See ndx_pose, by default None + pose_estimation_kwargs : dict, optional + PoseEstimation keyword arguments. See ndx_pose, by default None + skeletons_kwargs : dict, optional + Skeleton keyword arguments. See ndx_pose, by default None + + Raises + ------ + ValueError + If the number of NWBFiles is not equal to the number of individuals + in the dataset. + + """ if isinstance(nwbfiles, pynwb.NWBFile): nwbfiles = [nwbfiles] - if len(nwbfiles) != len(ds.individuals): - raise ValueError( + if len(nwbfiles) != len(movement_dataset.individuals): + raise log_error( + ValueError, "Number of NWBFiles must be equal to the number of individuals. " - "NWB requires one file per individual." + "NWB requires one file per individual.", ) for nwbfile, subject in zip( - nwbfiles, ds.individuals.to_numpy(), strict=False + nwbfiles, movement_dataset.individuals.to_numpy(), strict=False ): pose_estimation, skeletons = _create_pose_and_skeleton_objects( - ds.sel(individuals=subject), + movement_dataset.sel(individuals=subject), subject, pose_estimation_series_kwargs, pose_estimation_kwargs, @@ -150,35 +179,56 @@ def convert_movement_to_nwb( print("PoseEstimation already exists. Skipping...") -def _convert_pse( - pes: ndx_pose.PoseEstimationSeries, +def _convert_pose_estimation_series( + pose_estimation_series: ndx_pose.PoseEstimationSeries, keypoint: str, subject_name: str, source_software: str, source_file: str | None = None, -): +) -> xr.Dataset: + """Convert to single-keypoint, single-individual ``movement`` dataset. + + Parameters + ---------- + pose_estimation_series : ndx_pose.PoseEstimationSeries + PoseEstimationSeries NWB object to be converted. + keypoint : str + Name of the keypoint - body part. + subject_name : str + Name of the subject (individual). + source_software : str + Name of the software used to estimate the pose. + source_file : Optional[str], optional + File from which the data was extracted, by default None + + Returns + ------- + movement_dataset : xr.Dataset + ``movement`` compatible dataset containing the pose estimation data. + + """ attrs = { - "fps": int(np.median(1 / np.diff(pes.timestamps))), - "time_units": pes.timestamps_unit, + "fps": np.nanmedian(1 / np.diff(pose_estimation_series.timestamps)), + "time_units": pose_estimation_series.timestamps_unit, "source_software": source_software, "source_file": source_file, } - n_space_dims = pes.data.shape[1] + n_space_dims = pose_estimation_series.data.shape[1] space_dims = ["x", "y", "z"] return xr.Dataset( data_vars={ "position": ( ["time", "individuals", "keypoints", "space"], - pes.data[:, np.newaxis, np.newaxis, :], + pose_estimation_series.data[:, np.newaxis, np.newaxis, :], ), "confidence": ( ["time", "individuals", "keypoints"], - pes.confidence[:, np.newaxis, np.newaxis], + pose_estimation_series.confidence[:, np.newaxis, np.newaxis], ), }, coords={ - "time": pes.timestamps, + "time": pose_estimation_series.timestamps, "individuals": [subject_name], "keypoints": [keypoint], "space": space_dims[:n_space_dims], @@ -187,7 +237,22 @@ def _convert_pse( ) -def convert_nwb_to_movement(nwb_filepaths: list[str]) -> xr.Dataset: +def convert_nwb_to_movement( + nwb_filepaths: list[str] | list[Path], +) -> xr.Dataset: + """Convert a list of NWB files to a single ``movement`` dataset. + + Parameters + ---------- + nwb_filepaths : Union[list[str], list[Path]] + List of paths to NWB files to be converted. + + Returns + ------- + movement_ds : xr.Dataset + ``movement`` dataset containing the pose estimation data. + + """ datasets = [] for path in nwb_filepaths: with pynwb.NWBHDF5IO(path, mode="r") as io: @@ -200,7 +265,7 @@ def convert_nwb_to_movement(nwb_filepaths: list[str]) -> xr.Dataset: for keypoint, pes in pose_estimation_series.items(): datasets.append( - _convert_pse( + _convert_pose_estimation_series( pes, keypoint, subject_name=nwbfile.identifier, From 3f9a53becfabc55aa57edca94cce3253cf62a299 Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Sat, 8 Jun 2024 11:15:20 -0700 Subject: [PATCH 15/44] Rename function for clarity --- movement/io/nwb.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/movement/io/nwb.py b/movement/io/nwb.py index f95cfde1..8def398e 100644 --- a/movement/io/nwb.py +++ b/movement/io/nwb.py @@ -111,14 +111,14 @@ def _create_pose_and_skeleton_objects( return pose_estimation, skeletons -def convert_movement_to_nwb( +def add_movement_dataset_to_nwb( nwbfiles: list[pynwb.NWBFile] | pynwb.NWBFile, movement_dataset: xr.Dataset, pose_estimation_series_kwargs: dict | None = None, pose_estimation_kwargs: dict | None = None, skeletons_kwargs: dict | None = None, ) -> None: - """Convert a ``movement`` dataset to the ndx-pose extension format for NWB. + """Add pose estimation data to NWB files for each individual. Parameters ---------- From 3cd991dae5fc2daeaa1da9caa4eed201ef25f7d0 Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Sat, 8 Jun 2024 11:16:54 -0700 Subject: [PATCH 16/44] Update with example converting back to movement --- examples/nwb_conversion.py | 39 +++++++++++++++++++++++++++----------- 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/examples/nwb_conversion.py b/examples/nwb_conversion.py index 29d9f4c5..86ad040a 100644 --- a/examples/nwb_conversion.py +++ b/examples/nwb_conversion.py @@ -1,21 +1,24 @@ -"""Export pose tracks to NWB +"""Converting movement dataset to NWB or loading from NWB to movement dataset. ============================ Export pose tracks to NWB """ +# %% Load the sample data import datetime -from pynwb import NWBFile +from pynwb import NWBHDF5IO, NWBFile from movement import sample_data -from movement.io.nwb import convert_movement_to_nwb +from movement.io.nwb import ( + add_movement_dataset_to_nwb, + convert_nwb_to_movement, +) -# Load the sample data -ds = sample_data.fetch_sample_data("DLC_two-mice.predictions.csv") +ds = sample_data.fetch_dataset("DLC_two-mice.predictions.csv") -# The dataset has two individuals -# we will create two NWBFiles for each individual +# %%The dataset has two individuals. +# We will create two NWBFiles for each individual session_start_time = datetime.datetime.now(datetime.timezone.utc) nwbfile_individual1 = NWBFile( @@ -31,7 +34,21 @@ nwbfiles = [nwbfile_individual1, nwbfile_individual2] -# Convert the dataset to NWB -# This will create PoseEstimation and Skeleton objects for each individual -# and add them to the NWBFile -convert_movement_to_nwb(nwbfiles, ds) +# %% Convert the dataset to NWB +# This will create PoseEstimation and Skeleton objects for each +# individual and add them to the NWBFile +add_movement_dataset_to_nwb(nwbfiles, ds) + +# %% Save the NWBFiles +for file in nwbfiles: + with NWBHDF5IO(f"{file.identifier}.nwb", "w") as io: + io.write(file) + +# %% Convert the NWBFiles back to a movement dataset +# This will create a movement dataset with the same data as +# the original dataset from the NWBFiles + +# Convert the NWBFiles to a movement dataset +ds_from_nwb = convert_nwb_to_movement( + nwb_filepaths=["individual1.nwb", "individual2.nwb"] +) From 3aa1b11cf8de6e88cfbd4940836b9d3c860be091 Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Sat, 8 Jun 2024 11:29:58 -0700 Subject: [PATCH 17/44] Add file validation and handling for single path --- movement/io/nwb.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/movement/io/nwb.py b/movement/io/nwb.py index 8def398e..8f39d0ba 100644 --- a/movement/io/nwb.py +++ b/movement/io/nwb.py @@ -7,6 +7,7 @@ import pynwb import xarray as xr +from movement.io.save_poses import _validate_file_path from movement.logging import log_error @@ -238,13 +239,13 @@ def _convert_pose_estimation_series( def convert_nwb_to_movement( - nwb_filepaths: list[str] | list[Path], + nwb_filepaths: str | list[str] | list[Path], ) -> xr.Dataset: """Convert a list of NWB files to a single ``movement`` dataset. Parameters ---------- - nwb_filepaths : Union[list[str], list[Path]] + nwb_filepaths : str | Path | list[str] | list[Path] List of paths to NWB files to be converted. Returns @@ -253,8 +254,12 @@ def convert_nwb_to_movement( ``movement`` dataset containing the pose estimation data. """ + if isinstance(nwb_filepaths, str | Path): + nwb_filepaths = [nwb_filepaths] + datasets = [] for path in nwb_filepaths: + _validate_file_path(path, expected_suffix=[".nwb"]) with pynwb.NWBHDF5IO(path, mode="r") as io: nwbfile = io.read() pose_estimation = nwbfile.processing["behavior"]["PoseEstimation"] From e56cf6df0c395e24ae9a404233850557b74abbb6 Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Sat, 8 Jun 2024 11:46:01 -0700 Subject: [PATCH 18/44] Add preliminary tests --- tests/test_unit/test_nwb.py | 246 ++++++++++++++++++++++++++++++++++++ 1 file changed, 246 insertions(+) create mode 100644 tests/test_unit/test_nwb.py diff --git a/tests/test_unit/test_nwb.py b/tests/test_unit/test_nwb.py new file mode 100644 index 00000000..7319ea48 --- /dev/null +++ b/tests/test_unit/test_nwb.py @@ -0,0 +1,246 @@ +import ndx_pose +import numpy as np +import pynwb +import pytest +import xarray as xr + +from movement import sample_data +from movement.io.nwb import ( + _convert_pose_estimation_series, + _create_pose_and_skeleton_objects, + add_movement_dataset_to_nwb, + convert_nwb_to_movement, +) + + +def test_create_pose_and_skeleton_objects(): + # Create a sample dataset + ds = sample_data.fetch_dataset("DLC_two-mice.predictions.csv") + + # Call the function + pose_estimation, skeletons = _create_pose_and_skeleton_objects( + ds, + subject="subject1", + pose_estimation_series_kwargs=None, + pose_estimation_kwargs=None, + skeleton_kwargs=None, + ) + + # Assert the output types + assert isinstance(pose_estimation, list) + assert isinstance(skeletons, ndx_pose.Skeletons) + + # Assert the length of pose_estimation list + assert len(pose_estimation) == 1 + + # Assert the length of pose_estimation_series list + assert len(pose_estimation[0].pose_estimation_series) == 2 + + # Assert the name of the first PoseEstimationSeries + assert pose_estimation[0].pose_estimation_series[0].name == "keypoint1" + + # Assert the name of the second PoseEstimationSeries + assert pose_estimation[0].pose_estimation_series[1].name == "keypoint2" + + # Assert the name of the Skeleton + assert skeletons.skeletons[0].name == "subject1_skeleton" + + +def test__convert_pose_estimation_series(): + # Create a sample PoseEstimationSeries object + pose_estimation_series = ndx_pose.PoseEstimationSeries( + name="keypoint1", + data=np.random.rand(10, 3), + confidence=np.random.rand(10), + unit="pixels", + timestamps=np.arange(10), + ) + + # Call the function + movement_dataset = _convert_pose_estimation_series( + pose_estimation_series, + keypoint="keypoint1", + subject_name="subject1", + source_software="software1", + source_file="file1", + ) + + # Assert the dimensions of the movement dataset + assert movement_dataset.dims == { + "time": 10, + "individuals": 1, + "keypoints": 1, + "space": 3, + } + + # Assert the values of the position variable + np.testing.assert_array_equal( + movement_dataset["position"].values, + pose_estimation_series.data[:, np.newaxis, np.newaxis, :], + ) + + # Assert the values of the confidence variable + np.testing.assert_array_equal( + movement_dataset["confidence"].values, + pose_estimation_series.confidence[:, np.newaxis, np.newaxis], + ) + + # Assert the attributes of the movement dataset + assert movement_dataset.attrs == { + "fps": np.nanmedian(1 / np.diff(pose_estimation_series.timestamps)), + "time_units": pose_estimation_series.timestamps_unit, + "source_software": "software1", + "source_file": "file1", + } + + +def test_add_movement_dataset_to_nwb_single_file(): + # Create a sample NWBFile + nwbfile = pynwb.NWBFile( + "session_description", "identifier", "session_start_time" + ) + # Create a sample movement dataset + movement_dataset = xr.Dataset( + { + "keypoints": (["keypoints"], ["keypoint1", "keypoint2"]), + "position": (["time", "keypoints"], [[1, 2], [3, 4]]), + "confidence": (["time", "keypoints"], [[0.9, 0.8], [0.7, 0.6]]), + "time": [0, 1], + "individuals": ["subject1"], + } + ) + # Call the function + add_movement_dataset_to_nwb(nwbfile, movement_dataset) + # Assert the presence of PoseEstimation and Skeletons in the NWBFile + assert "PoseEstimation" in nwbfile.processing["behavior"] + assert "Skeletons" in nwbfile.processing["behavior"] + + +def test_add_movement_dataset_to_nwb_multiple_files(): + # Create sample NWBFiles + nwbfiles = [ + pynwb.NWBFile( + "session_description1", "identifier1", "session_start_time1" + ), + pynwb.NWBFile( + "session_description2", "identifier2", "session_start_time2" + ), + ] + # Create a sample movement dataset + movement_dataset = xr.Dataset( + { + "keypoints": (["keypoints"], ["keypoint1", "keypoint2"]), + "position": (["time", "keypoints"], [[1, 2], [3, 4]]), + "confidence": (["time", "keypoints"], [[0.9, 0.8], [0.7, 0.6]]), + "time": [0, 1], + "individuals": ["subject1", "subject2"], + } + ) + # Call the function + add_movement_dataset_to_nwb(nwbfiles, movement_dataset) + # Assert the presence of PoseEstimation and Skeletons in each NWBFile + for nwbfile in nwbfiles: + assert "PoseEstimation" in nwbfile.processing["behavior"] + assert "Skeletons" in nwbfile.processing["behavior"] + + +def test_convert_nwb_to_movement(): + # Create sample NWB files + nwb_filepaths = [ + "/path/to/file1.nwb", + "/path/to/file2.nwb", + "/path/to/file3.nwb", + ] + pose_estimation_series = { + "keypoint1": ndx_pose.PoseEstimationSeries( + name="keypoint1", + data=np.random.rand(10, 3), + confidence=np.random.rand(10), + unit="pixels", + timestamps=np.arange(10), + ), + "keypoint2": ndx_pose.PoseEstimationSeries( + name="keypoint2", + data=np.random.rand(10, 3), + confidence=np.random.rand(10), + unit="pixels", + timestamps=np.arange(10), + ), + } + + # Mock the NWBHDF5IO read method + def mock_read(filepath): + nwbfile = pynwb.NWBFile( + "session_description", "identifier", "session_start_time" + ) + + pose_estimation = ndx_pose.PoseEstimation( + name="PoseEstimation", + pose_estimation_series=pose_estimation_series, + description="Pose estimation data", + source_software="software1", + skeleton=ndx_pose.Skeleton( + name="skeleton1", nodes=["node1", "node2"] + ), + ) + behavior_pm = pynwb.ProcessingModule( + name="behavior", description="Behavior data" + ) + behavior_pm.add(pose_estimation) + nwbfile.add_processing_module(behavior_pm) + return nwbfile + + # Patch the NWBHDF5IO read method with the mock + with pytest.patch("pynwb.NWBHDF5IO.read", side_effect=mock_read): + # Call the function + movement_dataset = convert_nwb_to_movement(nwb_filepaths) + + # Assert the dimensions of the movement dataset + assert movement_dataset.dims == { + "time": 10, + "individuals": 3, + "keypoints": 2, + "space": 3, + } + + # Assert the values of the position variable + np.testing.assert_array_equal( + movement_dataset["position"].values, + np.concatenate( + [ + pose_estimation_series["keypoint1"].data[ + :, np.newaxis, np.newaxis, : + ], + pose_estimation_series["keypoint2"].data[ + :, np.newaxis, np.newaxis, : + ], + ], + axis=1, + ), + ) + + # Assert the values of the confidence variable + np.testing.assert_array_equal( + movement_dataset["confidence"].values, + np.concatenate( + [ + pose_estimation_series["keypoint1"].confidence[ + :, np.newaxis, np.newaxis + ], + pose_estimation_series["keypoint2"].confidence[ + :, np.newaxis, np.newaxis + ], + ], + axis=1, + ), + ) + + # Assert the attributes of the movement dataset + assert movement_dataset.attrs == { + "fps": np.nanmedian( + 1 / np.diff(pose_estimation_series["keypoint1"].timestamps) + ), + "time_units": pose_estimation_series["keypoint1"].timestamps_unit, + "source_software": "software1", + "source_file": None, + } From 99a90c1fe06ef9af6e598895bbb3c66a2ebb7266 Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Sun, 9 Jun 2024 10:54:36 -0700 Subject: [PATCH 19/44] Convert to numpy array --- movement/io/nwb.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/movement/io/nwb.py b/movement/io/nwb.py index 8f39d0ba..edb25618 100644 --- a/movement/io/nwb.py +++ b/movement/io/nwb.py @@ -7,7 +7,6 @@ import pynwb import xarray as xr -from movement.io.save_poses import _validate_file_path from movement.logging import log_error @@ -221,11 +220,15 @@ def _convert_pose_estimation_series( data_vars={ "position": ( ["time", "individuals", "keypoints", "space"], - pose_estimation_series.data[:, np.newaxis, np.newaxis, :], + np.asarray(pose_estimation_series.data)[ + :, np.newaxis, np.newaxis, : + ], ), "confidence": ( ["time", "individuals", "keypoints"], - pose_estimation_series.confidence[:, np.newaxis, np.newaxis], + np.asarray(pose_estimation_series.confidence)[ + :, np.newaxis, np.newaxis + ], ), }, coords={ @@ -259,7 +262,6 @@ def convert_nwb_to_movement( datasets = [] for path in nwb_filepaths: - _validate_file_path(path, expected_suffix=[".nwb"]) with pynwb.NWBHDF5IO(path, mode="r") as io: nwbfile = io.read() pose_estimation = nwbfile.processing["behavior"]["PoseEstimation"] From 02b997559a1ca4814e10d37b8ba5ab80b6ae7946 Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Sun, 9 Jun 2024 11:00:17 -0700 Subject: [PATCH 20/44] Handle lack of confidence --- movement/io/nwb.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/movement/io/nwb.py b/movement/io/nwb.py index edb25618..5a728490 100644 --- a/movement/io/nwb.py +++ b/movement/io/nwb.py @@ -216,19 +216,28 @@ def _convert_pose_estimation_series( n_space_dims = pose_estimation_series.data.shape[1] space_dims = ["x", "y", "z"] + position_array = np.asarray(pose_estimation_series.data)[ + :, np.newaxis, np.newaxis, : + ] + + if getattr(pose_estimation_series, "confidence", None) is None: + pose_estimation_series.confidence = np.full( + pose_estimation_series.data.shape[0], np.nan + ) + else: + confidence_array = np.asarray(pose_estimation_series.confidence)[ + :, np.newaxis, np.newaxis + ] + return xr.Dataset( data_vars={ "position": ( ["time", "individuals", "keypoints", "space"], - np.asarray(pose_estimation_series.data)[ - :, np.newaxis, np.newaxis, : - ], + position_array, ), "confidence": ( ["time", "individuals", "keypoints"], - np.asarray(pose_estimation_series.confidence)[ - :, np.newaxis, np.newaxis - ], + confidence_array, ), }, coords={ From a2ac0538656c463521c9b2fe00764a8347841fc0 Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Sun, 9 Jun 2024 11:00:59 -0700 Subject: [PATCH 21/44] Display xarray --- examples/nwb_conversion.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/nwb_conversion.py b/examples/nwb_conversion.py index 86ad040a..f0001b32 100644 --- a/examples/nwb_conversion.py +++ b/examples/nwb_conversion.py @@ -52,3 +52,4 @@ ds_from_nwb = convert_nwb_to_movement( nwb_filepaths=["individual1.nwb", "individual2.nwb"] ) +ds_from_nwb From 84a495df10823dd6db129bf79503f270c2e6ed4d Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Sun, 9 Jun 2024 11:56:26 -0700 Subject: [PATCH 22/44] Refactor tests --- tests/test_unit/test_nwb.py | 315 ++++++++++++++++++------------------ 1 file changed, 161 insertions(+), 154 deletions(-) diff --git a/tests/test_unit/test_nwb.py b/tests/test_unit/test_nwb.py index 7319ea48..060c2e99 100644 --- a/tests/test_unit/test_nwb.py +++ b/tests/test_unit/test_nwb.py @@ -1,8 +1,10 @@ +import datetime + import ndx_pose import numpy as np -import pynwb -import pytest -import xarray as xr +from ndx_pose import PoseEstimation, PoseEstimationSeries, Skeleton, Skeletons +from pynwb import NWBHDF5IO, NWBFile +from pynwb.file import Subject from movement import sample_data from movement.io.nwb import ( @@ -19,8 +21,8 @@ def test_create_pose_and_skeleton_objects(): # Call the function pose_estimation, skeletons = _create_pose_and_skeleton_objects( - ds, - subject="subject1", + ds.sel(individuals="individual1"), + subject="individual1", pose_estimation_series_kwargs=None, pose_estimation_kwargs=None, skeleton_kwargs=None, @@ -34,43 +36,59 @@ def test_create_pose_and_skeleton_objects(): assert len(pose_estimation) == 1 # Assert the length of pose_estimation_series list - assert len(pose_estimation[0].pose_estimation_series) == 2 + assert len(pose_estimation[0].pose_estimation_series) == 12 # Assert the name of the first PoseEstimationSeries - assert pose_estimation[0].pose_estimation_series[0].name == "keypoint1" - - # Assert the name of the second PoseEstimationSeries - assert pose_estimation[0].pose_estimation_series[1].name == "keypoint2" + assert "snout" in pose_estimation[0].pose_estimation_series # Assert the name of the Skeleton - assert skeletons.skeletons[0].name == "subject1_skeleton" + assert "individual1_skeleton" in skeletons.skeletons + + +def create_test_pose_estimation_series( + n_time=100, n_dims=2, keypoint="front_left_paw" +): + data = np.random.rand( + n_time, n_dims + ) # num_frames x (x, y) but can be (x, y, z) + timestamps = np.linspace(0, 10, num=n_time) # a timestamp for every frame + confidence = np.ones((n_time,)) # a confidence value for every frame + reference_frame = "(0,0,0) corresponds to ..." + confidence_definition = "Softmax output of the deep neural network." + + return PoseEstimationSeries( + name=keypoint, + description="Marker placed around fingers of front left paw.", + data=data, + unit="pixels", + reference_frame=reference_frame, + timestamps=timestamps, + confidence=confidence, + confidence_definition=confidence_definition, + ) def test__convert_pose_estimation_series(): # Create a sample PoseEstimationSeries object - pose_estimation_series = ndx_pose.PoseEstimationSeries( - name="keypoint1", - data=np.random.rand(10, 3), - confidence=np.random.rand(10), - unit="pixels", - timestamps=np.arange(10), + pose_estimation_series = create_test_pose_estimation_series( + n_time=100, n_dims=2, keypoint="front_left_paw" ) # Call the function movement_dataset = _convert_pose_estimation_series( pose_estimation_series, - keypoint="keypoint1", - subject_name="subject1", + keypoint="leftear", + subject_name="individual1", source_software="software1", source_file="file1", ) # Assert the dimensions of the movement dataset - assert movement_dataset.dims == { - "time": 10, + assert movement_dataset.sizes == { + "time": 100, "individuals": 1, "keypoints": 1, - "space": 3, + "space": 2, } # Assert the values of the position variable @@ -92,155 +110,144 @@ def test__convert_pose_estimation_series(): "source_software": "software1", "source_file": "file1", } + pose_estimation_series = create_test_pose_estimation_series( + n_time=50, n_dims=3, keypoint="front_left_paw" + ) + + # Assert the dimensions of the movement dataset + assert movement_dataset.sizes == { + "time": 50, + "individuals": 1, + "keypoints": 1, + "space": 3, + } def test_add_movement_dataset_to_nwb_single_file(): - # Create a sample NWBFile - nwbfile = pynwb.NWBFile( - "session_description", "identifier", "session_start_time" - ) - # Create a sample movement dataset - movement_dataset = xr.Dataset( - { - "keypoints": (["keypoints"], ["keypoint1", "keypoint2"]), - "position": (["time", "keypoints"], [[1, 2], [3, 4]]), - "confidence": (["time", "keypoints"], [[0.9, 0.8], [0.7, 0.6]]), - "time": [0, 1], - "individuals": ["subject1"], - } + ds = sample_data.fetch_dataset("DLC_two-mice.predictions.csv") + session_start_time = datetime.datetime.now(datetime.timezone.utc) + nwbfile_individual1 = NWBFile( + session_description="session_description", + identifier="individual1", + session_start_time=session_start_time, + ) + add_movement_dataset_to_nwb( + nwbfile_individual1, ds.sel(individuals=["individual1"]) + ) + assert ( + "PoseEstimation" + in nwbfile_individual1.processing["behavior"].data_interfaces + ) + assert ( + "Skeletons" + in nwbfile_individual1.processing["behavior"].data_interfaces ) - # Call the function - add_movement_dataset_to_nwb(nwbfile, movement_dataset) - # Assert the presence of PoseEstimation and Skeletons in the NWBFile - assert "PoseEstimation" in nwbfile.processing["behavior"] - assert "Skeletons" in nwbfile.processing["behavior"] def test_add_movement_dataset_to_nwb_multiple_files(): - # Create sample NWBFiles - nwbfiles = [ - pynwb.NWBFile( - "session_description1", "identifier1", "session_start_time1" - ), - pynwb.NWBFile( - "session_description2", "identifier2", "session_start_time2" - ), - ] - # Create a sample movement dataset - movement_dataset = xr.Dataset( - { - "keypoints": (["keypoints"], ["keypoint1", "keypoint2"]), - "position": (["time", "keypoints"], [[1, 2], [3, 4]]), - "confidence": (["time", "keypoints"], [[0.9, 0.8], [0.7, 0.6]]), - "time": [0, 1], - "individuals": ["subject1", "subject2"], - } + ds = sample_data.fetch_dataset("DLC_two-mice.predictions.csv") + session_start_time = datetime.datetime.now(datetime.timezone.utc) + nwbfile_individual1 = NWBFile( + session_description="session_description", + identifier="individual1", + session_start_time=session_start_time, + ) + nwbfile_individual2 = NWBFile( + session_description="session_description", + identifier="individual2", + session_start_time=session_start_time, ) - # Call the function - add_movement_dataset_to_nwb(nwbfiles, movement_dataset) - # Assert the presence of PoseEstimation and Skeletons in each NWBFile - for nwbfile in nwbfiles: - assert "PoseEstimation" in nwbfile.processing["behavior"] - assert "Skeletons" in nwbfile.processing["behavior"] + nwbfiles = [nwbfile_individual1, nwbfile_individual2] + add_movement_dataset_to_nwb(nwbfiles, ds) -def test_convert_nwb_to_movement(): - # Create sample NWB files - nwb_filepaths = [ - "/path/to/file1.nwb", - "/path/to/file2.nwb", - "/path/to/file3.nwb", - ] - pose_estimation_series = { - "keypoint1": ndx_pose.PoseEstimationSeries( - name="keypoint1", - data=np.random.rand(10, 3), - confidence=np.random.rand(10), - unit="pixels", - timestamps=np.arange(10), - ), - "keypoint2": ndx_pose.PoseEstimationSeries( - name="keypoint2", - data=np.random.rand(10, 3), - confidence=np.random.rand(10), - unit="pixels", - timestamps=np.arange(10), - ), - } - # Mock the NWBHDF5IO read method - def mock_read(filepath): - nwbfile = pynwb.NWBFile( - "session_description", "identifier", "session_start_time" - ) - - pose_estimation = ndx_pose.PoseEstimation( - name="PoseEstimation", - pose_estimation_series=pose_estimation_series, - description="Pose estimation data", - source_software="software1", - skeleton=ndx_pose.Skeleton( - name="skeleton1", nodes=["node1", "node2"] - ), - ) - behavior_pm = pynwb.ProcessingModule( - name="behavior", description="Behavior data" - ) - behavior_pm.add(pose_estimation) - nwbfile.add_processing_module(behavior_pm) - return nwbfile +def create_test_pose_nwb(identifier="subject1", write_to_disk=False): + # initialize an NWBFile object + nwbfile = NWBFile( + session_description="session_description", + identifier=identifier, + session_start_time=datetime.datetime.now(datetime.timezone.utc), + ) - # Patch the NWBHDF5IO read method with the mock - with pytest.patch("pynwb.NWBHDF5IO.read", side_effect=mock_read): - # Call the function - movement_dataset = convert_nwb_to_movement(nwb_filepaths) + # add a subject to the NWB file + subject = Subject(subject_id=identifier, species="Mus musculus") + nwbfile.subject = subject - # Assert the dimensions of the movement dataset - assert movement_dataset.dims == { - "time": 10, - "individuals": 3, - "keypoints": 2, - "space": 3, - } + skeleton = Skeleton( + name="subject1_skeleton", + nodes=["front_left_paw", "body", "front_right_paw"], + edges=np.array([[0, 1], [1, 2]], dtype="uint8"), + subject=subject, + ) - # Assert the values of the position variable - np.testing.assert_array_equal( - movement_dataset["position"].values, - np.concatenate( - [ - pose_estimation_series["keypoint1"].data[ - :, np.newaxis, np.newaxis, : - ], - pose_estimation_series["keypoint2"].data[ - :, np.newaxis, np.newaxis, : - ], - ], - axis=1, - ), + skeletons = Skeletons(skeletons=[skeleton]) + + # create a device for the camera + camera1 = nwbfile.create_device( + name="camera1", + description="camera for recording behavior", + manufacturer="my manufacturer", ) - # Assert the values of the confidence variable - np.testing.assert_array_equal( - movement_dataset["confidence"].values, - np.concatenate( - [ - pose_estimation_series["keypoint1"].confidence[ - :, np.newaxis, np.newaxis - ], - pose_estimation_series["keypoint2"].confidence[ - :, np.newaxis, np.newaxis - ], - ], - axis=1, - ), + n_time = 100 + n_dims = 2 # 2D data + front_left_paw = create_test_pose_estimation_series( + n_time=n_time, n_dims=n_dims, keypoint="front_left_paw" ) - # Assert the attributes of the movement dataset - assert movement_dataset.attrs == { - "fps": np.nanmedian( - 1 / np.diff(pose_estimation_series["keypoint1"].timestamps) + body = create_test_pose_estimation_series( + n_time=n_time, n_dims=n_dims, keypoint="body" + ) + front_right_paw = create_test_pose_estimation_series( + n_time=n_time, n_dims=n_dims, keypoint="front_right_paw" + ) + + # store all PoseEstimationSeries in a list + pose_estimation_series = [front_left_paw, body, front_right_paw] + + pose_estimation = PoseEstimation( + name="PoseEstimation", + pose_estimation_series=pose_estimation_series, + description=( + "Estimated positions of front paws" "of subject1 using DeepLabCut." ), - "time_units": pose_estimation_series["keypoint1"].timestamps_unit, - "source_software": "software1", - "source_file": None, + original_videos=["path/to/camera1.mp4"], + labeled_videos=["path/to/camera1_labeled.mp4"], + dimensions=np.array( + [[640, 480]], dtype="uint16" + ), # pixel dimensions of the video + devices=[camera1], + scorer="DLC_resnet50_openfieldOct30shuffle1_1600", + source_software="DeepLabCut", + source_software_version="2.3.8", + skeleton=skeleton, # link to the skeleton object + ) + + behavior_pm = nwbfile.create_processing_module( + name="behavior", + description="processed behavioral data", + ) + behavior_pm.add(skeletons) + behavior_pm.add(pose_estimation) + + # write the NWBFile to disk + if write_to_disk: + path = "test_pose.nwb" + with NWBHDF5IO(path, mode="w") as io: + io.write(nwbfile) + else: + return nwbfile + + +def test_convert_nwb_to_movement(): + create_test_pose_nwb(write_to_disk=True) + nwb_filepaths = ["test_pose.nwb"] + movement_dataset = convert_nwb_to_movement(nwb_filepaths) + + assert movement_dataset.sizes == { + "time": 100, + "individuals": 1, + "keypoints": 3, + "space": 2, } From e9e1cefa9eca5594388dd12065d778c0a67a1fb8 Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Thu, 18 Apr 2024 09:54:41 -0500 Subject: [PATCH 23/44] Create nwb_export.py --- movement/io/nwb_export.py | 143 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 143 insertions(+) create mode 100644 movement/io/nwb_export.py diff --git a/movement/io/nwb_export.py b/movement/io/nwb_export.py new file mode 100644 index 00000000..faac5717 --- /dev/null +++ b/movement/io/nwb_export.py @@ -0,0 +1,143 @@ +import xarray as xr + +try: + import ndx_pose + import pynwb +except ImportError: + ndx_pose = None + pynwb = None + + +def _create_pose_and_skeleton_objects( + ds: xr.Dataset, + pose_estimation_series_kwargs: dict = None, + pose_estimation_kwargs: dict = None, + skeleton_kwargs: dict = None, +) -> tuple[list[ndx_pose.PoseEstimation], ndx_pose.Skeletons]: + """Creates PoseEstimation and Skeletons objects from a movement xarray dataset. + + Parameters + ---------- + ds : xr.Dataset + Movement dataset containing the data to be converted to NWB. + pose_estimation_series_kwargs : dict, optional + PoseEstimationSeries keyword arguments. See ndx_pose, by default None + pose_estimation_kwargs : dict, optional + PoseEstimation keyword arguments. See ndx_pose, by default None + skeleton_kwargs : dict, optional + Skeleton keyword arguments. See ndx_pose, by default None + + Returns + ------- + pose_estimation : list[ndx_pose.PoseEstimation] + List of PoseEstimation objects + skeletons : ndx_pose.Skeletons + Skeletons object containing all skeletons + """ + if pose_estimation_series_kwargs is None: + pose_estimation_series_kwargs = dict( + reference_frame="(0,0,0) corresponds to ...", + confidence_definition=None, + conversion=1.0, + resolution=-1.0, + offset=0.0, + starting_time=None, + comments="no comments", + description="no description", + control=None, + control_description=None, + ) + + if skeleton_kwargs is None: + skeleton_kwargs = dict(edges=None) + + if pose_estimation_kwargs is None: + pose_estimation_kwargs = dict( + original_videos=None, + labeled_videos=None, + dimensions=None, + devices=None, + scorer=None, + source_software_version=None, + ) + + skeleton_list = [] + pose_estimation = [] + + for subject in ds.individuals.to_numpy(): + pose_estimation_series = [] + + for keypoint in ds.keypoints.to_numpy(): + pose_estimation_series.append( + ndx_pose.PoseEstimationSeries( + name=keypoint, + data=ds.sel( + keypoints=keypoint, individuals=subject + ).position.to_numpy(), + confidence=ds.sel( + keypoints=keypoint, individuals=subject + ).confidence.to_numpy(), + unit="pixels", + timestamps=ds.sel( + keypoints=keypoint, individuals=subject + ).time.to_numpy(), + **pose_estimation_series_kwargs, + ) + ) + + skeleton_list.append( + ndx_pose.Skeleton( + name=f"{subject}_skeleton", + nodes=ds.sel(individuals=subject).keypoints.to_numpy().tolist(), + **skeleton_kwargs, + ) + ) + + bodyparts_str = ", ".join( + ds.sel(individuals=subject).keypoints.to_numpy().tolist() + ) + description = f"Estimated positions of {bodyparts_str} of {subject} using {ds.source_software}." + + pose_estimation.append( + ndx_pose.PoseEstimation( + name="PoseEstimation", + pose_estimation_series=pose_estimation_series, + description=description, + source_software=ds.source_software, + skeleton=skeleton_list[-1], + **pose_estimation_kwargs, + ) + ) + + skeletons = ndx_pose.Skeletons(skeletons=skeleton_list) + + return pose_estimation, skeletons + + +def convert_movement_to_nwb( + nwbfile: pynwb.NWBFile, + ds: xr.Dataset, + pose_estimation_series_kwargs: dict = None, + pose_estimation_kwargs: dict = None, + skeletons_kwargs: dict = None, +): + pose_estimation, skeletons = _create_pose_and_skeleton_objects( + ds, pose_estimation_series_kwargs, pose_estimation_kwargs, skeletons_kwargs + ) + try: + behavior_pm = nwbfile.create_processing_module( + name="behavior", + description="processed behavioral data", + ) + except ValueError: + print("Behavior processing module already exists. Skipping...") + behavior_pm = nwbfile.processing["behavior"] + + try: + behavior_pm.add(skeletons) + except ValueError: + print("Skeletons already exists. Skipping...") + try: + behavior_pm.add(pose_estimation) + except ValueError: + print("PoseEstimation already exists. Skipping...") From 3ccd71c141f0e54c2cbf2e28b1f7a64ae73779e6 Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Thu, 18 Apr 2024 10:21:54 -0500 Subject: [PATCH 24/44] NWB requires one file per individual --- movement/io/nwb_export.py | 130 ++++++++++++++++++++------------------ 1 file changed, 68 insertions(+), 62 deletions(-) diff --git a/movement/io/nwb_export.py b/movement/io/nwb_export.py index faac5717..0d5c51d9 100644 --- a/movement/io/nwb_export.py +++ b/movement/io/nwb_export.py @@ -1,4 +1,5 @@ import xarray as xr +from typing import Union try: import ndx_pose @@ -10,6 +11,7 @@ def _create_pose_and_skeleton_objects( ds: xr.Dataset, + subject: str, pose_estimation_series_kwargs: dict = None, pose_estimation_kwargs: dict = None, skeleton_kwargs: dict = None, @@ -20,6 +22,8 @@ def _create_pose_and_skeleton_objects( ---------- ds : xr.Dataset Movement dataset containing the data to be converted to NWB. + subject : str + Name of the subject to be converted. pose_estimation_series_kwargs : dict, optional PoseEstimationSeries keyword arguments. See ndx_pose, by default None pose_estimation_kwargs : dict, optional @@ -61,53 +65,41 @@ def _create_pose_and_skeleton_objects( source_software_version=None, ) - skeleton_list = [] - pose_estimation = [] - - for subject in ds.individuals.to_numpy(): - pose_estimation_series = [] - - for keypoint in ds.keypoints.to_numpy(): - pose_estimation_series.append( - ndx_pose.PoseEstimationSeries( - name=keypoint, - data=ds.sel( - keypoints=keypoint, individuals=subject - ).position.to_numpy(), - confidence=ds.sel( - keypoints=keypoint, individuals=subject - ).confidence.to_numpy(), - unit="pixels", - timestamps=ds.sel( - keypoints=keypoint, individuals=subject - ).time.to_numpy(), - **pose_estimation_series_kwargs, - ) - ) - - skeleton_list.append( - ndx_pose.Skeleton( - name=f"{subject}_skeleton", - nodes=ds.sel(individuals=subject).keypoints.to_numpy().tolist(), - **skeleton_kwargs, + pose_estimation_series = [] + + for keypoint in ds.keypoints.to_numpy(): + pose_estimation_series.append( + ndx_pose.PoseEstimationSeries( + name=keypoint, + data=ds.sel(keypoints=keypoint).position.to_numpy(), + confidence=ds.sel(keypoints=keypoint).confidence.to_numpy(), + unit="pixels", + timestamps=ds.sel(keypoints=keypoint).time.to_numpy(), + **pose_estimation_series_kwargs, ) ) - bodyparts_str = ", ".join( - ds.sel(individuals=subject).keypoints.to_numpy().tolist() + skeleton_list = [ + ndx_pose.Skeleton( + name=f"{subject}_skeleton", + nodes=ds.keypoints.to_numpy().tolist(), + **skeleton_kwargs, ) - description = f"Estimated positions of {bodyparts_str} of {subject} using {ds.source_software}." - - pose_estimation.append( - ndx_pose.PoseEstimation( - name="PoseEstimation", - pose_estimation_series=pose_estimation_series, - description=description, - source_software=ds.source_software, - skeleton=skeleton_list[-1], - **pose_estimation_kwargs, - ) + ] + + bodyparts_str = ", ".join(ds.keypoints.to_numpy().tolist()) + description = f"Estimated positions of {bodyparts_str} of {subject} using {ds.source_software}." + + pose_estimation = [ + ndx_pose.PoseEstimation( + name="PoseEstimation", + pose_estimation_series=pose_estimation_series, + description=description, + source_software=ds.source_software, + skeleton=skeleton_list[-1], + **pose_estimation_kwargs, ) + ] skeletons = ndx_pose.Skeletons(skeletons=skeleton_list) @@ -115,29 +107,43 @@ def _create_pose_and_skeleton_objects( def convert_movement_to_nwb( - nwbfile: pynwb.NWBFile, + nwbfiles: Union[list[pynwb.NWBFile], pynwb.NWBFile], ds: xr.Dataset, pose_estimation_series_kwargs: dict = None, pose_estimation_kwargs: dict = None, skeletons_kwargs: dict = None, ): - pose_estimation, skeletons = _create_pose_and_skeleton_objects( - ds, pose_estimation_series_kwargs, pose_estimation_kwargs, skeletons_kwargs - ) - try: - behavior_pm = nwbfile.create_processing_module( - name="behavior", - description="processed behavioral data", + if isinstance(nwbfiles, pynwb.NWBFile): + nwbfiles = [nwbfiles] + + if len(nwbfiles) != len(ds.individuals): + raise ValueError( + "Number of NWBFiles must be equal to the number of individuals in the dataset. " + "NWB requires one file per individual." ) - except ValueError: - print("Behavior processing module already exists. Skipping...") - behavior_pm = nwbfile.processing["behavior"] - - try: - behavior_pm.add(skeletons) - except ValueError: - print("Skeletons already exists. Skipping...") - try: - behavior_pm.add(pose_estimation) - except ValueError: - print("PoseEstimation already exists. Skipping...") + + for nwbfile, subject in zip(nwbfiles, ds.individuals.to_numpy()): + pose_estimation, skeletons = _create_pose_and_skeleton_objects( + ds.sel(individuals=subject), + subject, + pose_estimation_series_kwargs, + pose_estimation_kwargs, + skeletons_kwargs, + ) + try: + behavior_pm = nwbfile.create_processing_module( + name="behavior", + description="processed behavioral data", + ) + except ValueError: + print("Behavior processing module already exists. Skipping...") + behavior_pm = nwbfile.processing["behavior"] + + try: + behavior_pm.add(skeletons) + except ValueError: + print("Skeletons already exists. Skipping...") + try: + behavior_pm.add(pose_estimation) + except ValueError: + print("PoseEstimation already exists. Skipping...") From f906cd5787bfc56a53a28eeb058a465d513ac79e Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Fri, 19 Apr 2024 09:47:01 -0500 Subject: [PATCH 25/44] Add script --- examples/nwb_conversion.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 examples/nwb_conversion.py diff --git a/examples/nwb_conversion.py b/examples/nwb_conversion.py new file mode 100644 index 00000000..4499e951 --- /dev/null +++ b/examples/nwb_conversion.py @@ -0,0 +1,27 @@ +from movement import sample_data +from pynwb import NWBFile +import datetime +from movement.io.nwb_export import convert_movement_to_nwb + +# Load the sample data +ds = sample_data.fetch_sample_data("DLC_two-mice.predictions.csv") + +# The dataset has two individuals, we will create two NWBFiles for each individual + +nwbfile_individual1 = NWBFile( + session_description="session_description", + identifier="individual1", + session_start_time=datetime.datetime.now(datetime.timezone.utc), +) +nwbfile_individual2 = NWBFile( + session_description="session_description", + identifier="individual2", + session_start_time=datetime.datetime.now(datetime.timezone.utc), +) + +nwbfiles = [nwbfile_individual1, nwbfile_individual2] + +# Convert the dataset to NWB +# This will create PoseEstimation and Skeleton objects for each individual +# and add them to the NWBFile +convert_movement_to_nwb(nwbfiles, ds) From d35d9c2b5c36d8a2818a48d6bb6226b8bd5bddc5 Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Fri, 19 Apr 2024 09:49:13 -0500 Subject: [PATCH 26/44] Remove import error handling --- movement/io/nwb_export.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/movement/io/nwb_export.py b/movement/io/nwb_export.py index 0d5c51d9..a9d3d6a2 100644 --- a/movement/io/nwb_export.py +++ b/movement/io/nwb_export.py @@ -1,12 +1,8 @@ -import xarray as xr from typing import Union -try: - import ndx_pose - import pynwb -except ImportError: - ndx_pose = None - pynwb = None +import ndx_pose +import pynwb +import xarray as xr def _create_pose_and_skeleton_objects( From e5726d4791d0c3c261306b71fc5becc009c055cf Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Fri, 19 Apr 2024 09:49:24 -0500 Subject: [PATCH 27/44] Add nwb optional dependencies --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 27348c29..72ec120b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,7 @@ dev = [ "types-PyYAML", "types-requests", ] +nwb = ["pynwb", "git+https://github.com/rly/ndx-pose.git"] [project.scripts] movement = "movement.cli_entrypoint:main" From 53f505b6198b8baf894533efe14d7cfe90421f85 Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Fri, 19 Apr 2024 10:03:36 -0500 Subject: [PATCH 28/44] Fix linting based on pre-commit hooks --- examples/nwb_conversion.py | 14 +++++++++----- movement/io/nwb_export.py | 24 ++++++++++++++---------- pyproject.toml | 2 +- 3 files changed, 24 insertions(+), 16 deletions(-) diff --git a/examples/nwb_conversion.py b/examples/nwb_conversion.py index 4499e951..67ee5810 100644 --- a/examples/nwb_conversion.py +++ b/examples/nwb_conversion.py @@ -1,22 +1,26 @@ -from movement import sample_data -from pynwb import NWBFile import datetime + +from pynwb import NWBFile + +from movement import sample_data from movement.io.nwb_export import convert_movement_to_nwb # Load the sample data ds = sample_data.fetch_sample_data("DLC_two-mice.predictions.csv") -# The dataset has two individuals, we will create two NWBFiles for each individual +# The dataset has two individuals +# we will create two NWBFiles for each individual +session_start_time = datetime.datetime.now(datetime.timezone.utc) nwbfile_individual1 = NWBFile( session_description="session_description", identifier="individual1", - session_start_time=datetime.datetime.now(datetime.timezone.utc), + session_start_time=session_start_time, ) nwbfile_individual2 = NWBFile( session_description="session_description", identifier="individual2", - session_start_time=datetime.datetime.now(datetime.timezone.utc), + session_start_time=session_start_time, ) nwbfiles = [nwbfile_individual1, nwbfile_individual2] diff --git a/movement/io/nwb_export.py b/movement/io/nwb_export.py index a9d3d6a2..77e211ee 100644 --- a/movement/io/nwb_export.py +++ b/movement/io/nwb_export.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import Optional, Union import ndx_pose import pynwb @@ -8,11 +8,12 @@ def _create_pose_and_skeleton_objects( ds: xr.Dataset, subject: str, - pose_estimation_series_kwargs: dict = None, - pose_estimation_kwargs: dict = None, - skeleton_kwargs: dict = None, + pose_estimation_series_kwargs: Optional[dict] = None, + pose_estimation_kwargs: Optional[dict] = None, + skeleton_kwargs: Optional[dict] = None, ) -> tuple[list[ndx_pose.PoseEstimation], ndx_pose.Skeletons]: - """Creates PoseEstimation and Skeletons objects from a movement xarray dataset. + """Creates PoseEstimation and Skeletons objects from a movement xarray + dataset. Parameters ---------- @@ -84,7 +85,10 @@ def _create_pose_and_skeleton_objects( ] bodyparts_str = ", ".join(ds.keypoints.to_numpy().tolist()) - description = f"Estimated positions of {bodyparts_str} of {subject} using {ds.source_software}." + description = ( + f"Estimated positions of {bodyparts_str} of" + f"{subject} using {ds.source_software}." + ) pose_estimation = [ ndx_pose.PoseEstimation( @@ -105,16 +109,16 @@ def _create_pose_and_skeleton_objects( def convert_movement_to_nwb( nwbfiles: Union[list[pynwb.NWBFile], pynwb.NWBFile], ds: xr.Dataset, - pose_estimation_series_kwargs: dict = None, - pose_estimation_kwargs: dict = None, - skeletons_kwargs: dict = None, + pose_estimation_series_kwargs: Optional[dict] = None, + pose_estimation_kwargs: Optional[dict] = None, + skeletons_kwargs: Optional[dict] = None, ): if isinstance(nwbfiles, pynwb.NWBFile): nwbfiles = [nwbfiles] if len(nwbfiles) != len(ds.individuals): raise ValueError( - "Number of NWBFiles must be equal to the number of individuals in the dataset. " + "Number of NWBFiles must be equal to the number of individuals. " "NWB requires one file per individual." ) diff --git a/pyproject.toml b/pyproject.toml index 72ec120b..4405d1b5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,7 +59,7 @@ dev = [ "types-PyYAML", "types-requests", ] -nwb = ["pynwb", "git+https://github.com/rly/ndx-pose.git"] +nwb = ["pynwb", "ndx-pose"] [project.scripts] movement = "movement.cli_entrypoint:main" From f1d480dd50adb6b0b98043589edee6b0f11fa17c Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Fri, 19 Apr 2024 10:21:05 -0500 Subject: [PATCH 29/44] Add example docstring --- examples/nwb_conversion.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/examples/nwb_conversion.py b/examples/nwb_conversion.py index 67ee5810..bc7b80cd 100644 --- a/examples/nwb_conversion.py +++ b/examples/nwb_conversion.py @@ -1,3 +1,10 @@ +""" +Export pose tracks to NWB +============================ + +Export pose tracks to NWB +""" + import datetime from pynwb import NWBFile From 4b162cffc6b565e2534692344a13c91a5f71cfcd Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Fri, 19 Apr 2024 13:31:34 -0500 Subject: [PATCH 30/44] Rename to fit module naming pattern --- examples/nwb_conversion.py | 2 +- movement/io/{nwb_export.py => export_nwb.py} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename movement/io/{nwb_export.py => export_nwb.py} (100%) diff --git a/examples/nwb_conversion.py b/examples/nwb_conversion.py index bc7b80cd..396e578b 100644 --- a/examples/nwb_conversion.py +++ b/examples/nwb_conversion.py @@ -10,7 +10,7 @@ from pynwb import NWBFile from movement import sample_data -from movement.io.nwb_export import convert_movement_to_nwb +from movement.io.export_nwb import convert_movement_to_nwb # Load the sample data ds = sample_data.fetch_sample_data("DLC_two-mice.predictions.csv") diff --git a/movement/io/nwb_export.py b/movement/io/export_nwb.py similarity index 100% rename from movement/io/nwb_export.py rename to movement/io/export_nwb.py From 4b887be065de2e49e17393f2341b915dbf898c3a Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Fri, 19 Apr 2024 14:45:20 -0500 Subject: [PATCH 31/44] Add import from nwb --- examples/nwb_conversion.py | 2 +- movement/io/{export_nwb.py => nwb.py} | 63 +++++++++++++++++++++++++++ 2 files changed, 64 insertions(+), 1 deletion(-) rename movement/io/{export_nwb.py => nwb.py} (71%) diff --git a/examples/nwb_conversion.py b/examples/nwb_conversion.py index 396e578b..e796c222 100644 --- a/examples/nwb_conversion.py +++ b/examples/nwb_conversion.py @@ -10,7 +10,7 @@ from pynwb import NWBFile from movement import sample_data -from movement.io.export_nwb import convert_movement_to_nwb +from movement.io.nwb import convert_movement_to_nwb # Load the sample data ds = sample_data.fetch_sample_data("DLC_two-mice.predictions.csv") diff --git a/movement/io/export_nwb.py b/movement/io/nwb.py similarity index 71% rename from movement/io/export_nwb.py rename to movement/io/nwb.py index 77e211ee..8afce46a 100644 --- a/movement/io/export_nwb.py +++ b/movement/io/nwb.py @@ -1,6 +1,7 @@ from typing import Optional, Union import ndx_pose +import numpy as np import pynwb import xarray as xr @@ -147,3 +148,65 @@ def convert_movement_to_nwb( behavior_pm.add(pose_estimation) except ValueError: print("PoseEstimation already exists. Skipping...") + + +def _convert_pse( + pes: ndx_pose.PoseEstimationSeries, + keypoint: str, + subject_name: str, + source_software: str, + source_file: Optional[str] = None, +): + attrs = { + "fps": int(np.median(1 / np.diff(pes.timestamps))), + "time_units": pes.timestamps_unit, + "source_software": source_software, + "source_file": source_file, + } + n_space_dims = pes.data.shape[1] + space_dims = ["x", "y", "z"] + + return xr.Dataset( + data_vars={ + "position": ( + ["time", "individuals", "keypoints", "space"], + pes.data[:, np.newaxis, np.newaxis, :], + ), + "confidence": ( + ["time", "individuals", "keypoints"], + pes.confidence[:, np.newaxis, np.newaxis], + ), + }, + coords={ + "time": pes.timestamps, + "individuals": [subject_name], + "keypoints": [keypoint], + "space": space_dims[:n_space_dims], + }, + attrs=attrs, + ) + + +def convert_nwb_to_movement(nwb_filepaths: list[str]) -> xr.Dataset: + datasets = [] + for path in nwb_filepaths: + with pynwb.NWBHDF5IO(path, mode="r") as io: + nwbfile = io.read() + pose_estimation = nwbfile.processing["behavior"]["PoseEstimation"] + source_software = pose_estimation.fields["source_software"] + pose_estimation_series = pose_estimation.fields[ + "pose_estimation_series" + ] + + for keypoint, pes in pose_estimation_series.items(): + datasets.append( + _convert_pse( + pes, + keypoint, + subject_name=nwbfile.identifier, + source_software=source_software, + source_file=None, + ) + ) + + return xr.merge(datasets) From 96ee7baecd2485c81874b35af9273d972c1a3462 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 22 Apr 2024 13:29:05 +0000 Subject: [PATCH 32/44] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/nwb_conversion.py | 3 +-- movement/io/nwb.py | 1 + 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/nwb_conversion.py b/examples/nwb_conversion.py index e796c222..29d9f4c5 100644 --- a/examples/nwb_conversion.py +++ b/examples/nwb_conversion.py @@ -1,5 +1,4 @@ -""" -Export pose tracks to NWB +"""Export pose tracks to NWB ============================ Export pose tracks to NWB diff --git a/movement/io/nwb.py b/movement/io/nwb.py index 8afce46a..722bfb0a 100644 --- a/movement/io/nwb.py +++ b/movement/io/nwb.py @@ -35,6 +35,7 @@ def _create_pose_and_skeleton_objects( List of PoseEstimation objects skeletons : ndx_pose.Skeletons Skeletons object containing all skeletons + """ if pose_estimation_series_kwargs is None: pose_estimation_series_kwargs = dict( From 2f2625d4930ae867f03953bc29d6ed49ceb18785 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 8 Jun 2024 16:01:33 +0000 Subject: [PATCH 33/44] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- movement/io/nwb.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/movement/io/nwb.py b/movement/io/nwb.py index 722bfb0a..8960227d 100644 --- a/movement/io/nwb.py +++ b/movement/io/nwb.py @@ -1,5 +1,3 @@ -from typing import Optional, Union - import ndx_pose import numpy as np import pynwb @@ -9,9 +7,9 @@ def _create_pose_and_skeleton_objects( ds: xr.Dataset, subject: str, - pose_estimation_series_kwargs: Optional[dict] = None, - pose_estimation_kwargs: Optional[dict] = None, - skeleton_kwargs: Optional[dict] = None, + pose_estimation_series_kwargs: dict | None = None, + pose_estimation_kwargs: dict | None = None, + skeleton_kwargs: dict | None = None, ) -> tuple[list[ndx_pose.PoseEstimation], ndx_pose.Skeletons]: """Creates PoseEstimation and Skeletons objects from a movement xarray dataset. @@ -109,11 +107,11 @@ def _create_pose_and_skeleton_objects( def convert_movement_to_nwb( - nwbfiles: Union[list[pynwb.NWBFile], pynwb.NWBFile], + nwbfiles: list[pynwb.NWBFile] | pynwb.NWBFile, ds: xr.Dataset, - pose_estimation_series_kwargs: Optional[dict] = None, - pose_estimation_kwargs: Optional[dict] = None, - skeletons_kwargs: Optional[dict] = None, + pose_estimation_series_kwargs: dict | None = None, + pose_estimation_kwargs: dict | None = None, + skeletons_kwargs: dict | None = None, ): if isinstance(nwbfiles, pynwb.NWBFile): nwbfiles = [nwbfiles] @@ -124,7 +122,9 @@ def convert_movement_to_nwb( "NWB requires one file per individual." ) - for nwbfile, subject in zip(nwbfiles, ds.individuals.to_numpy()): + for nwbfile, subject in zip( + nwbfiles, ds.individuals.to_numpy(), strict=False + ): pose_estimation, skeletons = _create_pose_and_skeleton_objects( ds.sel(individuals=subject), subject, @@ -156,7 +156,7 @@ def _convert_pse( keypoint: str, subject_name: str, source_software: str, - source_file: Optional[str] = None, + source_file: str | None = None, ): attrs = { "fps": int(np.median(1 / np.diff(pes.timestamps))), From 1c7c2e36407b637f7fd245c71d253b5ccf84bb63 Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Sat, 8 Jun 2024 09:02:33 -0700 Subject: [PATCH 34/44] Apply suggestions from code review Co-authored-by: Niko Sirmpilatze --- movement/io/nwb.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/movement/io/nwb.py b/movement/io/nwb.py index 8960227d..0b85fe53 100644 --- a/movement/io/nwb.py +++ b/movement/io/nwb.py @@ -11,13 +11,12 @@ def _create_pose_and_skeleton_objects( pose_estimation_kwargs: dict | None = None, skeleton_kwargs: dict | None = None, ) -> tuple[list[ndx_pose.PoseEstimation], ndx_pose.Skeletons]: - """Creates PoseEstimation and Skeletons objects from a movement xarray - dataset. + """Create PoseEstimation and Skeletons objects from a movement dataset. Parameters ---------- ds : xr.Dataset - Movement dataset containing the data to be converted to NWB. + movement dataset containing the data to be converted to NWB. subject : str Name of the subject to be converted. pose_estimation_series_kwargs : dict, optional From 4191ae8721390e569a729b507eb06cb400f2fffe Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Sat, 8 Jun 2024 09:26:33 -0700 Subject: [PATCH 35/44] Update make pynwb and ndx-pose core dependencies --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 4405d1b5..1d227e01 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,8 @@ dependencies = [ "sleap-io", "xarray[accel,viz]", "PyYAML", + "pynwb", + "ndx-pose>=0.2", ] classifiers = [ @@ -59,7 +61,6 @@ dev = [ "types-PyYAML", "types-requests", ] -nwb = ["pynwb", "ndx-pose"] [project.scripts] movement = "movement.cli_entrypoint:main" From 4202ff69b7535372d7098f1e512364896b1f367b Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Sat, 8 Jun 2024 10:58:07 -0700 Subject: [PATCH 36/44] Cleanup of docstrings and variable names from code review --- movement/io/nwb.py | 107 ++++++++++++++++++++++++++++++++++++--------- 1 file changed, 86 insertions(+), 21 deletions(-) diff --git a/movement/io/nwb.py b/movement/io/nwb.py index 0b85fe53..f95cfde1 100644 --- a/movement/io/nwb.py +++ b/movement/io/nwb.py @@ -1,8 +1,14 @@ +"""Functions to convert movement data to and from NWB format.""" + +from pathlib import Path + import ndx_pose import numpy as np import pynwb import xarray as xr +from movement.logging import log_error + def _create_pose_and_skeleton_objects( ds: xr.Dataset, @@ -11,14 +17,14 @@ def _create_pose_and_skeleton_objects( pose_estimation_kwargs: dict | None = None, skeleton_kwargs: dict | None = None, ) -> tuple[list[ndx_pose.PoseEstimation], ndx_pose.Skeletons]: - """Create PoseEstimation and Skeletons objects from a movement dataset. + """Create PoseEstimation and Skeletons objects from a ``movement`` dataset. Parameters ---------- - ds : xr.Dataset + ds : xarray.Dataset movement dataset containing the data to be converted to NWB. subject : str - Name of the subject to be converted. + Name of the subject (individual) to be converted. pose_estimation_series_kwargs : dict, optional PoseEstimationSeries keyword arguments. See ndx_pose, by default None pose_estimation_kwargs : dict, optional @@ -107,25 +113,48 @@ def _create_pose_and_skeleton_objects( def convert_movement_to_nwb( nwbfiles: list[pynwb.NWBFile] | pynwb.NWBFile, - ds: xr.Dataset, + movement_dataset: xr.Dataset, pose_estimation_series_kwargs: dict | None = None, pose_estimation_kwargs: dict | None = None, skeletons_kwargs: dict | None = None, -): +) -> None: + """Convert a ``movement`` dataset to the ndx-pose extension format for NWB. + + Parameters + ---------- + nwbfiles : list[pynwb.NWBFile] | pynwb.NWBFile + NWBFile object(s) to which the data will be added. + movement_dataset : xr.Dataset + ``movement`` dataset containing the data to be converted to NWB. + pose_estimation_series_kwargs : dict, optional + PoseEstimationSeries keyword arguments. See ndx_pose, by default None + pose_estimation_kwargs : dict, optional + PoseEstimation keyword arguments. See ndx_pose, by default None + skeletons_kwargs : dict, optional + Skeleton keyword arguments. See ndx_pose, by default None + + Raises + ------ + ValueError + If the number of NWBFiles is not equal to the number of individuals + in the dataset. + + """ if isinstance(nwbfiles, pynwb.NWBFile): nwbfiles = [nwbfiles] - if len(nwbfiles) != len(ds.individuals): - raise ValueError( + if len(nwbfiles) != len(movement_dataset.individuals): + raise log_error( + ValueError, "Number of NWBFiles must be equal to the number of individuals. " - "NWB requires one file per individual." + "NWB requires one file per individual.", ) for nwbfile, subject in zip( - nwbfiles, ds.individuals.to_numpy(), strict=False + nwbfiles, movement_dataset.individuals.to_numpy(), strict=False ): pose_estimation, skeletons = _create_pose_and_skeleton_objects( - ds.sel(individuals=subject), + movement_dataset.sel(individuals=subject), subject, pose_estimation_series_kwargs, pose_estimation_kwargs, @@ -150,35 +179,56 @@ def convert_movement_to_nwb( print("PoseEstimation already exists. Skipping...") -def _convert_pse( - pes: ndx_pose.PoseEstimationSeries, +def _convert_pose_estimation_series( + pose_estimation_series: ndx_pose.PoseEstimationSeries, keypoint: str, subject_name: str, source_software: str, source_file: str | None = None, -): +) -> xr.Dataset: + """Convert to single-keypoint, single-individual ``movement`` dataset. + + Parameters + ---------- + pose_estimation_series : ndx_pose.PoseEstimationSeries + PoseEstimationSeries NWB object to be converted. + keypoint : str + Name of the keypoint - body part. + subject_name : str + Name of the subject (individual). + source_software : str + Name of the software used to estimate the pose. + source_file : Optional[str], optional + File from which the data was extracted, by default None + + Returns + ------- + movement_dataset : xr.Dataset + ``movement`` compatible dataset containing the pose estimation data. + + """ attrs = { - "fps": int(np.median(1 / np.diff(pes.timestamps))), - "time_units": pes.timestamps_unit, + "fps": np.nanmedian(1 / np.diff(pose_estimation_series.timestamps)), + "time_units": pose_estimation_series.timestamps_unit, "source_software": source_software, "source_file": source_file, } - n_space_dims = pes.data.shape[1] + n_space_dims = pose_estimation_series.data.shape[1] space_dims = ["x", "y", "z"] return xr.Dataset( data_vars={ "position": ( ["time", "individuals", "keypoints", "space"], - pes.data[:, np.newaxis, np.newaxis, :], + pose_estimation_series.data[:, np.newaxis, np.newaxis, :], ), "confidence": ( ["time", "individuals", "keypoints"], - pes.confidence[:, np.newaxis, np.newaxis], + pose_estimation_series.confidence[:, np.newaxis, np.newaxis], ), }, coords={ - "time": pes.timestamps, + "time": pose_estimation_series.timestamps, "individuals": [subject_name], "keypoints": [keypoint], "space": space_dims[:n_space_dims], @@ -187,7 +237,22 @@ def _convert_pse( ) -def convert_nwb_to_movement(nwb_filepaths: list[str]) -> xr.Dataset: +def convert_nwb_to_movement( + nwb_filepaths: list[str] | list[Path], +) -> xr.Dataset: + """Convert a list of NWB files to a single ``movement`` dataset. + + Parameters + ---------- + nwb_filepaths : Union[list[str], list[Path]] + List of paths to NWB files to be converted. + + Returns + ------- + movement_ds : xr.Dataset + ``movement`` dataset containing the pose estimation data. + + """ datasets = [] for path in nwb_filepaths: with pynwb.NWBHDF5IO(path, mode="r") as io: @@ -200,7 +265,7 @@ def convert_nwb_to_movement(nwb_filepaths: list[str]) -> xr.Dataset: for keypoint, pes in pose_estimation_series.items(): datasets.append( - _convert_pse( + _convert_pose_estimation_series( pes, keypoint, subject_name=nwbfile.identifier, From 3188b0b1948eb1bc6fa8ff38c203a6eb9277f791 Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Sat, 8 Jun 2024 11:15:20 -0700 Subject: [PATCH 37/44] Rename function for clarity --- movement/io/nwb.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/movement/io/nwb.py b/movement/io/nwb.py index f95cfde1..8def398e 100644 --- a/movement/io/nwb.py +++ b/movement/io/nwb.py @@ -111,14 +111,14 @@ def _create_pose_and_skeleton_objects( return pose_estimation, skeletons -def convert_movement_to_nwb( +def add_movement_dataset_to_nwb( nwbfiles: list[pynwb.NWBFile] | pynwb.NWBFile, movement_dataset: xr.Dataset, pose_estimation_series_kwargs: dict | None = None, pose_estimation_kwargs: dict | None = None, skeletons_kwargs: dict | None = None, ) -> None: - """Convert a ``movement`` dataset to the ndx-pose extension format for NWB. + """Add pose estimation data to NWB files for each individual. Parameters ---------- From 4908040cc17ae482a66efd80614a5b18c1138236 Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Sat, 8 Jun 2024 11:16:54 -0700 Subject: [PATCH 38/44] Update with example converting back to movement --- examples/nwb_conversion.py | 39 +++++++++++++++++++++++++++----------- 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/examples/nwb_conversion.py b/examples/nwb_conversion.py index 29d9f4c5..86ad040a 100644 --- a/examples/nwb_conversion.py +++ b/examples/nwb_conversion.py @@ -1,21 +1,24 @@ -"""Export pose tracks to NWB +"""Converting movement dataset to NWB or loading from NWB to movement dataset. ============================ Export pose tracks to NWB """ +# %% Load the sample data import datetime -from pynwb import NWBFile +from pynwb import NWBHDF5IO, NWBFile from movement import sample_data -from movement.io.nwb import convert_movement_to_nwb +from movement.io.nwb import ( + add_movement_dataset_to_nwb, + convert_nwb_to_movement, +) -# Load the sample data -ds = sample_data.fetch_sample_data("DLC_two-mice.predictions.csv") +ds = sample_data.fetch_dataset("DLC_two-mice.predictions.csv") -# The dataset has two individuals -# we will create two NWBFiles for each individual +# %%The dataset has two individuals. +# We will create two NWBFiles for each individual session_start_time = datetime.datetime.now(datetime.timezone.utc) nwbfile_individual1 = NWBFile( @@ -31,7 +34,21 @@ nwbfiles = [nwbfile_individual1, nwbfile_individual2] -# Convert the dataset to NWB -# This will create PoseEstimation and Skeleton objects for each individual -# and add them to the NWBFile -convert_movement_to_nwb(nwbfiles, ds) +# %% Convert the dataset to NWB +# This will create PoseEstimation and Skeleton objects for each +# individual and add them to the NWBFile +add_movement_dataset_to_nwb(nwbfiles, ds) + +# %% Save the NWBFiles +for file in nwbfiles: + with NWBHDF5IO(f"{file.identifier}.nwb", "w") as io: + io.write(file) + +# %% Convert the NWBFiles back to a movement dataset +# This will create a movement dataset with the same data as +# the original dataset from the NWBFiles + +# Convert the NWBFiles to a movement dataset +ds_from_nwb = convert_nwb_to_movement( + nwb_filepaths=["individual1.nwb", "individual2.nwb"] +) From da43e87e80bbef10f93b0cacb86a37c9b7d21bce Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Sat, 8 Jun 2024 11:29:58 -0700 Subject: [PATCH 39/44] Add file validation and handling for single path --- movement/io/nwb.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/movement/io/nwb.py b/movement/io/nwb.py index 8def398e..8f39d0ba 100644 --- a/movement/io/nwb.py +++ b/movement/io/nwb.py @@ -7,6 +7,7 @@ import pynwb import xarray as xr +from movement.io.save_poses import _validate_file_path from movement.logging import log_error @@ -238,13 +239,13 @@ def _convert_pose_estimation_series( def convert_nwb_to_movement( - nwb_filepaths: list[str] | list[Path], + nwb_filepaths: str | list[str] | list[Path], ) -> xr.Dataset: """Convert a list of NWB files to a single ``movement`` dataset. Parameters ---------- - nwb_filepaths : Union[list[str], list[Path]] + nwb_filepaths : str | Path | list[str] | list[Path] List of paths to NWB files to be converted. Returns @@ -253,8 +254,12 @@ def convert_nwb_to_movement( ``movement`` dataset containing the pose estimation data. """ + if isinstance(nwb_filepaths, str | Path): + nwb_filepaths = [nwb_filepaths] + datasets = [] for path in nwb_filepaths: + _validate_file_path(path, expected_suffix=[".nwb"]) with pynwb.NWBHDF5IO(path, mode="r") as io: nwbfile = io.read() pose_estimation = nwbfile.processing["behavior"]["PoseEstimation"] From 9d34939eabeebc75ec6dcd222f27058713272d73 Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Sat, 8 Jun 2024 11:46:01 -0700 Subject: [PATCH 40/44] Add preliminary tests --- tests/test_unit/test_nwb.py | 246 ++++++++++++++++++++++++++++++++++++ 1 file changed, 246 insertions(+) create mode 100644 tests/test_unit/test_nwb.py diff --git a/tests/test_unit/test_nwb.py b/tests/test_unit/test_nwb.py new file mode 100644 index 00000000..7319ea48 --- /dev/null +++ b/tests/test_unit/test_nwb.py @@ -0,0 +1,246 @@ +import ndx_pose +import numpy as np +import pynwb +import pytest +import xarray as xr + +from movement import sample_data +from movement.io.nwb import ( + _convert_pose_estimation_series, + _create_pose_and_skeleton_objects, + add_movement_dataset_to_nwb, + convert_nwb_to_movement, +) + + +def test_create_pose_and_skeleton_objects(): + # Create a sample dataset + ds = sample_data.fetch_dataset("DLC_two-mice.predictions.csv") + + # Call the function + pose_estimation, skeletons = _create_pose_and_skeleton_objects( + ds, + subject="subject1", + pose_estimation_series_kwargs=None, + pose_estimation_kwargs=None, + skeleton_kwargs=None, + ) + + # Assert the output types + assert isinstance(pose_estimation, list) + assert isinstance(skeletons, ndx_pose.Skeletons) + + # Assert the length of pose_estimation list + assert len(pose_estimation) == 1 + + # Assert the length of pose_estimation_series list + assert len(pose_estimation[0].pose_estimation_series) == 2 + + # Assert the name of the first PoseEstimationSeries + assert pose_estimation[0].pose_estimation_series[0].name == "keypoint1" + + # Assert the name of the second PoseEstimationSeries + assert pose_estimation[0].pose_estimation_series[1].name == "keypoint2" + + # Assert the name of the Skeleton + assert skeletons.skeletons[0].name == "subject1_skeleton" + + +def test__convert_pose_estimation_series(): + # Create a sample PoseEstimationSeries object + pose_estimation_series = ndx_pose.PoseEstimationSeries( + name="keypoint1", + data=np.random.rand(10, 3), + confidence=np.random.rand(10), + unit="pixels", + timestamps=np.arange(10), + ) + + # Call the function + movement_dataset = _convert_pose_estimation_series( + pose_estimation_series, + keypoint="keypoint1", + subject_name="subject1", + source_software="software1", + source_file="file1", + ) + + # Assert the dimensions of the movement dataset + assert movement_dataset.dims == { + "time": 10, + "individuals": 1, + "keypoints": 1, + "space": 3, + } + + # Assert the values of the position variable + np.testing.assert_array_equal( + movement_dataset["position"].values, + pose_estimation_series.data[:, np.newaxis, np.newaxis, :], + ) + + # Assert the values of the confidence variable + np.testing.assert_array_equal( + movement_dataset["confidence"].values, + pose_estimation_series.confidence[:, np.newaxis, np.newaxis], + ) + + # Assert the attributes of the movement dataset + assert movement_dataset.attrs == { + "fps": np.nanmedian(1 / np.diff(pose_estimation_series.timestamps)), + "time_units": pose_estimation_series.timestamps_unit, + "source_software": "software1", + "source_file": "file1", + } + + +def test_add_movement_dataset_to_nwb_single_file(): + # Create a sample NWBFile + nwbfile = pynwb.NWBFile( + "session_description", "identifier", "session_start_time" + ) + # Create a sample movement dataset + movement_dataset = xr.Dataset( + { + "keypoints": (["keypoints"], ["keypoint1", "keypoint2"]), + "position": (["time", "keypoints"], [[1, 2], [3, 4]]), + "confidence": (["time", "keypoints"], [[0.9, 0.8], [0.7, 0.6]]), + "time": [0, 1], + "individuals": ["subject1"], + } + ) + # Call the function + add_movement_dataset_to_nwb(nwbfile, movement_dataset) + # Assert the presence of PoseEstimation and Skeletons in the NWBFile + assert "PoseEstimation" in nwbfile.processing["behavior"] + assert "Skeletons" in nwbfile.processing["behavior"] + + +def test_add_movement_dataset_to_nwb_multiple_files(): + # Create sample NWBFiles + nwbfiles = [ + pynwb.NWBFile( + "session_description1", "identifier1", "session_start_time1" + ), + pynwb.NWBFile( + "session_description2", "identifier2", "session_start_time2" + ), + ] + # Create a sample movement dataset + movement_dataset = xr.Dataset( + { + "keypoints": (["keypoints"], ["keypoint1", "keypoint2"]), + "position": (["time", "keypoints"], [[1, 2], [3, 4]]), + "confidence": (["time", "keypoints"], [[0.9, 0.8], [0.7, 0.6]]), + "time": [0, 1], + "individuals": ["subject1", "subject2"], + } + ) + # Call the function + add_movement_dataset_to_nwb(nwbfiles, movement_dataset) + # Assert the presence of PoseEstimation and Skeletons in each NWBFile + for nwbfile in nwbfiles: + assert "PoseEstimation" in nwbfile.processing["behavior"] + assert "Skeletons" in nwbfile.processing["behavior"] + + +def test_convert_nwb_to_movement(): + # Create sample NWB files + nwb_filepaths = [ + "/path/to/file1.nwb", + "/path/to/file2.nwb", + "/path/to/file3.nwb", + ] + pose_estimation_series = { + "keypoint1": ndx_pose.PoseEstimationSeries( + name="keypoint1", + data=np.random.rand(10, 3), + confidence=np.random.rand(10), + unit="pixels", + timestamps=np.arange(10), + ), + "keypoint2": ndx_pose.PoseEstimationSeries( + name="keypoint2", + data=np.random.rand(10, 3), + confidence=np.random.rand(10), + unit="pixels", + timestamps=np.arange(10), + ), + } + + # Mock the NWBHDF5IO read method + def mock_read(filepath): + nwbfile = pynwb.NWBFile( + "session_description", "identifier", "session_start_time" + ) + + pose_estimation = ndx_pose.PoseEstimation( + name="PoseEstimation", + pose_estimation_series=pose_estimation_series, + description="Pose estimation data", + source_software="software1", + skeleton=ndx_pose.Skeleton( + name="skeleton1", nodes=["node1", "node2"] + ), + ) + behavior_pm = pynwb.ProcessingModule( + name="behavior", description="Behavior data" + ) + behavior_pm.add(pose_estimation) + nwbfile.add_processing_module(behavior_pm) + return nwbfile + + # Patch the NWBHDF5IO read method with the mock + with pytest.patch("pynwb.NWBHDF5IO.read", side_effect=mock_read): + # Call the function + movement_dataset = convert_nwb_to_movement(nwb_filepaths) + + # Assert the dimensions of the movement dataset + assert movement_dataset.dims == { + "time": 10, + "individuals": 3, + "keypoints": 2, + "space": 3, + } + + # Assert the values of the position variable + np.testing.assert_array_equal( + movement_dataset["position"].values, + np.concatenate( + [ + pose_estimation_series["keypoint1"].data[ + :, np.newaxis, np.newaxis, : + ], + pose_estimation_series["keypoint2"].data[ + :, np.newaxis, np.newaxis, : + ], + ], + axis=1, + ), + ) + + # Assert the values of the confidence variable + np.testing.assert_array_equal( + movement_dataset["confidence"].values, + np.concatenate( + [ + pose_estimation_series["keypoint1"].confidence[ + :, np.newaxis, np.newaxis + ], + pose_estimation_series["keypoint2"].confidence[ + :, np.newaxis, np.newaxis + ], + ], + axis=1, + ), + ) + + # Assert the attributes of the movement dataset + assert movement_dataset.attrs == { + "fps": np.nanmedian( + 1 / np.diff(pose_estimation_series["keypoint1"].timestamps) + ), + "time_units": pose_estimation_series["keypoint1"].timestamps_unit, + "source_software": "software1", + "source_file": None, + } From 56a66729640087d8510a59732053eddc8abe8cc5 Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Sun, 9 Jun 2024 10:54:36 -0700 Subject: [PATCH 41/44] Convert to numpy array --- movement/io/nwb.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/movement/io/nwb.py b/movement/io/nwb.py index 8f39d0ba..edb25618 100644 --- a/movement/io/nwb.py +++ b/movement/io/nwb.py @@ -7,7 +7,6 @@ import pynwb import xarray as xr -from movement.io.save_poses import _validate_file_path from movement.logging import log_error @@ -221,11 +220,15 @@ def _convert_pose_estimation_series( data_vars={ "position": ( ["time", "individuals", "keypoints", "space"], - pose_estimation_series.data[:, np.newaxis, np.newaxis, :], + np.asarray(pose_estimation_series.data)[ + :, np.newaxis, np.newaxis, : + ], ), "confidence": ( ["time", "individuals", "keypoints"], - pose_estimation_series.confidence[:, np.newaxis, np.newaxis], + np.asarray(pose_estimation_series.confidence)[ + :, np.newaxis, np.newaxis + ], ), }, coords={ @@ -259,7 +262,6 @@ def convert_nwb_to_movement( datasets = [] for path in nwb_filepaths: - _validate_file_path(path, expected_suffix=[".nwb"]) with pynwb.NWBHDF5IO(path, mode="r") as io: nwbfile = io.read() pose_estimation = nwbfile.processing["behavior"]["PoseEstimation"] From b37b2c6ce29746f7ac92b8996408b5f9669c4ce5 Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Sun, 9 Jun 2024 11:00:17 -0700 Subject: [PATCH 42/44] Handle lack of confidence --- movement/io/nwb.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/movement/io/nwb.py b/movement/io/nwb.py index edb25618..5a728490 100644 --- a/movement/io/nwb.py +++ b/movement/io/nwb.py @@ -216,19 +216,28 @@ def _convert_pose_estimation_series( n_space_dims = pose_estimation_series.data.shape[1] space_dims = ["x", "y", "z"] + position_array = np.asarray(pose_estimation_series.data)[ + :, np.newaxis, np.newaxis, : + ] + + if getattr(pose_estimation_series, "confidence", None) is None: + pose_estimation_series.confidence = np.full( + pose_estimation_series.data.shape[0], np.nan + ) + else: + confidence_array = np.asarray(pose_estimation_series.confidence)[ + :, np.newaxis, np.newaxis + ] + return xr.Dataset( data_vars={ "position": ( ["time", "individuals", "keypoints", "space"], - np.asarray(pose_estimation_series.data)[ - :, np.newaxis, np.newaxis, : - ], + position_array, ), "confidence": ( ["time", "individuals", "keypoints"], - np.asarray(pose_estimation_series.confidence)[ - :, np.newaxis, np.newaxis - ], + confidence_array, ), }, coords={ From f7d48ce933a657a6cc2dd4dcbc85a0530d54de14 Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Sun, 9 Jun 2024 11:00:59 -0700 Subject: [PATCH 43/44] Display xarray --- examples/nwb_conversion.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/nwb_conversion.py b/examples/nwb_conversion.py index 86ad040a..f0001b32 100644 --- a/examples/nwb_conversion.py +++ b/examples/nwb_conversion.py @@ -52,3 +52,4 @@ ds_from_nwb = convert_nwb_to_movement( nwb_filepaths=["individual1.nwb", "individual2.nwb"] ) +ds_from_nwb From 0606add21a1f6e9f4abb72c68a01bd108c1a71cc Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Sun, 9 Jun 2024 11:56:26 -0700 Subject: [PATCH 44/44] Refactor tests --- tests/test_unit/test_nwb.py | 315 ++++++++++++++++++------------------ 1 file changed, 161 insertions(+), 154 deletions(-) diff --git a/tests/test_unit/test_nwb.py b/tests/test_unit/test_nwb.py index 7319ea48..060c2e99 100644 --- a/tests/test_unit/test_nwb.py +++ b/tests/test_unit/test_nwb.py @@ -1,8 +1,10 @@ +import datetime + import ndx_pose import numpy as np -import pynwb -import pytest -import xarray as xr +from ndx_pose import PoseEstimation, PoseEstimationSeries, Skeleton, Skeletons +from pynwb import NWBHDF5IO, NWBFile +from pynwb.file import Subject from movement import sample_data from movement.io.nwb import ( @@ -19,8 +21,8 @@ def test_create_pose_and_skeleton_objects(): # Call the function pose_estimation, skeletons = _create_pose_and_skeleton_objects( - ds, - subject="subject1", + ds.sel(individuals="individual1"), + subject="individual1", pose_estimation_series_kwargs=None, pose_estimation_kwargs=None, skeleton_kwargs=None, @@ -34,43 +36,59 @@ def test_create_pose_and_skeleton_objects(): assert len(pose_estimation) == 1 # Assert the length of pose_estimation_series list - assert len(pose_estimation[0].pose_estimation_series) == 2 + assert len(pose_estimation[0].pose_estimation_series) == 12 # Assert the name of the first PoseEstimationSeries - assert pose_estimation[0].pose_estimation_series[0].name == "keypoint1" - - # Assert the name of the second PoseEstimationSeries - assert pose_estimation[0].pose_estimation_series[1].name == "keypoint2" + assert "snout" in pose_estimation[0].pose_estimation_series # Assert the name of the Skeleton - assert skeletons.skeletons[0].name == "subject1_skeleton" + assert "individual1_skeleton" in skeletons.skeletons + + +def create_test_pose_estimation_series( + n_time=100, n_dims=2, keypoint="front_left_paw" +): + data = np.random.rand( + n_time, n_dims + ) # num_frames x (x, y) but can be (x, y, z) + timestamps = np.linspace(0, 10, num=n_time) # a timestamp for every frame + confidence = np.ones((n_time,)) # a confidence value for every frame + reference_frame = "(0,0,0) corresponds to ..." + confidence_definition = "Softmax output of the deep neural network." + + return PoseEstimationSeries( + name=keypoint, + description="Marker placed around fingers of front left paw.", + data=data, + unit="pixels", + reference_frame=reference_frame, + timestamps=timestamps, + confidence=confidence, + confidence_definition=confidence_definition, + ) def test__convert_pose_estimation_series(): # Create a sample PoseEstimationSeries object - pose_estimation_series = ndx_pose.PoseEstimationSeries( - name="keypoint1", - data=np.random.rand(10, 3), - confidence=np.random.rand(10), - unit="pixels", - timestamps=np.arange(10), + pose_estimation_series = create_test_pose_estimation_series( + n_time=100, n_dims=2, keypoint="front_left_paw" ) # Call the function movement_dataset = _convert_pose_estimation_series( pose_estimation_series, - keypoint="keypoint1", - subject_name="subject1", + keypoint="leftear", + subject_name="individual1", source_software="software1", source_file="file1", ) # Assert the dimensions of the movement dataset - assert movement_dataset.dims == { - "time": 10, + assert movement_dataset.sizes == { + "time": 100, "individuals": 1, "keypoints": 1, - "space": 3, + "space": 2, } # Assert the values of the position variable @@ -92,155 +110,144 @@ def test__convert_pose_estimation_series(): "source_software": "software1", "source_file": "file1", } + pose_estimation_series = create_test_pose_estimation_series( + n_time=50, n_dims=3, keypoint="front_left_paw" + ) + + # Assert the dimensions of the movement dataset + assert movement_dataset.sizes == { + "time": 50, + "individuals": 1, + "keypoints": 1, + "space": 3, + } def test_add_movement_dataset_to_nwb_single_file(): - # Create a sample NWBFile - nwbfile = pynwb.NWBFile( - "session_description", "identifier", "session_start_time" - ) - # Create a sample movement dataset - movement_dataset = xr.Dataset( - { - "keypoints": (["keypoints"], ["keypoint1", "keypoint2"]), - "position": (["time", "keypoints"], [[1, 2], [3, 4]]), - "confidence": (["time", "keypoints"], [[0.9, 0.8], [0.7, 0.6]]), - "time": [0, 1], - "individuals": ["subject1"], - } + ds = sample_data.fetch_dataset("DLC_two-mice.predictions.csv") + session_start_time = datetime.datetime.now(datetime.timezone.utc) + nwbfile_individual1 = NWBFile( + session_description="session_description", + identifier="individual1", + session_start_time=session_start_time, + ) + add_movement_dataset_to_nwb( + nwbfile_individual1, ds.sel(individuals=["individual1"]) + ) + assert ( + "PoseEstimation" + in nwbfile_individual1.processing["behavior"].data_interfaces + ) + assert ( + "Skeletons" + in nwbfile_individual1.processing["behavior"].data_interfaces ) - # Call the function - add_movement_dataset_to_nwb(nwbfile, movement_dataset) - # Assert the presence of PoseEstimation and Skeletons in the NWBFile - assert "PoseEstimation" in nwbfile.processing["behavior"] - assert "Skeletons" in nwbfile.processing["behavior"] def test_add_movement_dataset_to_nwb_multiple_files(): - # Create sample NWBFiles - nwbfiles = [ - pynwb.NWBFile( - "session_description1", "identifier1", "session_start_time1" - ), - pynwb.NWBFile( - "session_description2", "identifier2", "session_start_time2" - ), - ] - # Create a sample movement dataset - movement_dataset = xr.Dataset( - { - "keypoints": (["keypoints"], ["keypoint1", "keypoint2"]), - "position": (["time", "keypoints"], [[1, 2], [3, 4]]), - "confidence": (["time", "keypoints"], [[0.9, 0.8], [0.7, 0.6]]), - "time": [0, 1], - "individuals": ["subject1", "subject2"], - } + ds = sample_data.fetch_dataset("DLC_two-mice.predictions.csv") + session_start_time = datetime.datetime.now(datetime.timezone.utc) + nwbfile_individual1 = NWBFile( + session_description="session_description", + identifier="individual1", + session_start_time=session_start_time, + ) + nwbfile_individual2 = NWBFile( + session_description="session_description", + identifier="individual2", + session_start_time=session_start_time, ) - # Call the function - add_movement_dataset_to_nwb(nwbfiles, movement_dataset) - # Assert the presence of PoseEstimation and Skeletons in each NWBFile - for nwbfile in nwbfiles: - assert "PoseEstimation" in nwbfile.processing["behavior"] - assert "Skeletons" in nwbfile.processing["behavior"] + nwbfiles = [nwbfile_individual1, nwbfile_individual2] + add_movement_dataset_to_nwb(nwbfiles, ds) -def test_convert_nwb_to_movement(): - # Create sample NWB files - nwb_filepaths = [ - "/path/to/file1.nwb", - "/path/to/file2.nwb", - "/path/to/file3.nwb", - ] - pose_estimation_series = { - "keypoint1": ndx_pose.PoseEstimationSeries( - name="keypoint1", - data=np.random.rand(10, 3), - confidence=np.random.rand(10), - unit="pixels", - timestamps=np.arange(10), - ), - "keypoint2": ndx_pose.PoseEstimationSeries( - name="keypoint2", - data=np.random.rand(10, 3), - confidence=np.random.rand(10), - unit="pixels", - timestamps=np.arange(10), - ), - } - # Mock the NWBHDF5IO read method - def mock_read(filepath): - nwbfile = pynwb.NWBFile( - "session_description", "identifier", "session_start_time" - ) - - pose_estimation = ndx_pose.PoseEstimation( - name="PoseEstimation", - pose_estimation_series=pose_estimation_series, - description="Pose estimation data", - source_software="software1", - skeleton=ndx_pose.Skeleton( - name="skeleton1", nodes=["node1", "node2"] - ), - ) - behavior_pm = pynwb.ProcessingModule( - name="behavior", description="Behavior data" - ) - behavior_pm.add(pose_estimation) - nwbfile.add_processing_module(behavior_pm) - return nwbfile +def create_test_pose_nwb(identifier="subject1", write_to_disk=False): + # initialize an NWBFile object + nwbfile = NWBFile( + session_description="session_description", + identifier=identifier, + session_start_time=datetime.datetime.now(datetime.timezone.utc), + ) - # Patch the NWBHDF5IO read method with the mock - with pytest.patch("pynwb.NWBHDF5IO.read", side_effect=mock_read): - # Call the function - movement_dataset = convert_nwb_to_movement(nwb_filepaths) + # add a subject to the NWB file + subject = Subject(subject_id=identifier, species="Mus musculus") + nwbfile.subject = subject - # Assert the dimensions of the movement dataset - assert movement_dataset.dims == { - "time": 10, - "individuals": 3, - "keypoints": 2, - "space": 3, - } + skeleton = Skeleton( + name="subject1_skeleton", + nodes=["front_left_paw", "body", "front_right_paw"], + edges=np.array([[0, 1], [1, 2]], dtype="uint8"), + subject=subject, + ) - # Assert the values of the position variable - np.testing.assert_array_equal( - movement_dataset["position"].values, - np.concatenate( - [ - pose_estimation_series["keypoint1"].data[ - :, np.newaxis, np.newaxis, : - ], - pose_estimation_series["keypoint2"].data[ - :, np.newaxis, np.newaxis, : - ], - ], - axis=1, - ), + skeletons = Skeletons(skeletons=[skeleton]) + + # create a device for the camera + camera1 = nwbfile.create_device( + name="camera1", + description="camera for recording behavior", + manufacturer="my manufacturer", ) - # Assert the values of the confidence variable - np.testing.assert_array_equal( - movement_dataset["confidence"].values, - np.concatenate( - [ - pose_estimation_series["keypoint1"].confidence[ - :, np.newaxis, np.newaxis - ], - pose_estimation_series["keypoint2"].confidence[ - :, np.newaxis, np.newaxis - ], - ], - axis=1, - ), + n_time = 100 + n_dims = 2 # 2D data + front_left_paw = create_test_pose_estimation_series( + n_time=n_time, n_dims=n_dims, keypoint="front_left_paw" ) - # Assert the attributes of the movement dataset - assert movement_dataset.attrs == { - "fps": np.nanmedian( - 1 / np.diff(pose_estimation_series["keypoint1"].timestamps) + body = create_test_pose_estimation_series( + n_time=n_time, n_dims=n_dims, keypoint="body" + ) + front_right_paw = create_test_pose_estimation_series( + n_time=n_time, n_dims=n_dims, keypoint="front_right_paw" + ) + + # store all PoseEstimationSeries in a list + pose_estimation_series = [front_left_paw, body, front_right_paw] + + pose_estimation = PoseEstimation( + name="PoseEstimation", + pose_estimation_series=pose_estimation_series, + description=( + "Estimated positions of front paws" "of subject1 using DeepLabCut." ), - "time_units": pose_estimation_series["keypoint1"].timestamps_unit, - "source_software": "software1", - "source_file": None, + original_videos=["path/to/camera1.mp4"], + labeled_videos=["path/to/camera1_labeled.mp4"], + dimensions=np.array( + [[640, 480]], dtype="uint16" + ), # pixel dimensions of the video + devices=[camera1], + scorer="DLC_resnet50_openfieldOct30shuffle1_1600", + source_software="DeepLabCut", + source_software_version="2.3.8", + skeleton=skeleton, # link to the skeleton object + ) + + behavior_pm = nwbfile.create_processing_module( + name="behavior", + description="processed behavioral data", + ) + behavior_pm.add(skeletons) + behavior_pm.add(pose_estimation) + + # write the NWBFile to disk + if write_to_disk: + path = "test_pose.nwb" + with NWBHDF5IO(path, mode="w") as io: + io.write(nwbfile) + else: + return nwbfile + + +def test_convert_nwb_to_movement(): + create_test_pose_nwb(write_to_disk=True) + nwb_filepaths = ["test_pose.nwb"] + movement_dataset = convert_nwb_to_movement(nwb_filepaths) + + assert movement_dataset.sizes == { + "time": 100, + "individuals": 1, + "keypoints": 3, + "space": 2, }