Skip to content

Commit

Permalink
use individual instead of subject
Browse files Browse the repository at this point in the history
  • Loading branch information
niksirbi committed Dec 18, 2024
1 parent 51bb1af commit 179963d
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 15 deletions.
28 changes: 14 additions & 14 deletions movement/io/nwb.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def _merge_kwargs(defaults, overrides):

def _create_pose_and_skeleton_objects(
ds: xr.Dataset,
subject: str,
pose_estimation_series_kwargs: dict | None = None,
pose_estimation_kwargs: dict | None = None,
skeleton_kwargs: dict | None = None,
Expand All @@ -55,15 +54,14 @@ def _create_pose_and_skeleton_objects(
Parameters
----------
ds : xarray.Dataset
movement dataset containing the data to be converted to NWB.
subject : str
Name of the subject (individual) to be converted.
A single-individual ``movement`` poses dataset.
pose_estimation_series_kwargs : dict, optional
PoseEstimationSeries keyword arguments. See ndx_pose, by default None
PoseEstimationSeries keyword arguments.
See ``ndx_pose``, by default None
pose_estimation_kwargs : dict, optional
PoseEstimation keyword arguments. See ndx_pose, by default None
PoseEstimation keyword arguments. See ``ndx_pose``, by default None
skeleton_kwargs : dict, optional
Skeleton keyword arguments. See ndx_pose, by default None
Skeleton keyword arguments. See ``ndx_pose``, by default None
Returns
-------
Expand All @@ -82,8 +80,11 @@ def _create_pose_and_skeleton_objects(
)
skeleton_kwargs = _merge_kwargs(SKELETON_KWARGS, skeleton_kwargs)

pose_estimation_series = []
# Extract individual name
individual = ds.individuals.values.item()

# Create a PoseEstimationSeries object for each keypoint
pose_estimation_series = []
for keypoint in ds.keypoints.to_numpy():
pose_estimation_series.append(
ndx_pose.PoseEstimationSeries(
Expand All @@ -95,10 +96,10 @@ def _create_pose_and_skeleton_objects(
**pose_estimation_series_kwargs,
)
)

# Create a Skeleton object for the chosen individual
skeleton_list = [
ndx_pose.Skeleton(
name=f"{subject}_skeleton",
name=f"{individual}_skeleton",
nodes=ds.keypoints.to_numpy().tolist(),
**skeleton_kwargs,
)
Expand All @@ -107,7 +108,7 @@ def _create_pose_and_skeleton_objects(
bodyparts_str = ", ".join(ds.keypoints.to_numpy().tolist())
description = (
f"Estimated positions of {bodyparts_str} of"
f"{subject} using {ds.source_software}."
f"{individual} using {ds.source_software}."
)

pose_estimation = [
Expand Down Expand Up @@ -178,12 +179,11 @@ def ds_to_nwb(
"individuals, as NWB requires one file per individual (subject).",
)

for nwb_file, subject in zip(
for nwb_file, individual in zip(
nwb_files, movement_dataset.individuals.values, strict=False
):
pose_estimation, skeletons = _create_pose_and_skeleton_objects(
movement_dataset.sel(individuals=subject),
subject,
movement_dataset.sel(individuals=individual),
pose_estimation_series_kwargs,
pose_estimation_kwargs,
skeletons_kwargs,
Expand Down
1 change: 0 additions & 1 deletion tests/test_unit/test_nwb.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ def test_create_pose_and_skeleton_objects():
# Call the function
pose_estimation, skeletons = _create_pose_and_skeleton_objects(
ds.sel(individuals="individual1"),
subject="individual1",
pose_estimation_series_kwargs=None,
pose_estimation_kwargs=None,
skeleton_kwargs=None,
Expand Down

0 comments on commit 179963d

Please sign in to comment.