Skip to content

Commit

Permalink
Merge branch 'liezl/add-multiview-datastructures' into liezl/add-came…
Browse files Browse the repository at this point in the history
…ragroup-class
  • Loading branch information
roomrys authored Jan 21, 2025
2 parents 3e3b37a + b2984af commit 8233ee9
Show file tree
Hide file tree
Showing 6 changed files with 250 additions and 57 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ jobs:
# Tests with pytest
tests:
timeout-minutes: 15
timeout-minutes: 25
strategy:
fail-fast: false
matrix:
Expand Down Expand Up @@ -92,7 +92,7 @@ jobs:
if: ${{ startsWith(matrix.os, 'ubuntu') }}
shell: bash -l {0}
run: |
sudo apt-get update && sudo apt-get install libglapi-mesa libegl-mesa0 libegl1 libopengl0 libgl1-mesa-glx
sudo apt-get update && sudo apt-get install libglapi-mesa libegl-mesa0 libegl1 libopengl0 libgl1 libglx-mesa0
- name: Test with pytest (with coverage)
shell: bash -l {0}
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ dependencies = [
"attrs",
"h5py>=3.8.0",
"pynwb",
"ndx-pose<0.2.0",
"ndx-pose>=0.2.1",
"pandas",
"simplejson",
"imageio",
Expand Down
168 changes: 133 additions & 35 deletions sleap_io/io/nwb.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Functions to write and read from the neurodata without borders (NWB) format."""

from copy import deepcopy
from typing import List, Optional, Union
from typing import List, Optional, Union, Dict
from pathlib import Path
import datetime
import uuid
Expand All @@ -15,14 +15,14 @@
except ImportError:
ArrayLike = np.ndarray
from pynwb import NWBFile, NWBHDF5IO, ProcessingModule # type: ignore[import]
from ndx_pose import PoseEstimationSeries, PoseEstimation # type: ignore[import]
from ndx_pose import PoseEstimationSeries, PoseEstimation, Skeleton, Skeletons # type: ignore[import]

from sleap_io import (
Labels,
Video,
LabeledFrame,
Track,
Skeleton,
Skeleton as SleapSkeleton,
Instance,
PredictedInstance,
)
Expand Down Expand Up @@ -116,31 +116,40 @@ def read_nwb(path: str) -> Labels:
"""
with NWBHDF5IO(path, mode="r", load_namespaces=True) as io:
read_nwbfile = io.read()
nwb_file = read_nwbfile.processing
nwb_file_processing = read_nwbfile.processing

# Get list of videos
video_keys: List[str] = [key for key in nwb_file.keys() if "SLEAP_VIDEO" in key]
video_keys: List[str] = [
key for key in nwb_file_processing.keys() if "SLEAP_VIDEO" in key
]
video_tracks = dict()

# Get track keys
test_processing_module: ProcessingModule = nwb_file[video_keys[0]]
# Get track keys from first video's processing module
test_processing_module: ProcessingModule = nwb_file_processing[video_keys[0]]
track_keys: List[str] = list(test_processing_module.fields["data_interfaces"])

# Get track
# Get first track's skeleton
test_pose_estimation: PoseEstimation = test_processing_module[track_keys[0]]
node_names = test_pose_estimation.nodes[:]
edge_inds = test_pose_estimation.edges[:]
skeleton = test_pose_estimation.skeleton
skeleton_nodes = skeleton.nodes[:]
skeleton_edges = skeleton.edges[:]

for processing_module in nwb_file.values():
# Filtering out behavior module with skeletons
pose_estimation_container_modules = [
nwb_file_processing[key] for key in video_keys
]

for processing_module in pose_estimation_container_modules:
# Get track keys
_track_keys: List[str] = list(processing_module.fields["data_interfaces"])
is_tracked: bool = re.sub("[0-9]+", "", _track_keys[0]) == "track"

# Figure out the max number of frames and the canonical timestamps
timestamps = np.empty(())
for track_key in _track_keys:
for node_name in node_names:
pose_estimation_series = processing_module[track_key][node_name]
pose_estimation = processing_module[track_key]
for node_name in skeleton.nodes:
pose_estimation_series = pose_estimation[node_name]
timestamps = np.union1d(
timestamps, get_timestamps(pose_estimation_series)
)
Expand All @@ -149,13 +158,13 @@ def read_nwb(path: str) -> Labels:
# Recreate Labels numpy (same as output of Labels.numpy())
n_tracks = len(_track_keys)
n_frames = len(timestamps)
n_nodes = len(node_names)
n_nodes = len(skeleton.nodes)
tracks_numpy = np.full((n_frames, n_tracks, n_nodes, 2), np.nan, np.float32)
confidence = np.full((n_frames, n_tracks, n_nodes), np.nan, np.float32)

for track_idx, track_key in enumerate(_track_keys):
pose_estimation = processing_module[track_key]

for node_idx, node_name in enumerate(node_names):
for node_idx, node_name in enumerate(skeleton.nodes):
pose_estimation_series = pose_estimation[node_name]
frame_inds = np.searchsorted(
timestamps, get_timestamps(pose_estimation_series)
Expand All @@ -173,10 +182,10 @@ def read_nwb(path: str) -> Labels:
is_tracked,
)

# Create skeleton
skeleton = Skeleton(
nodes=node_names,
edges=edge_inds,
# Create SLEAP skeleton from NWB skeleton
sleap_skeleton = SleapSkeleton(
nodes=skeleton_nodes,
edges=skeleton_edges.tolist(),
)

# Add instances to labeled frames
Expand All @@ -185,6 +194,7 @@ def read_nwb(path: str) -> Labels:
video = Video(filename=video_fn)
n_frames, n_tracks, n_nodes, _ = tracks_numpy.shape
tracks = [Track(name=f"track{track_idx}") for track_idx in range(n_tracks)]

for frame_idx, (frame_pts, frame_confs) in enumerate(
zip(tracks_numpy, confidence)
):
Expand All @@ -199,19 +209,77 @@ def read_nwb(path: str) -> Labels:
points=inst_pts, # (n_nodes, 2)
point_scores=inst_confs, # (n_nodes,)
instance_score=inst_confs.mean(), # ()
skeleton=skeleton,
skeleton=sleap_skeleton,
track=track if is_tracked else None,
)
)
if len(insts) > 0:
lfs.append(
LabeledFrame(video=video, frame_idx=frame_idx, instances=insts)
)

