Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

I/O support for the ndx-pose NWB extension #166

Closed
wants to merge 48 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
bd12d04
Create nwb_export.py
edeno Apr 18, 2024
c5319b9
NWB requires one file per individual
edeno Apr 18, 2024
d82fe30
Add script
edeno Apr 19, 2024
d889105
Remove import error handling
edeno Apr 19, 2024
72aea47
Add nwb optional dependencies
edeno Apr 19, 2024
12bf83f
Fix linting based on pre-commit hooks
edeno Apr 19, 2024
742bf86
Add example docstring
edeno Apr 19, 2024
a06d485
Rename to fit module naming pattern
edeno Apr 19, 2024
739c4d8
Add import from nwb
edeno Apr 19, 2024
aef9b0c
Merge branch 'main' into main
edeno Apr 22, 2024
ce28f90
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 22, 2024
b9599a9
Merge remote-tracking branch 'upstream/main'
edeno Jun 8, 2024
2491cf6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 8, 2024
58bef93
Apply suggestions from code review
edeno Jun 8, 2024
3ab9aa4
Update make pynwb and ndx-pose core dependencies
edeno Jun 8, 2024
910ce90
Cleanup of docstrings and variable names from code review
edeno Jun 8, 2024
3f9a53b
Rename function for clarity
edeno Jun 8, 2024
3cd991d
Update with example converting back to movement
edeno Jun 8, 2024
3aa1b11
Add file validation and handling for single path
edeno Jun 8, 2024
e56cf6d
Add preliminary tests
edeno Jun 8, 2024
99a90c1
Convert to numpy array
edeno Jun 9, 2024
02b9975
Handle lack of confidence
edeno Jun 9, 2024
a2ac053
Display xarray
edeno Jun 9, 2024
84a495d
Refactor tests
edeno Jun 9, 2024
90c3287
Merge remote-tracking branch 'upstream/main'
edeno Jun 10, 2024
e9e1cef
Create nwb_export.py
edeno Apr 18, 2024
3ccd71c
NWB requires one file per individual
edeno Apr 18, 2024
f906cd5
Add script
edeno Apr 19, 2024
d35d9c2
Remove import error handling
edeno Apr 19, 2024
e5726d4
Add nwb optional dependencies
edeno Apr 19, 2024
53f505b
Fix linting based on pre-commit hooks
edeno Apr 19, 2024
f1d480d
Add example docstring
edeno Apr 19, 2024
4b162cf
Rename to fit module naming pattern
edeno Apr 19, 2024
4b887be
Add import from nwb
edeno Apr 19, 2024
96ee7ba
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 22, 2024
2f2625d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 8, 2024
1c7c2e3
Apply suggestions from code review
edeno Jun 8, 2024
4191ae8
Update make pynwb and ndx-pose core dependencies
edeno Jun 8, 2024
4202ff6
Cleanup of docstrings and variable names from code review
edeno Jun 8, 2024
3188b0b
Rename function for clarity
edeno Jun 8, 2024
4908040
Update with example converting back to movement
edeno Jun 8, 2024
da43e87
Add file validation and handling for single path
edeno Jun 8, 2024
9d34939
Add preliminary tests
edeno Jun 8, 2024
56a6672
Convert to numpy array
edeno Jun 9, 2024
b37b2c6
Handle lack of confidence
edeno Jun 9, 2024
f7d48ce
Display xarray
edeno Jun 9, 2024
0606add
Refactor tests
edeno Jun 9, 2024
dbf804a
Merge branch 'main' of https://github.com/edeno/movement
edeno Oct 25, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions examples/nwb_conversion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""Converting movement dataset to NWB or loading from NWB to movement dataset.
============================

