diff --git a/movement/io/nwb.py b/movement/io/nwb.py index 1a559ba8..8a01632a 100644 --- a/movement/io/nwb.py +++ b/movement/io/nwb.py @@ -45,7 +45,6 @@ def _merge_kwargs(defaults, overrides): def _create_pose_and_skeleton_objects( ds: xr.Dataset, - subject: str, pose_estimation_series_kwargs: dict | None = None, pose_estimation_kwargs: dict | None = None, skeleton_kwargs: dict | None = None, @@ -55,15 +54,14 @@ def _create_pose_and_skeleton_objects( Parameters ---------- ds : xarray.Dataset - movement dataset containing the data to be converted to NWB. - subject : str - Name of the subject (individual) to be converted. + A single-individual ``movement`` poses dataset. pose_estimation_series_kwargs : dict, optional - PoseEstimationSeries keyword arguments. See ndx_pose, by default None + PoseEstimationSeries keyword arguments. + See ``ndx_pose``, by default None pose_estimation_kwargs : dict, optional - PoseEstimation keyword arguments. See ndx_pose, by default None + PoseEstimation keyword arguments. See ``ndx_pose``, by default None skeleton_kwargs : dict, optional - Skeleton keyword arguments. See ndx_pose, by default None + Skeleton keyword arguments. See ``ndx_pose``, by default None Returns ------- @@ -82,8 +80,11 @@ def _create_pose_and_skeleton_objects( ) skeleton_kwargs = _merge_kwargs(SKELETON_KWARGS, skeleton_kwargs) - pose_estimation_series = [] + # Extract individual name + individual = ds.individuals.values.item() + # Create a PoseEstimationSeries object for each keypoint + pose_estimation_series = [] for keypoint in ds.keypoints.to_numpy(): pose_estimation_series.append( ndx_pose.PoseEstimationSeries( @@ -95,10 +96,10 @@ def _create_pose_and_skeleton_objects( **pose_estimation_series_kwargs, ) ) - + # Create a Skeleton object for the chosen individual skeleton_list = [ ndx_pose.Skeleton( - name=f"{subject}_skeleton", + name=f"{individual}_skeleton", nodes=ds.keypoints.to_numpy().tolist(), **skeleton_kwargs, ) @@ -107,7 +108,7 @@ def _create_pose_and_skeleton_objects( bodyparts_str = ", ".join(ds.keypoints.to_numpy().tolist()) description = ( f"Estimated positions of {bodyparts_str} of" - f"{subject} using {ds.source_software}." + f"{individual} using {ds.source_software}." ) pose_estimation = [ @@ -178,12 +179,11 @@ def ds_to_nwb( "individuals, as NWB requires one file per individual (subject).", ) - for nwb_file, subject in zip( + for nwb_file, individual in zip( nwb_files, movement_dataset.individuals.values, strict=False ): pose_estimation, skeletons = _create_pose_and_skeleton_objects( - movement_dataset.sel(individuals=subject), - subject, + movement_dataset.sel(individuals=individual), pose_estimation_series_kwargs, pose_estimation_kwargs, skeletons_kwargs, diff --git a/tests/test_unit/test_nwb.py b/tests/test_unit/test_nwb.py index 9e7e72cd..4d95de42 100644 --- a/tests/test_unit/test_nwb.py +++ b/tests/test_unit/test_nwb.py @@ -23,7 +23,6 @@ def test_create_pose_and_skeleton_objects(): # Call the function pose_estimation, skeletons = _create_pose_and_skeleton_objects( ds.sel(individuals="individual1"), - subject="individual1", pose_estimation_series_kwargs=None, pose_estimation_kwargs=None, skeleton_kwargs=None,