labels = Labels(lfs)
labels.provenance["filename"] = path
return labels


def create_skeleton_container(
labels: Labels,
nwbfile: NWBFile,
) -> Dict[str, Skeleton]:
"""Create NWB skeleton containers from SLEAP skeletons.
Args:
labels: SLEAP Labels object containing skeleton definitions
nwbfile: NWB file to add skeletons to
Returns:
Dictionary mapping skeleton names to NWB Skeleton objects
"""
skeleton_map = {}
nwb_skeletons = []

# Get or create behavior processing module
behavior_pm = nwbfile.processing.get("behavior")
if behavior_pm is None:
behavior_pm = nwbfile.create_processing_module(
name="behavior", description="processed behavioral data"
)

# Check if Skeletons container already exists
existing_skeletons = None
if "Skeletons" in behavior_pm.data_interfaces:
existing_skeletons = behavior_pm.data_interfaces["Skeletons"]
# Add existing skeletons to our map
for skeleton_name in existing_skeletons.skeletons:
nwb_skeleton = existing_skeletons.skeletons[skeleton_name]
skeleton_map[skeleton_name] = nwb_skeleton

# Create new skeletons for ones that don't exist yet
for sleap_skeleton in labels.skeletons:
if sleap_skeleton.name not in skeleton_map:
nwb_skeleton = Skeleton(
name=sleap_skeleton.name,
nodes=sleap_skeleton.node_names,
edges=np.array(sleap_skeleton.edge_inds, dtype="uint8"),
)
nwb_skeletons.append(nwb_skeleton)
skeleton_map[sleap_skeleton.name] = nwb_skeleton

# If we have new skeletons to add
if nwb_skeletons:
if existing_skeletons is None:
# Create new Skeletons container if none exists
skeletons_container = Skeletons(skeletons=nwb_skeletons)
behavior_pm.add(skeletons_container)
else:
# Add new skeletons to existing container
for skeleton in nwb_skeletons:
existing_skeletons.add_skeleton(skeleton)

return skeleton_map