Export pose tracks to NWB
"""

# %% Load the sample data
import datetime

from pynwb import NWBHDF5IO, NWBFile

from movement import sample_data
from movement.io.nwb import (
add_movement_dataset_to_nwb,
convert_nwb_to_movement,
)

ds = sample_data.fetch_dataset("DLC_two-mice.predictions.csv")

# %%The dataset has two individuals.
# We will create two NWBFiles for each individual

session_start_time = datetime.datetime.now(datetime.timezone.utc)
nwbfile_individual1 = NWBFile(
session_description="session_description",
identifier="individual1",
session_start_time=session_start_time,
)
nwbfile_individual2 = NWBFile(
session_description="session_description",
identifier="individual2",
session_start_time=session_start_time,
)

nwbfiles = [nwbfile_individual1, nwbfile_individual2]

# %% Convert the dataset to NWB
# This will create PoseEstimation and Skeleton objects for each
# individual and add them to the NWBFile
add_movement_dataset_to_nwb(nwbfiles, ds)

# %% Save the NWBFiles
for file in nwbfiles:
with NWBHDF5IO(f"{file.identifier}.nwb", "w") as io:
io.write(file)

# %% Convert the NWBFiles back to a movement dataset
# This will create a movement dataset with the same data as
# the original dataset from the NWBFiles

# Convert the NWBFiles to a movement dataset
ds_from_nwb = convert_nwb_to_movement(
nwb_filepaths=["individual1.nwb", "individual2.nwb"]
)
ds_from_nwb
293 changes: 293 additions & 0 deletions movement/io/nwb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,293 @@
"""Functions to convert movement data to and from NWB format."""

from pathlib import Path

import ndx_pose
import numpy as np
import pynwb
import xarray as xr

from movement.logging import log_error


def _create_pose_and_skeleton_objects(
ds: xr.Dataset,
subject: str,
pose_estimation_series_kwargs: dict | None = None,
pose_estimation_kwargs: dict | None = None,
skeleton_kwargs: dict | None = None,
) -> tuple[list[ndx_pose.PoseEstimation], ndx_pose.Skeletons]:
"""Create PoseEstimation and Skeletons objects from a ``movement`` dataset.

Parameters
----------
ds : xarray.Dataset
movement dataset containing the data to be converted to NWB.
subject : str
Name of the subject (individual) to be converted.
pose_estimation_series_kwargs : dict, optional
PoseEstimationSeries keyword arguments. See ndx_pose, by default None
pose_estimation_kwargs : dict, optional
PoseEstimation keyword arguments. See ndx_pose, by default None
skeleton_kwargs : dict, optional
Skeleton keyword arguments. See ndx_pose, by default None

Returns
-------
pose_estimation : list[ndx_pose.PoseEstimation]
List of PoseEstimation objects
skeletons : ndx_pose.Skeletons
Skeletons object containing all skeletons

"""
if pose_estimation_series_kwargs is None:
pose_estimation_series_kwargs = dict(
reference_frame="(0,0,0) corresponds to ...",
confidence_definition=None,
conversion=1.0,
resolution=-1.0,
offset=0.0,
starting_time=None,
comments="no comments",
description="no description",
control=None,
control_description=None,
)

if skeleton_kwargs is None:
skeleton_kwargs = dict(edges=None)

if pose_estimation_kwargs is None:
pose_estimation_kwargs = dict(
original_videos=None,
labeled_videos=None,
dimensions=None,
devices=None,
scorer=None,
source_software_version=None,
)

pose_estimation_series = []

for keypoint in ds.keypoints.to_numpy():
pose_estimation_series.append(
ndx_pose.PoseEstimationSeries(
name=keypoint,
data=ds.sel(keypoints=keypoint).position.to_numpy(),
confidence=ds.sel(keypoints=keypoint).confidence.to_numpy(),
unit="pixels",
timestamps=ds.sel(keypoints=keypoint).time.to_numpy(),
**pose_estimation_series_kwargs,
)
)

skeleton_list = [
ndx_pose.Skeleton(
name=f"{subject}_skeleton",
nodes=ds.keypoints.to_numpy().tolist(),
**skeleton_kwargs,
)
]

bodyparts_str = ", ".join(ds.keypoints.to_numpy().tolist())
description = (
f"Estimated positions of {bodyparts_str} of"
f"{subject} using {ds.source_software}."
)

pose_estimation = [
ndx_pose.PoseEstimation(
name="PoseEstimation",
pose_estimation_series=pose_estimation_series,
description=description,
source_software=ds.source_software,
skeleton=skeleton_list[-1],
**pose_estimation_kwargs,
)
]

skeletons = ndx_pose.Skeletons(skeletons=skeleton_list)

return pose_estimation, skeletons


def add_movement_dataset_to_nwb(
nwbfiles: list[pynwb.NWBFile] | pynwb.NWBFile,
movement_dataset: xr.Dataset,
pose_estimation_series_kwargs: dict | None = None,
pose_estimation_kwargs: dict | None = None,
skeletons_kwargs: dict | None = None,
) -> None:
"""Add pose estimation data to NWB files for each individual.

Parameters
----------
nwbfiles : list[pynwb.NWBFile] | pynwb.NWBFile
NWBFile object(s) to which the data will be added.
movement_dataset : xr.Dataset
``movement`` dataset containing the data to be converted to NWB.
pose_estimation_series_kwargs : dict, optional
PoseEstimationSeries keyword arguments. See ndx_pose, by default None
pose_estimation_kwargs : dict, optional
PoseEstimation keyword arguments. See ndx_pose, by default None
skeletons_kwargs : dict, optional
Skeleton keyword arguments. See ndx_pose, by default None

Raises
------
ValueError
If the number of NWBFiles is not equal to the number of individuals
in the dataset.

"""
if isinstance(nwbfiles, pynwb.NWBFile):
nwbfiles = [nwbfiles]

if len(nwbfiles) != len(movement_dataset.individuals):
raise log_error(
ValueError,
"Number of NWBFiles must be equal to the number of individuals. "
"NWB requires one file per individual.",
)

for nwbfile, subject in zip(
nwbfiles, movement_dataset.individuals.to_numpy(), strict=False
):
pose_estimation, skeletons = _create_pose_and_skeleton_objects(
movement_dataset.sel(individuals=subject),
subject,
pose_estimation_series_kwargs,
pose_estimation_kwargs,
skeletons_kwargs,
)
try:
behavior_pm = nwbfile.create_processing_module(
name="behavior",
description="processed behavioral data",
)
except ValueError:
print("Behavior processing module already exists. Skipping...")
behavior_pm = nwbfile.processing["behavior"]

try:
behavior_pm.add(skeletons)
except ValueError:
print("Skeletons already exists. Skipping...")
try:
behavior_pm.add(pose_estimation)
except ValueError:
print("PoseEstimation already exists. Skipping...")


def _convert_pose_estimation_series(
pose_estimation_series: ndx_pose.PoseEstimationSeries,
keypoint: str,
subject_name: str,
source_software: str,
source_file: str | None = None,
) -> xr.Dataset:
"""Convert to single-keypoint, single-individual ``movement`` dataset.

Parameters
----------
pose_estimation_series : ndx_pose.PoseEstimationSeries
PoseEstimationSeries NWB object to be converted.
keypoint : str
Name of the keypoint - body part.
subject_name : str
Name of the subject (individual).
source_software : str
Name of the software used to estimate the pose.
source_file : Optional[str], optional
File from which the data was extracted, by default None

Returns
-------
movement_dataset : xr.Dataset
``movement`` compatible dataset containing the pose estimation data.

"""
attrs = {
"fps": np.nanmedian(1 / np.diff(pose_estimation_series.timestamps)),
"time_units": pose_estimation_series.timestamps_unit,
"source_software": source_software,
"source_file": source_file,
}
n_space_dims = pose_estimation_series.data.shape[1]
space_dims = ["x", "y", "z"]

position_array = np.asarray(pose_estimation_series.data)[
:, np.newaxis, np.newaxis, :
]

if getattr(pose_estimation_series, "confidence", None) is None:
pose_estimation_series.confidence = np.full(
pose_estimation_series.data.shape[0], np.nan
)
else:
confidence_array = np.asarray(pose_estimation_series.confidence)[
:, np.newaxis, np.newaxis
]

return xr.Dataset(
data_vars={
"position": (
["time", "individuals", "keypoints", "space"],
position_array,
),
"confidence": (
["time", "individuals", "keypoints"],
confidence_array,
),
},
coords={
"time": pose_estimation_series.timestamps,
"individuals": [subject_name],
"keypoints": [keypoint],
"space": space_dims[:n_space_dims],
},
attrs=attrs,
)


def convert_nwb_to_movement(
nwb_filepaths: str | list[str] | list[Path],
) -> xr.Dataset:
"""Convert a list of NWB files to a single ``movement`` dataset.

Parameters
----------
nwb_filepaths : str | Path | list[str] | list[Path]
List of paths to NWB files to be converted.

Returns
-------
movement_ds : xr.Dataset
``movement`` dataset containing the pose estimation data.

"""
if isinstance(nwb_filepaths, str | Path):
nwb_filepaths = [nwb_filepaths]

datasets = []
for path in nwb_filepaths:
with pynwb.NWBHDF5IO(path, mode="r") as io:
nwbfile = io.read()
pose_estimation = nwbfile.processing["behavior"]["PoseEstimation"]
source_software = pose_estimation.fields["source_software"]
pose_estimation_series = pose_estimation.fields[
"pose_estimation_series"
]

for keypoint, pes in pose_estimation_series.items():
datasets.append(
_convert_pose_estimation_series(
pes,
keypoint,
subject_name=nwbfile.identifier,
source_software=source_software,
source_file=None,
)
)

return xr.merge(datasets)
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ dependencies = [
"sleap-io",
"xarray[accel,viz]",
"PyYAML",
"pynwb",
"ndx-pose>=0.2",
]

classifiers = [
Expand Down
Loading
Loading