def write_nwb(
labels: Labels,
nwbfile_path: str,
Expand Down Expand Up @@ -266,14 +334,22 @@ def write_nwb(
)

nwbfile = NWBFile(**nwb_file_kwargs)
nwbfile = append_nwb_data(labels, nwbfile, pose_estimation_metadata)

# Create skeleton containers first
skeleton_map = create_skeleton_container(labels, nwbfile)

# Then append pose data
nwbfile = append_nwb_data(labels, nwbfile, pose_estimation_metadata, skeleton_map)

with NWBHDF5IO(str(nwbfile_path), "w") as io:
io.write(nwbfile)


def append_nwb_data(
labels: Labels, nwbfile: NWBFile, pose_estimation_metadata: Optional[dict] = None
labels: Labels,
nwbfile: NWBFile,
pose_estimation_metadata: Optional[dict] = None,
skeleton_map: Optional[Dict[str, Skeleton]] = None,
) -> NWBFile:
"""Append data from a Labels object to an in-memory nwb file.
Expand All @@ -294,11 +370,14 @@ def append_nwb_data(
2) The other use of this dictionary is to ovewrite sleap-io default
arguments for the PoseEstimation container.
see https://github.com/rly/ndx-pose for a full list or arguments.
skeleton_map: Mapping of skeleton names to NWB Skeleton objects.
Returns:
An in-memory nwbfile with the data from the labels object appended.
"""
pose_estimation_metadata = pose_estimation_metadata or dict()
if skeleton_map is None:
skeleton_map = create_skeleton_container(labels=labels, nwbfile=nwbfile)

# Extract default metadata
provenance = labels.provenance
Expand All @@ -316,6 +395,16 @@ def append_nwb_data(
processing_module_name, nwbfile
)

device_name = f"camera_{video_index}"
if device_name in nwbfile.devices:
device = nwbfile.devices[device_name]
else:
device = nwbfile.create_device(
name=device_name,
description=f"Camera for {video_path.name}",
manufacturer="Unknown",
)

# Propagate video metadata
default_metadata["original_videos"] = [f"{video.filename}"] # type: ignore
default_metadata["labeled_videos"] = [f"{video.filename}"] # type: ignore
Expand All @@ -337,6 +426,8 @@ def append_nwb_data(
track_name,
video,
default_metadata,
skeleton_map,
devices=[device],
)
nwb_processing_module.add(pose_estimation_container)

Expand Down Expand Up @@ -395,6 +486,8 @@ def build_pose_estimation_container_for_track(
track_name: str,
video: Video,
pose_estimation_metadata: dict,
skeleton_map: Dict[str, Skeleton],
devices: Optional[List] = None,
) -> PoseEstimation:
"""Create a PoseEstimation container for a track.
Expand All @@ -404,7 +497,10 @@ def build_pose_estimation_container_for_track(
labels (Labels): A general labels object
track_name (str): The name of the track in labels.tracks
video (Video): The video to which data belongs to
pose_estimation_metadata: (dict) Metadata for pose estimation. See `append_nwb_data`
skeleton_map: Mapping of skeleton names to NWB Skeleton objects
skeleton_map: Mapping of skeleton names to NWB Skeleton objects
devices: Optional list of recording devices
Returns:
PoseEstimation: A PoseEstimation multicontainer where the time series
of all the node trajectories in the track are stored. One time series per
Expand All @@ -422,13 +518,15 @@ def build_pose_estimation_container_for_track(

# Assuming only one skeleton per track
skeleton_name = all_track_skeletons[0]
skeleton = next(
sleap_skeleton = next(
skeleton for skeleton in labels.skeletons if skeleton.name == skeleton_name
)
nwb_skeleton = skeleton_map[skeleton_name]

# Get track data
track_data_df = labels_data_df[
video.filename,
skeleton.name,
sleap_skeleton.name,
track_name,
]

Expand All @@ -448,12 +546,12 @@ def build_pose_estimation_container_for_track(
# Arrange and mix metadata
pose_estimation_container_kwargs = dict(
name=f"track={track_name}",
description=f"Estimated positions of {skeleton.name} in video {video_path.name}",
description=f"Estimated positions of {sleap_skeleton.name} in video {video_path.name}",
pose_estimation_series=pose_estimation_series_list,
nodes=skeleton.node_names,
edges=np.array(skeleton.edge_inds).astype("uint64"),
skeleton=nwb_skeleton,
source_software="SLEAP",
# dimensions=np.array([[video.backend.height, video.backend.width]]),
# dimensions=np.array([[video.height, video.width]], dtype="uint16"),
devices=devices or [],
)

pose_estimation_container_kwargs.update(**pose_estimation_metadata_copy)
Expand All @@ -468,12 +566,12 @@ def build_track_pose_estimation_list(
"""Build a list of PoseEstimationSeries from tracks.
Args:
track_data_df (pd.DataFrame): A pandas DataFrame object containing the
trajectories for all the nodes associated with a specific track.
track_data_df: A pandas DataFrame containing the trajectories
for all the nodes associated with a specific track.
timestamps: Array of timestamps for the data points
Returns:
List[PoseEstimationSeries]: The list of all the PoseEstimationSeries.
One for each node.
List of PoseEstimationSeries, one for each node.
"""
name_of_nodes_in_track = track_data_df.columns.get_level_values(
"node_name"
Expand All @@ -490,7 +588,7 @@ def build_track_pose_estimation_list(
reference_frame = (
"The coordinates are in (x, y) relative to the top-left of the image. "
"Coordinates refer to the midpoint of the pixel. "
"That is, t the midpoint of the top-left pixel is at (0, 0), whereas "
"That is, the midpoint of the top-left pixel is at (0, 0), whereas "
"the top-left corner of that same pixel is at (-0.5, -0.5)."
)

Expand Down
Loading

0 comments on commit 8233ee9

Please sign in to comment.