From 855bc15d5d9b32ed01ddd9b2045deb62a1365fbe Mon Sep 17 00:00:00 2001 From: niksirbi Date: Tue, 4 Jul 2023 14:40:17 +0100 Subject: [PATCH 01/79] added sleap-io as dependency --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index bdcbbea4..18d803f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ dependencies = [ "pydantic", "pooch", "tqdm", + "sleap-io", ] classifiers = [ From e69862b72caa3b6def802d83fa7c9932a9d7cc37 Mon Sep 17 00:00:00 2001 From: niksirbi Date: Tue, 4 Jul 2023 17:25:17 +0100 Subject: [PATCH 02/79] added function for converting SLEAP poses into DLC-style df --- movement/io/converters.py | 68 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) create mode 100644 movement/io/converters.py diff --git a/movement/io/converters.py b/movement/io/converters.py new file mode 100644 index 00000000..c05a2dd7 --- /dev/null +++ b/movement/io/converters.py @@ -0,0 +1,68 @@ +""" +Functions to convert between different formats, +e.g. from DeepLabCut to SLEAP and vice versa. +""" +import logging + +import pandas as pd + +# get logger +logger = logging.getLogger(__name__) + + +def sleap_poses_to_dlc_df(pose_tracks: dict) -> pd.DataFrame: + """Convert pose tracking data from SLEAP labels to a DeepLabCut-style + DataFrame with multi-index columns. See Notes for details. + + Parameters + ---------- + pose_tracks : dict + Dictionary containing `pose_tracks`, `node_names` and `track_names`. + This dictionary is returned by `io.load_poses.from_sleap`. + + Returns + ------- + pandas DataFrame + DataFrame containing pose tracks in DLC style, with the multi-index + columns ("scorer", "individuals", "bodyparts", "coords"). + + Notes + ----- + Correspondence between SLEAP and DLC terminology: + - DLC "scorer" has no equivalent in SLEAP, so we assign it to "SLEAP" + - DLC "individuals" are the names of SLEAP "tracks" + - DLC "bodyparts" are the names of SLEAP "nodes" (i.e. the keypoints) + - DLC "coords" are referred to in SLEAP as "dims" + (i.e. "x" coord + "y" coord + "confidence/likelihood") + - DLC reports "likelihood" while SLEAP reports "confidence". + These both measure the point-wise prediction confidence but do not + have the same range and cannot be compared between the two frameworks. + """ + + # Get the number of frames, tracks, nodes and dimensions + n_frames, n_tracks, n_nodes, n_dims = pose_tracks["tracks"].shape + # Use the DLC terminology: scorer, individuals, bodyparts, coords + # The assigned scorer is always "DeepLabCut" + scorer = ["SLEAP"] + individuals = pose_tracks["track_names"] + bodyparts = pose_tracks["node_names"] + coords = ["x", "y", "likelihood"] + + # Create the DLC-style multi-index dataframe + index_levels = ["scorer", "individuals", "bodyparts", "coords"] + columns = pd.MultiIndex.from_product( + [scorer, individuals, bodyparts, coords], names=index_levels + ) + df = pd.DataFrame( + data=pose_tracks["tracks"].reshape(n_frames, -1), + index=pd.RangeIndex(0, n_frames), + columns=columns, + dtype=float, + ) + + # Log the conversion + logger.info( + f"Converted SLEAP pose tracks to DLC-style DataFrame " + f"with shape {df.shape}" + ) + return df From 044cff45dc5da1f4f02c6bc9ce94218e674eb27d Mon Sep 17 00:00:00 2001 From: niksirbi Date: Tue, 4 Jul 2023 17:27:30 +0100 Subject: [PATCH 03/79] added functions for loading SLEAP pose tracks --- movement/io/load_poses.py | 149 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 149 insertions(+) diff --git a/movement/io/load_poses.py b/movement/io/load_poses.py index e7b92f6f..bd6df1e5 100644 --- a/movement/io/load_poses.py +++ b/movement/io/load_poses.py @@ -2,7 +2,10 @@ from pathlib import Path from typing import Optional, Union +import h5py +import numpy as np import pandas as pd +from sleap_io.io.slp import read_labels from movement.io.validators import DeepLabCutPosesFile @@ -90,3 +93,149 @@ def _parse_dlc_csv_to_dataframe(file_path: Path) -> pd.DataFrame: ) df.columns.rename(level_names, inplace=True) return df + + +def from_sleap(file_path: Union[Path, str]) -> dict: + """Load pose tracking data from a SLEAP labels file. + + Parameters + ---------- + file_path : pathlib Path or str + Path to the file containing the SLEAP predictions, either in ".slp" + or ".h5" (analysis) format. See Notes for more information. + + Returns + ------- + dict + Dictionary containing `pose_tracks`, `node_names` and `track_names`. + - `pose_tracks` is an array containing the predicted poses. + Shape: (n_frames, n_tracks, n_nodes, n_dims). The last axis + contains the spatial coordinates "x" and "y", as well as the + point-wise confidence values. + - `node_names` is a list of the node names. + - `track_names` is a list of the track names. + + Notes + ----- + The SLEAP inference procedure normally produces a file suffixed with ".slp" + containing the predictions, e.g. "myproject.predictions.slp". + This can be converted to an ".h5" (analysis) file using the command line + tool `sleap-convert` with the "--format analysis" option enabled, + or alternatively by choosing “Export Analysis HDF5…” from the “File” menu + of the SLEAP GUI [1]_. + + This function will only the predicted instances in the ".slp" file, + not the user-labeled ones. + + movement expects the tracks to be proofread before loading them. + There should be as many tracks as there are instances (animals) in the + video, without identity switches. Follow the SLEAP guide for + tracking and proofreading [2]_. + + References + ---------- + .. [1] https://sleap.ai/tutorials/analysis.html + .. [2] https://sleap.ai/guides/proofreading.html + + Examples + -------- + >>> from movement.io import load_poses + >>> poses = load_poses.from_sleap("path/to/labels.predictions.slp") + """ + + if not isinstance(file_path, Path): + file_path = Path(file_path) + + if file_path.suffix == ".h5": + # Load the SLEAP predictions from an analysis file + poses = _load_sleap_analysis_file(file_path) + elif file_path.suffix == ".slp": + # Load the SLEAP predictions from a labels file + poses = _load_sleap_labels_file(file_path) + else: + error_msg = ( + f"Expected file suffix to be '.h5' or '.slp', " + f"but got '{file_path.suffix}'. Make sure the file is " + "a SLEAP labels file with suffix '.slp' or SLEAP analysis " + "file with suffix '.h5'." + ) + logger.error(error_msg) + raise ValueError(error_msg) + + n_frames, n_tracks, n_nodes, n_dims = poses["tracks"].shape + logger.info(f"Loaded poses from {file_path}.") + logger.debug( + f"Shape: ({n_frames} frames, {n_tracks} tracks, " + f"{n_nodes} nodes, {n_dims - 1} spatial coords " + "+ 1 confidence score)" + ) + logger.info(f"Track names: {poses['track_names']}") + logger.info(f"Node names: {poses['node_names']}") + return poses + + +def _load_sleap_analysis_file(file_path: Path) -> dict: + """Load pose tracking data from a SLEAP analysis file. + + Parameters + ---------- + file_path : pathlib Path + Path to the file containing the SLEAP predictions, in ".h5" format. + + Returns + ------- + dict + Dictionary containing `pose_tracks`, `node_names` and `track_names`. + """ + + # Load the SLEAP poses + with h5py.File(file_path, "r") as f: + # First, load and reshape the pose tracks + tracks = f["tracks"][:].T + n_frames, n_nodes, n_dims, n_tracks = tracks.shape + tracks = tracks.reshape((n_frames, n_tracks, n_nodes, n_dims)) + + # If present, read the point-wise confidence scores + # and add them to the "tracks" array + confidence = np.full( + (n_frames, n_tracks, n_nodes, 3), np.nan, dtype="float32" + ) + if "point_scores" in f.keys(): + confidence = f["point_scores"][:].T + confidence = confidence.reshape((n_frames, n_tracks, n_nodes)) + tracks = np.concatenate( + [tracks, confidence[:, :, :, np.newaxis]], axis=3 + ) + + # Create the dictionary to be returned + poses = { + "tracks": tracks, + "node_names": [n.decode() for n in f["node_names"][:]], + "track_names": [n.decode() for n in f["track_names"][:]], + } + return poses + + +def _load_sleap_labels_file(file_path: Path) -> dict: + """Load pose tracking data from a SLEAP labels file. + + Parameters + ---------- + file_path : pathlib Path + Path to the file containing the SLEAP predictions, in ".slp" format. + + Returns + ------- + dict + Dictionary containing `pose_tracks`, `node_names` and `track_names`. + """ + labels = read_labels(file_path.as_posix()) + poses = { + "tracks": labels.numpy(return_confidence=True), + "node_names": [node.name for node in labels.skeletons[0].nodes], + "track_names": [track.name for track in labels.tracks], + } + # return_confidence=True adds the point-wise confidence scores + # as an extra coord dimension to the "tracks" array + + return poses From 6b8fdbbe7b703fd20d45da1f4aaf1031fad42eeb Mon Sep 17 00:00:00 2001 From: niksirbi Date: Mon, 17 Jul 2023 14:14:08 +0100 Subject: [PATCH 04/79] renamed converters module to convert --- movement/io/{converters.py => convert.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename movement/io/{converters.py => convert.py} (100%) diff --git a/movement/io/converters.py b/movement/io/convert.py similarity index 100% rename from movement/io/converters.py rename to movement/io/convert.py From e8de7702b7a57749df744b762718631864f1828b Mon Sep 17 00:00:00 2001 From: niksirbi Date: Mon, 17 Jul 2023 14:18:24 +0100 Subject: [PATCH 05/79] renamed converters module to convert --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 18d803f7..6f125741 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ dependencies = [ "pooch", "tqdm", "sleap-io", + "xarray", ] classifiers = [ From ce28d1daecf1b0f5a09c16146c783de62c5e8703 Mon Sep 17 00:00:00 2001 From: niksirbi Date: Mon, 17 Jul 2023 14:38:50 +0100 Subject: [PATCH 06/79] add attrs as dependency --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 6f125741..01ddc1ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ dependencies = [ "tqdm", "sleap-io", "xarray", + "attrs", ] classifiers = [ From 5cdd989f039f335666bcf3495e385d4df5d3aead Mon Sep 17 00:00:00 2001 From: niksirbi Date: Tue, 18 Jul 2023 19:24:23 +0100 Subject: [PATCH 07/79] Implemented PoseTracks class with import functions from SLEAP --- movement/io/PoseTracks.py | 227 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 227 insertions(+) create mode 100644 movement/io/PoseTracks.py diff --git a/movement/io/PoseTracks.py b/movement/io/PoseTracks.py new file mode 100644 index 00000000..a0b60a3f --- /dev/null +++ b/movement/io/PoseTracks.py @@ -0,0 +1,227 @@ +from pathlib import Path +from typing import ClassVar, Optional + +import h5py +import numpy as np +import pandas as pd +import xarray as xr +from sleap_io.io.slp import read_labels + + +class PoseTracks(xr.Dataset): + """Pose tracking data with point-wise confidence scores. + + This is a subclass of `xarray.Dataset`, with the following dimensions: + - `frames`: the number of frames in the video + - `individuals`: the number of individuals in the video + - `keypoints`: the number of keypoints in the skeleton + - `space`: the number of spatial dimensions, either 2 (x,y) or 3 (x,y,z) + + The dataset contains two data variables: + - `pose_tracks`: with shape (`frames`, `individuals`, `keypoints`, `space`) + - `confidence_scores`: with shape (`frames`, `individuals`, `keypoints`) + + The dataset may also contain following attributes as metadata: + - `fps`: the number of frames per second in the video + - `source_software`: the software from which the pose tracks were loaded + - `source_file`: the file from which the pose tracks were loaded + """ + + dim_names: ClassVar[tuple] = ( + "frames", + "individuals", + "keypoints", + "space", + ) + + __slots__ = ("fps", "source_software", "source_file") + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @classmethod + def from_dict( + cls, + dict_: dict, + ): + """Create an xarray.Dataset from a dictionary of pose tracks, + confidence scores, and metadata. + + Parameters + ---------- + dict_ : dict + A dictionary with the following keys: + - "pose_tracks": np.ndarray of shape (n_frames, n_individuals, + n_keypoints, n_space_dims) + - "confidence_scores": np.ndarray of shape (n_frames, + n_individuals, n_keypoints) + - "individual_names": list of strings, with length individuals + - "keypoint_names": list of strings, with length n_keypoints + - "fps": float, the number of frames per second in the video. + If None, the "time" coordinate will not be added. + + Returns + ------- + xarray.Dataset + xarray.Dataset containing `pose_tracks` and `confidence_scores`. + """ + + # Convert the pose tracks and confidence scores to xarray.DataArray + tracks_da = xr.DataArray(dict_["pose_tracks"], dims=cls.dim_names) + scores_da = xr.DataArray( + dict_["confidence_scores"], dims=cls.dim_names[:-1] + ) + + # Combine the DataArrays into a Dataset, with common coordinates + ds = cls( + data_vars={ + "pose_tracks": tracks_da, + "confidence_scores": scores_da, + }, + coords={ + cls.dim_names[0]: np.arange( + dict_["pose_tracks"].shape[0], dtype=int + ), + cls.dim_names[1]: dict_["individual_names"], + cls.dim_names[2]: dict_["keypoint_names"], + cls.dim_names[3]: ["x", "y", "z"][ + : dict_["pose_tracks"].shape[-1] + ], + }, + attrs={"fps": dict_["fps"]}, + ) + + # If fps is given, create "time" coords for 1st ("frames") dimension + if dict_["fps"] is not None: + times = pd.TimedeltaIndex( + ds.coords["frames"] / dict_["fps"], unit="s" + ) + ds.coords["time"] = (cls.dim_names[0], times) + + return ds + + @classmethod + def from_sleap(cls, file_path: Path, fps: Optional[float] = None): + """Load pose tracking data from a SLEAP labels or analysis file. + + Parameters + ---------- + file_path : pathlib Path or str + Path to the file containing the SLEAP predictions, either in ".slp" + or ".h5" (analysis) format. See Notes for more information. + fps : float, optional + The number of frames per second in the video. If None (default), + the "time" coordinate will not be created. + + Notes + ----- + The SLEAP inference procedure normally produces a file suffixed with + ".slp" containing the predictions, e.g. "myproject.predictions.slp". + This can be converted to an ".h5" (analysis) file using the command + line tool `sleap-convert` with the "--format analysis" option enabled, + or alternatively by choosing "Export Analysis HDF5…" from the "File" + menu of the SLEAP GUI [1]_. + + This function will only load the predicted instances in the ".slp", + file not the user-labeled ones. + + `movement` expects the tracks to be proofread before loading them. + There should be as many tracks as there are instances (animals) in the + video, without identity switches. Follow the SLEAP guide for + tracking and proofreading [2]_. + + References + ---------- + .. [1] https://sleap.ai/tutorials/analysis.html + .. [2] https://sleap.ai/guides/proofreading.html + + Examples + -------- + >>> from movement.io import PoseTracks + >>> poses = PoseTracks.from_sleap("path/to/labels.predictions.slp")""" + + if not isinstance(file_path, Path): + file_path = Path(file_path) + + if file_path.suffix == ".h5": + with h5py.File(file_path, "r") as f: + tracks = f["tracks"][:].T + n_frames, n_nodes, n_space_dims, n_tracks = tracks.shape + tracks = tracks.reshape( + (n_frames, n_tracks, n_nodes, n_space_dims) + ) + # Create an array of NaNs for the confidence scores + scores = np.full( + (n_frames, n_tracks, n_nodes), np.nan, dtype="float32" + ) + # If present, read the point-wise scores, and reshape them + if "point_scores" in f.keys(): + scores = f["point_scores"][:].reshape( + (n_frames, n_tracks, n_nodes) + ) + individual_names = [n.decode() for n in f["track_names"][:]] + keypoint_names = [n.decode() for n in f["node_names"][:]] + elif file_path.suffix == ".slp": + labels = read_labels(file_path.as_posix()) + tracks_with_scores = labels.numpy( + return_confidence=True, untracked=False + ) + tracks = tracks_with_scores[:, :, :, :-1] + scores = tracks_with_scores[:, :, :, -1] + individual_names = [track.name for track in labels.tracks] + keypoint_names = [node.name for node in labels.skeletons[0].nodes] + else: + error_msg = ( + f"Expected file suffix to be '.h5' or '.slp', " + f"but got '{file_path.suffix}'. Make sure the file is " + "a SLEAP labels file with suffix '.slp' or SLEAP analysis " + "file with suffix '.h5'." + ) + # logger.error(error_msg) + raise ValueError(error_msg) + + ds = cls.from_dict( + { + "pose_tracks": tracks, + "confidence_scores": scores, + "individual_names": individual_names, + "keypoint_names": keypoint_names, + "fps": fps, + } + ) + # Add metadata to the dataset.attrs dictionary + ds.attrs["source_software"] = "SLEAP" + ds.attrs["source_file"] = file_path.as_posix() + return ds + + +if __name__ == "__main__": + from movement.datasets import fetch_pose_data_path + + h5_file = fetch_pose_data_path("SLEAP_single-mouse_EPM.analysis.h5") + slp_file = fetch_pose_data_path("SLEAP_single-mouse_EPM.predictions.slp") + + h5_poses = PoseTracks.from_sleap(h5_file, fps=60) + slp_poses = PoseTracks.from_sleap(slp_file, fps=60) + + # Plot the trajectories - 2 subplots: h5 and slp + from matplotlib import pyplot as plt + + titles = ["h5", "slp"] + fig, axes = plt.subplots(1, 2, figsize=(10, 5)) + for i, poses in enumerate([h5_poses, slp_poses]): + tracks = poses.pose_tracks[:, 0, 0, :].to_pandas() + # Plot the trajectories of the first individual, first keypoint + tracks.plot( + title=titles[i], + x="x", + y="y", + s=1, + c=tracks.index.values, + kind="scatter", + backend="matplotlib", + cmap="viridis", + ax=axes[i], + colorbar=False, + ) + plt.show() From 3e86805ea4af5dba4bb7610a992372f4b13a2940 Mon Sep 17 00:00:00 2001 From: niksirbi Date: Wed, 19 Jul 2023 15:23:44 +0100 Subject: [PATCH 08/79] modified docstrings in PoseTracks --- movement/io/PoseTracks.py | 71 +++++++++++++-------------------------- 1 file changed, 23 insertions(+), 48 deletions(-) diff --git a/movement/io/PoseTracks.py b/movement/io/PoseTracks.py index a0b60a3f..3e1c10d9 100644 --- a/movement/io/PoseTracks.py +++ b/movement/io/PoseTracks.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import ClassVar, Optional +from typing import ClassVar, Optional, Union import h5py import numpy as np @@ -9,15 +9,24 @@ class PoseTracks(xr.Dataset): - """Pose tracking data with point-wise confidence scores. + """Dataset containing pose tracks and point-wise confidence scores. - This is a subclass of `xarray.Dataset`, with the following dimensions: + This is a `xarray.Dataset` object, with the following dimensions: - `frames`: the number of frames in the video - `individuals`: the number of individuals in the video - `keypoints`: the number of keypoints in the skeleton - - `space`: the number of spatial dimensions, either 2 (x,y) or 3 (x,y,z) + - `space`: the number of spatial dimensions, either 2 or 3 - The dataset contains two data variables: + Each dimension is assigned appropriate coordinates: + - frame indices (int) for `frames` + - list of unique names (str) for `individuals` and `keypoints` + - `x`, `y` (and `z`) for `space` + + If `fps` is supplied, the `frames` dimension is also assigned a `time` + coordinate. If `fps` is None, the temporal dimension can only be + accessed through frame indices. + + The dataset contains two data variables (`xarray.DataArray` objects): - `pose_tracks`: with shape (`frames`, `individuals`, `keypoints`, `space`) - `confidence_scores`: with shape (`frames`, `individuals`, `keypoints`) @@ -44,7 +53,7 @@ def from_dict( cls, dict_: dict, ): - """Create an xarray.Dataset from a dictionary of pose tracks, + """Create a `PosteTracks` dataset from a dictionary of pose tracks, confidence scores, and metadata. Parameters @@ -59,11 +68,6 @@ def from_dict( - "keypoint_names": list of strings, with length n_keypoints - "fps": float, the number of frames per second in the video. If None, the "time" coordinate will not be added. - - Returns - ------- - xarray.Dataset - xarray.Dataset containing `pose_tracks` and `confidence_scores`. """ # Convert the pose tracks and confidence scores to xarray.DataArray @@ -101,7 +105,9 @@ def from_dict( return ds @classmethod - def from_sleap(cls, file_path: Path, fps: Optional[float] = None): + def from_sleap( + cls, file_path: Union[Path, str], fps: Optional[float] = None + ): """Load pose tracking data from a SLEAP labels or analysis file. Parameters @@ -111,7 +117,7 @@ def from_sleap(cls, file_path: Path, fps: Optional[float] = None): or ".h5" (analysis) format. See Notes for more information. fps : float, optional The number of frames per second in the video. If None (default), - the "time" coordinate will not be created. + the `time` coordinate will not be created. Notes ----- @@ -122,8 +128,8 @@ def from_sleap(cls, file_path: Path, fps: Optional[float] = None): or alternatively by choosing "Export Analysis HDF5…" from the "File" menu of the SLEAP GUI [1]_. - This function will only load the predicted instances in the ".slp", - file not the user-labeled ones. + If the ".slp" file contains both user-labeled and predicted instances, + this function will only load the ones predicted by the SLEAP model `movement` expects the tracks to be proofread before loading them. There should be as many tracks as there are instances (animals) in the @@ -138,7 +144,8 @@ def from_sleap(cls, file_path: Path, fps: Optional[float] = None): Examples -------- >>> from movement.io import PoseTracks - >>> poses = PoseTracks.from_sleap("path/to/labels.predictions.slp")""" + >>> poses = PoseTracks.from_sleap("path/to/v1.predictions.slp", fps=30) + """ if not isinstance(file_path, Path): file_path = Path(file_path) @@ -193,35 +200,3 @@ def from_sleap(cls, file_path: Path, fps: Optional[float] = None): ds.attrs["source_software"] = "SLEAP" ds.attrs["source_file"] = file_path.as_posix() return ds - - -if __name__ == "__main__": - from movement.datasets import fetch_pose_data_path - - h5_file = fetch_pose_data_path("SLEAP_single-mouse_EPM.analysis.h5") - slp_file = fetch_pose_data_path("SLEAP_single-mouse_EPM.predictions.slp") - - h5_poses = PoseTracks.from_sleap(h5_file, fps=60) - slp_poses = PoseTracks.from_sleap(slp_file, fps=60) - - # Plot the trajectories - 2 subplots: h5 and slp - from matplotlib import pyplot as plt - - titles = ["h5", "slp"] - fig, axes = plt.subplots(1, 2, figsize=(10, 5)) - for i, poses in enumerate([h5_poses, slp_poses]): - tracks = poses.pose_tracks[:, 0, 0, :].to_pandas() - # Plot the trajectories of the first individual, first keypoint - tracks.plot( - title=titles[i], - x="x", - y="y", - s=1, - c=tracks.index.values, - kind="scatter", - backend="matplotlib", - cmap="viridis", - ax=axes[i], - colorbar=False, - ) - plt.show() From 6aa9df491b5448b4d2b02ffdef214405df6f646b Mon Sep 17 00:00:00 2001 From: niksirbi Date: Wed, 19 Jul 2023 16:13:00 +0100 Subject: [PATCH 09/79] refactored from_sleap() classmethod --- movement/io/PoseTracks.py | 89 ++++++++++++++++++++++++--------------- 1 file changed, 55 insertions(+), 34 deletions(-) diff --git a/movement/io/PoseTracks.py b/movement/io/PoseTracks.py index 3e1c10d9..003c67d3 100644 --- a/movement/io/PoseTracks.py +++ b/movement/io/PoseTracks.py @@ -1,3 +1,4 @@ +import logging from pathlib import Path from typing import ClassVar, Optional, Union @@ -7,6 +8,9 @@ import xarray as xr from sleap_io.io.slp import read_labels +# get logger +logger = logging.getLogger(__name__) + class PoseTracks(xr.Dataset): """Dataset containing pose tracks and point-wise confidence scores. @@ -150,33 +154,11 @@ def from_sleap( if not isinstance(file_path, Path): file_path = Path(file_path) + # Load data into a dictionary if file_path.suffix == ".h5": - with h5py.File(file_path, "r") as f: - tracks = f["tracks"][:].T - n_frames, n_nodes, n_space_dims, n_tracks = tracks.shape - tracks = tracks.reshape( - (n_frames, n_tracks, n_nodes, n_space_dims) - ) - # Create an array of NaNs for the confidence scores - scores = np.full( - (n_frames, n_tracks, n_nodes), np.nan, dtype="float32" - ) - # If present, read the point-wise scores, and reshape them - if "point_scores" in f.keys(): - scores = f["point_scores"][:].reshape( - (n_frames, n_tracks, n_nodes) - ) - individual_names = [n.decode() for n in f["track_names"][:]] - keypoint_names = [n.decode() for n in f["node_names"][:]] + dict_ = cls._load_dict_from_sleap_analysis_file(file_path) elif file_path.suffix == ".slp": - labels = read_labels(file_path.as_posix()) - tracks_with_scores = labels.numpy( - return_confidence=True, untracked=False - ) - tracks = tracks_with_scores[:, :, :, :-1] - scores = tracks_with_scores[:, :, :, -1] - individual_names = [track.name for track in labels.tracks] - keypoint_names = [node.name for node in labels.skeletons[0].nodes] + dict_ = cls._load_dict_from_sleap_labels_file(file_path) else: error_msg = ( f"Expected file suffix to be '.h5' or '.slp', " @@ -187,16 +169,55 @@ def from_sleap( # logger.error(error_msg) raise ValueError(error_msg) - ds = cls.from_dict( - { - "pose_tracks": tracks, - "confidence_scores": scores, - "individual_names": individual_names, - "keypoint_names": keypoint_names, - "fps": fps, - } + logger.debug( + f"Loaded pose tracks from {file_path.as_posix()} into a dict." ) - # Add metadata to the dataset.attrs dictionary + + # Initialize a PoseTracks dataset from the dictionary + ds = cls.from_dict({**dict_, "fps": fps}) + + # Add metadata as attrs ds.attrs["source_software"] = "SLEAP" ds.attrs["source_file"] = file_path.as_posix() return ds + + @staticmethod + def _load_dict_from_sleap_analysis_file(file_path: Path): + """Load pose tracks and confidence scores from a SLEAP analysis + file into a dictionary.""" + + with h5py.File(file_path, "r") as f: + tracks = f["tracks"][:].T + n_frames, n_keypoints, n_space, n_tracks = tracks.shape + tracks = tracks.reshape((n_frames, n_tracks, n_keypoints, n_space)) + # Create an array of NaNs for the confidence scores + scores = np.full( + (n_frames, n_tracks, n_keypoints), np.nan, dtype="float32" + ) + # If present, read the point-wise scores, and reshape them + if "point_scores" in f.keys(): + scores = f["point_scores"][:].reshape( + (n_frames, n_tracks, n_keypoints) + ) + + return { + "pose_tracks": tracks, + "confidence_scores": scores, + "individual_names": [n.decode() for n in f["track_names"][:]], + "keypoint_names": [n.decode() for n in f["node_names"][:]], + } + + @staticmethod + def _load_dict_from_sleap_labels_file(file_path: Path): + """Load pose tracks and confidence scores from a SLEAP labels file + into a dictionary.""" + + labels = read_labels(file_path.as_posix()) + tracks_with_scores = labels.numpy(return_confidence=True) + + return { + "pose_tracks": tracks_with_scores[:, :, :, :-1], + "confidence_scores": tracks_with_scores[:, :, :, -1], + "individual_names": [track.name for track in labels.tracks], + "keypoint_names": [kp.name for kp in labels.skeletons[0].nodes], + } From 633a9ebf62b30403461e80bd5fc2e8f5138a7d14 Mon Sep 17 00:00:00 2001 From: niksirbi Date: Wed, 19 Jul 2023 16:37:42 +0100 Subject: [PATCH 10/79] Ensure that PoseTracks class is imported with the io module --- movement/io/__init__.py | 1 + movement/io/{PoseTracks.py => pose_tracks.py} | 0 2 files changed, 1 insertion(+) rename movement/io/{PoseTracks.py => pose_tracks.py} (100%) diff --git a/movement/io/__init__.py b/movement/io/__init__.py index e69de29b..855b4d62 100644 --- a/movement/io/__init__.py +++ b/movement/io/__init__.py @@ -0,0 +1 @@ +from .pose_tracks import PoseTracks diff --git a/movement/io/PoseTracks.py b/movement/io/pose_tracks.py similarity index 100% rename from movement/io/PoseTracks.py rename to movement/io/pose_tracks.py From 1fbc7132decec0be76bd79e2da2021fcabef3a4f Mon Sep 17 00:00:00 2001 From: niksirbi Date: Wed, 19 Jul 2023 17:42:12 +0100 Subject: [PATCH 11/79] added method to import pose tracks from DeepLabCut --- movement/io/pose_tracks.py | 171 +++++++++++++++++++++++++++++++++++++ 1 file changed, 171 insertions(+) diff --git a/movement/io/pose_tracks.py b/movement/io/pose_tracks.py index 003c67d3..dec9c0af 100644 --- a/movement/io/pose_tracks.py +++ b/movement/io/pose_tracks.py @@ -8,6 +8,8 @@ import xarray as xr from sleap_io.io.slp import read_labels +from movement.io.validators import DeepLabCutPosesFile + # get logger logger = logging.getLogger(__name__) @@ -108,6 +110,42 @@ def from_dict( return ds + @classmethod + def from_dataframe(cls, df: pd.DataFrame, fps: Optional[float] = None): + """Create a `PoseTracks` dataset from a pandas DataFrame. + + Parameters + ---------- + df : pandas DataFrame + DataFrame containing the pose tracks and confidence scores. Must + be formatted as in DeepLabCut output files (see Notes). + fps : float, optional + The number of frames per second in the video. If None (default), + the `time` coordinate will not be created. + + Notes + ----- + The DataFrame must have a multi-index column with the following levels: + "scorer", ("individuals"), "bodyparts", "coords". + The "individuals level may be omitted if there is only one individual + in the video. The "coords" level contains the spatial coordinates "x", + "y", as well as "likelihood" (point-wise confidence scores). The row + index corresponds to the frame number. + + Examples + -------- + >>> from movement.io import PoseTracks + >>> df = pd.read_csv("path/to/poses.csv") + >>> poses = PoseTracks.from_dataframe(df, fps=30) + """ + + # Convert the DataFrame to a dictionary + dict_ = cls.dataframe_to_dict(df) + + # Initialize a PoseTracks dataset from the dictionary + ds = cls.from_dict({**dict_, "fps": fps}) + return ds + @classmethod def from_sleap( cls, file_path: Union[Path, str], fps: Optional[float] = None @@ -181,6 +219,57 @@ def from_sleap( ds.attrs["source_file"] = file_path.as_posix() return ds + @classmethod + def from_dlc( + cls, file_path: Union[Path, str], fps: Optional[float] = None + ): + """Load pose tracking data from a DeepLabCut (DLC) output file. + + Parameters + ---------- + file_path : pathlib Path or str + Path to the file containing the DLC poses, either in ".slp" + or ".h5" (analysis) format. + fps : float, optional + The number of frames per second in the video. If None (default), + the `time` coordinate will not be created. + + + Examples + -------- + >>> from movement.io import PoseTracks + >>> poses = PoseTracks.from_dlc("path/to/video_model.h5", fps=30) + """ + + # Validate the input file path + dlc_poses_file = DeepLabCutPosesFile(file_path=file_path) + file_suffix = dlc_poses_file.file_path.suffix + + # Load the DLC poses into a DataFrame + try: + if file_suffix == ".csv": + df = cls._parse_dlc_csv_to_dataframe(dlc_poses_file.file_path) + else: # file can only be .h5 at this point + df = pd.read_hdf(dlc_poses_file.file_path) + # above line does not necessarily return a DataFrame + df = pd.DataFrame(df) + except (OSError, TypeError, ValueError) as e: + error_msg = ( + f"Could not load poses from {file_path}. " + "Please check that the file is valid and readable." + ) + logger.error(error_msg) + raise OSError from e + logger.info(f"Loaded poses from {file_path} into a DataFrame.") + + # Convert the DataFrame to a PoseTracks dataset + ds = cls.from_dataframe(df=df, fps=fps) + + # Add metadata as attrs + ds.attrs["source_software"] = "DeepLabCut" + ds.attrs["source_file"] = dlc_poses_file.filepath.as_posix() + return ds + @staticmethod def _load_dict_from_sleap_analysis_file(file_path: Path): """Load pose tracks and confidence scores from a SLEAP analysis @@ -221,3 +310,85 @@ def _load_dict_from_sleap_labels_file(file_path: Path): "individual_names": [track.name for track in labels.tracks], "keypoint_names": [kp.name for kp in labels.skeletons[0].nodes], } + + def _parse_dlc_csv_to_dataframe(file_path: Path) -> pd.DataFrame: + """If poses are loaded from a DeepLabCut.csv file, the DataFrame + lacks the multi-index columns that are present in the .h5 file. This + function parses the csv file to a DataFrame with multi-index columns, + i.e. the same format as in the .h5 file. + + Parameters + ---------- + file_path : pathlib Path + Path to the file containing the DLC poses, in .csv format. + + Returns + ------- + pandas DataFrame + DataFrame containing the DLC poses, with multi-index columns. + """ + + possible_level_names = ["scorer", "individuals", "bodyparts", "coords"] + with open(file_path, "r") as f: + # if line starts with a possible level name, split it into a list + # of strings, and add it to the list of header lines + header_lines = [ + line.strip().split(",") + for line in f.readlines() + if line.split(",")[0] in possible_level_names + ] + + # Form multi-index column names from the header lines + level_names = [line[0] for line in header_lines] + column_tuples = list(zip(*[line[1:] for line in header_lines])) + columns = pd.MultiIndex.from_tuples(column_tuples, names=level_names) + + # Import the DLC poses as a DataFrame + df = pd.read_csv( + file_path, skiprows=len(header_lines), index_col=0, names=columns + ) + df.columns.rename(level_names, inplace=True) + return df + + @staticmethod + def dataframe_to_dict(df: pd.DataFrame) -> dict: + """Convert a DeepLabCut-style DataFrame containing pose tracks and + likelihood scores into a dictionary. + + Parameters + ---------- + df : pandas DataFrame + DataFrame formatted as in DeepLabCut output files. + + Returns + ------- + dict + Dictionary containing the pose tracks, confidence scores, and + metadata. + """ + + # read names of individuals and keypoints from the DataFrame + # retain the order of their appearance in the DataFrame + if "individuals" in df.columns.names: + ind_names = ( + df.columns.get_level_values("individuals").unique().to_list() + ) + else: + ind_names = ["individual_0"] + + kp_names = df.columns.get_level_values("bodyparts").unique().to_list() + print(ind_names) + print(kp_names) + + # reshape the data into (n_frames, n_individuals, n_keypoints, 3) + # where the last axis contains "x", "y", "likelihood" + tracks_with_scores = df.to_numpy().reshape( + (-1, len(ind_names), len(kp_names), 3) + ) + + return { + "pose_tracks": tracks_with_scores[:, :, :, :-1], + "confidence_scores": tracks_with_scores[:, :, :, -1], + "individual_names": ind_names, + "keypoint_names": kp_names, + } From e7a3ee80dfa5e7e07b35d4dbad1ba003e651addc Mon Sep 17 00:00:00 2001 From: niksirbi Date: Wed, 19 Jul 2023 18:31:11 +0100 Subject: [PATCH 12/79] transferred functionality of from_dict method to __init__ --- movement/io/pose_tracks.py | 190 +++++++++++++++++++------------------ 1 file changed, 96 insertions(+), 94 deletions(-) diff --git a/movement/io/pose_tracks.py b/movement/io/pose_tracks.py index dec9c0af..bc4886d6 100644 --- a/movement/io/pose_tracks.py +++ b/movement/io/pose_tracks.py @@ -51,68 +51,88 @@ class PoseTracks(xr.Dataset): __slots__ = ("fps", "source_software", "source_file") - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - @classmethod - def from_dict( - cls, - dict_: dict, + def __init__( + self, + pose_tracks: np.ndarray, + confidence_scores: Optional[np.ndarray] = None, + individual_names: Optional[list[str]] = None, + keypoint_names: Optional[list[str]] = None, + fps: Optional[float] = None, ): - """Create a `PosteTracks` dataset from a dictionary of pose tracks, - confidence scores, and metadata. + """Create a `PoseTracks` dataset. Parameters ---------- - dict_ : dict - A dictionary with the following keys: - - "pose_tracks": np.ndarray of shape (n_frames, n_individuals, - n_keypoints, n_space_dims) - - "confidence_scores": np.ndarray of shape (n_frames, - n_individuals, n_keypoints) - - "individual_names": list of strings, with length individuals - - "keypoint_names": list of strings, with length n_keypoints - - "fps": float, the number of frames per second in the video. - If None, the "time" coordinate will not be added. + pose_tracks : np.ndarray + Array of shape (n_frames, n_individuals, n_keypoints, n_space) + containing the pose tracks. + confidence_scores : np.ndarray, optional + Array of shape (n_frames, n_individuals, n_keypoints) containing + the point-wise confidence scores. If None (default), the + confidence scores will be set to an array of NaNs. + individual_names : list of str, optional + List of unique names for the individuals in the video. If None + (default), the individuals will be named "individual_0", + "individual_1", etc. + keypoint_names : list of str, optional + List of unique names for the keypoints in the skeleton. If None + (default), the keypoints will be named "keypoint_0", "keypoint_1", + etc. + fps : float, optional + The number of frames per second in the video. If None (default), + the `time` coordinate will not be created. """ + n_frames, n_individuals, n_keypoints, n_space = pose_tracks.shape + if confidence_scores is None: + confidence_scores = np.full( + (n_frames, n_individuals, n_keypoints), np.nan, dtype="float32" + ) + if individual_names is None: + individual_names = [ + f"individual_{i}" for i in range(n_individuals) + ] + if keypoint_names is None: + keypoint_names = [f"keypoint_{i}" for i in range(n_keypoints)] + if (fps is not None) and (fps <= 0): + logger.warning( + f"Expected fps to be a positive number, but got {fps}. " + "Setting fps to None." + ) + fps = None + # Convert the pose tracks and confidence scores to xarray.DataArray - tracks_da = xr.DataArray(dict_["pose_tracks"], dims=cls.dim_names) - scores_da = xr.DataArray( - dict_["confidence_scores"], dims=cls.dim_names[:-1] - ) + tracks_da = xr.DataArray(pose_tracks, dims=self.dim_names) + scores_da = xr.DataArray(confidence_scores, dims=self.dim_names[:-1]) # Combine the DataArrays into a Dataset, with common coordinates - ds = cls( + super().__init__( data_vars={ "pose_tracks": tracks_da, "confidence_scores": scores_da, }, coords={ - cls.dim_names[0]: np.arange( - dict_["pose_tracks"].shape[0], dtype=int - ), - cls.dim_names[1]: dict_["individual_names"], - cls.dim_names[2]: dict_["keypoint_names"], - cls.dim_names[3]: ["x", "y", "z"][ - : dict_["pose_tracks"].shape[-1] - ], + self.dim_names[0]: np.arange(n_frames, dtype=int), + self.dim_names[1]: individual_names, + self.dim_names[2]: keypoint_names, + self.dim_names[3]: ["x", "y", "z"][:n_space], }, - attrs={"fps": dict_["fps"]}, + attrs={"fps": fps, "source_software": None, "source_file": None}, ) - # If fps is given, create "time" coords for 1st ("frames") dimension - if dict_["fps"] is not None: - times = pd.TimedeltaIndex( - ds.coords["frames"] / dict_["fps"], unit="s" - ) - ds.coords["time"] = (cls.dim_names[0], times) + if fps is not None: + self._add_time_coord() - return ds + def _add_time_coord(self): + """Add a `time` coordinate to the dataset, based on the `frames` + dimension and the value of the `fps` attribute. + """ + times = pd.TimedeltaIndex(self.coords["frames"] / self.fps, unit="s") + self.coords["time"] = (self.dim_names[0], times) @classmethod def from_dataframe(cls, df: pd.DataFrame, fps: Optional[float] = None): - """Create a `PoseTracks` dataset from a pandas DataFrame. + """Create a `PoseTracks` dataset from a DLC_style pandas DataFrame. Parameters ---------- @@ -139,12 +159,31 @@ def from_dataframe(cls, df: pd.DataFrame, fps: Optional[float] = None): >>> poses = PoseTracks.from_dataframe(df, fps=30) """ - # Convert the DataFrame to a dictionary - dict_ = cls.dataframe_to_dict(df) + # read names of individuals and keypoints from the DataFrame + if "individuals" in df.columns.names: + individual_names = ( + df.columns.get_level_values("individuals").unique().to_list() + ) + else: + individual_names = ["individual_0"] - # Initialize a PoseTracks dataset from the dictionary - ds = cls.from_dict({**dict_, "fps": fps}) - return ds + keypoint_names = ( + df.columns.get_level_values("bodyparts").unique().to_list() + ) + + # reshape the data into (n_frames, n_individuals, n_keypoints, 3) + # where the last axis contains "x", "y", "likelihood" + tracks_with_scores = df.to_numpy().reshape( + (-1, len(individual_names), len(keypoint_names), 3) + ) + + return cls( + pose_tracks=tracks_with_scores[:, :, :, :-1], + confidence_scores=tracks_with_scores[:, :, :, -1], + individual_names=individual_names, + keypoint_names=keypoint_names, + fps=fps, + ) @classmethod def from_sleap( @@ -194,9 +233,9 @@ def from_sleap( # Load data into a dictionary if file_path.suffix == ".h5": - dict_ = cls._load_dict_from_sleap_analysis_file(file_path) + data_dict = cls._load_dict_from_sleap_analysis_file(file_path) elif file_path.suffix == ".slp": - dict_ = cls._load_dict_from_sleap_labels_file(file_path) + data_dict = cls._load_dict_from_sleap_labels_file(file_path) else: error_msg = ( f"Expected file suffix to be '.h5' or '.slp', " @@ -212,11 +251,14 @@ def from_sleap( ) # Initialize a PoseTracks dataset from the dictionary - ds = cls.from_dict({**dict_, "fps": fps}) + ds = cls(**data_dict, fps=fps) # Add metadata as attrs ds.attrs["source_software"] = "SLEAP" ds.attrs["source_file"] = file_path.as_posix() + + logger.info(f"Loaded pose tracks from {ds.source_file}:") + logger.info(ds) return ds @classmethod @@ -260,14 +302,17 @@ def from_dlc( ) logger.error(error_msg) raise OSError from e - logger.info(f"Loaded poses from {file_path} into a DataFrame.") + logger.debug(f"Loaded poses from {file_path} into a DataFrame.") # Convert the DataFrame to a PoseTracks dataset ds = cls.from_dataframe(df=df, fps=fps) # Add metadata as attrs ds.attrs["source_software"] = "DeepLabCut" - ds.attrs["source_file"] = dlc_poses_file.filepath.as_posix() + ds.attrs["source_file"] = dlc_poses_file.file_path.as_posix() + + logger.info(f"Loaded pose tracks from {ds.source_file}:") + logger.info(ds) return ds @staticmethod @@ -349,46 +394,3 @@ def _parse_dlc_csv_to_dataframe(file_path: Path) -> pd.DataFrame: ) df.columns.rename(level_names, inplace=True) return df - - @staticmethod - def dataframe_to_dict(df: pd.DataFrame) -> dict: - """Convert a DeepLabCut-style DataFrame containing pose tracks and - likelihood scores into a dictionary. - - Parameters - ---------- - df : pandas DataFrame - DataFrame formatted as in DeepLabCut output files. - - Returns - ------- - dict - Dictionary containing the pose tracks, confidence scores, and - metadata. - """ - - # read names of individuals and keypoints from the DataFrame - # retain the order of their appearance in the DataFrame - if "individuals" in df.columns.names: - ind_names = ( - df.columns.get_level_values("individuals").unique().to_list() - ) - else: - ind_names = ["individual_0"] - - kp_names = df.columns.get_level_values("bodyparts").unique().to_list() - print(ind_names) - print(kp_names) - - # reshape the data into (n_frames, n_individuals, n_keypoints, 3) - # where the last axis contains "x", "y", "likelihood" - tracks_with_scores = df.to_numpy().reshape( - (-1, len(ind_names), len(kp_names), 3) - ) - - return { - "pose_tracks": tracks_with_scores[:, :, :, :-1], - "confidence_scores": tracks_with_scores[:, :, :, -1], - "individual_names": ind_names, - "keypoint_names": kp_names, - } From d53398995d68d0dc44e06b4a24a3bbea90ce4722 Mon Sep 17 00:00:00 2001 From: niksirbi Date: Thu, 20 Jul 2023 11:30:20 +0100 Subject: [PATCH 13/79] shortened some docstrings --- movement/io/pose_tracks.py | 61 ++++++++++++++++---------------------- 1 file changed, 25 insertions(+), 36 deletions(-) diff --git a/movement/io/pose_tracks.py b/movement/io/pose_tracks.py index bc4886d6..ac3e3785 100644 --- a/movement/io/pose_tracks.py +++ b/movement/io/pose_tracks.py @@ -17,20 +17,17 @@ class PoseTracks(xr.Dataset): """Dataset containing pose tracks and point-wise confidence scores. - This is a `xarray.Dataset` object, with the following dimensions: + This is an `xarray.Dataset` object, with the following dimensions: - `frames`: the number of frames in the video - `individuals`: the number of individuals in the video - `keypoints`: the number of keypoints in the skeleton - `space`: the number of spatial dimensions, either 2 or 3 - Each dimension is assigned appropriate coordinates: - - frame indices (int) for `frames` - - list of unique names (str) for `individuals` and `keypoints` - - `x`, `y` (and `z`) for `space` - - If `fps` is supplied, the `frames` dimension is also assigned a `time` - coordinate. If `fps` is None, the temporal dimension can only be - accessed through frame indices. + Appropriate coordinate labels are assigned to each dimension: + frame indices (int) for `frames`. list of unique names (str) for + `individuals` and `keypoints`, ['x','y',('z')] for `space`. If `fps` + is supplied, the `frames` dimension is also assigned a `time` + coordinate. The dataset contains two data variables (`xarray.DataArray` objects): - `pose_tracks`: with shape (`frames`, `individuals`, `keypoints`, `space`) @@ -146,17 +143,11 @@ def from_dataframe(cls, df: pd.DataFrame, fps: Optional[float] = None): Notes ----- The DataFrame must have a multi-index column with the following levels: - "scorer", ("individuals"), "bodyparts", "coords". - The "individuals level may be omitted if there is only one individual - in the video. The "coords" level contains the spatial coordinates "x", - "y", as well as "likelihood" (point-wise confidence scores). The row - index corresponds to the frame number. - - Examples - -------- - >>> from movement.io import PoseTracks - >>> df = pd.read_csv("path/to/poses.csv") - >>> poses = PoseTracks.from_dataframe(df, fps=30) + "scorer", ("individuals"), "bodyparts", "coords". The "individuals" + level may be omitted if there is only one individual in the video. + The "coords" level contains the spatial coordinates "x", "y", + as well as "likelihood" (point-wise confidence scores). + The row index corresponds to the frame number. """ # read names of individuals and keypoints from the DataFrame @@ -202,20 +193,18 @@ def from_sleap( Notes ----- - The SLEAP inference procedure normally produces a file suffixed with - ".slp" containing the predictions, e.g. "myproject.predictions.slp". - This can be converted to an ".h5" (analysis) file using the command - line tool `sleap-convert` with the "--format analysis" option enabled, - or alternatively by choosing "Export Analysis HDF5…" from the "File" - menu of the SLEAP GUI [1]_. + The SLEAP predictions are normally saved in a ".slp" file, e.g. + "v1.predictions.slp". If this file contains both user-labeled and + predicted instances, only the predicted iones will be loaded. - If the ".slp" file contains both user-labeled and predicted instances, - this function will only load the ones predicted by the SLEAP model + An analysis file, suffixed with ".h5" can be exported from the ".slp" + file, using either the command line tool `sleap-convert` (with the + "--format analysis" option enabled) or the SLEAP GUI (Choose + "Export Analysis HDF5…" from the "File" menu) [1]_. - `movement` expects the tracks to be proofread before loading them. - There should be as many tracks as there are instances (animals) in the - video, without identity switches. Follow the SLEAP guide for - tracking and proofreading [2]_. + `movement` expects the tracks to be proofread before loading them, + meaning each track is interpreted as a single individual/animal. + Follow the SLEAP guide for tracking and proofreading [2]_. References ---------- @@ -270,8 +259,8 @@ def from_dlc( Parameters ---------- file_path : pathlib Path or str - Path to the file containing the DLC poses, either in ".slp" - or ".h5" (analysis) format. + Path to the file containing the DLC poses, either in ".h5" + or ".csv" format. fps : float, optional The number of frames per second in the video. If None (default), the `time` coordinate will not be created. @@ -359,8 +348,8 @@ def _load_dict_from_sleap_labels_file(file_path: Path): def _parse_dlc_csv_to_dataframe(file_path: Path) -> pd.DataFrame: """If poses are loaded from a DeepLabCut.csv file, the DataFrame lacks the multi-index columns that are present in the .h5 file. This - function parses the csv file to a DataFrame with multi-index columns, - i.e. the same format as in the .h5 file. + function parses the csv file to a pandas DataFrame with multi-index + columns, i.e. the same format as in the .h5 file. Parameters ---------- From c7075f1f02bbf053465f271951a544e6727fcc34 Mon Sep 17 00:00:00 2001 From: niksirbi Date: Thu, 20 Jul 2023 12:00:42 +0100 Subject: [PATCH 14/79] deleted superceded load_poses module --- movement/io/load_poses.py | 241 -------------------------------------- 1 file changed, 241 deletions(-) delete mode 100644 movement/io/load_poses.py diff --git a/movement/io/load_poses.py b/movement/io/load_poses.py deleted file mode 100644 index bd6df1e5..00000000 --- a/movement/io/load_poses.py +++ /dev/null @@ -1,241 +0,0 @@ -import logging -from pathlib import Path -from typing import Optional, Union - -import h5py -import numpy as np -import pandas as pd -from sleap_io.io.slp import read_labels - -from movement.io.validators import DeepLabCutPosesFile - -# get logger -logger = logging.getLogger(__name__) - - -def from_dlc(file_path: Union[Path, str]) -> Optional[pd.DataFrame]: - """Load pose estimation results from a DeepLabCut (DLC) files. - Files must be in .h5 format or .csv format. - - Parameters - ---------- - file_path : pathlib Path or str - Path to the file containing the DLC poses. - - Returns - ------- - pandas DataFrame - DataFrame containing the DLC poses - - Examples - -------- - >>> from movement.io import load_poses - >>> poses = load_poses.from_dlc("path/to/file.h5") - """ - - # Validate the input file path - dlc_poses_file = DeepLabCutPosesFile(file_path=file_path) # type: ignore - file_suffix = dlc_poses_file.file_path.suffix - - # Load the DLC poses - try: - if file_suffix == ".csv": - df = _parse_dlc_csv_to_dataframe(dlc_poses_file.file_path) - else: # file can only be .h5 at this point - df = pd.read_hdf(dlc_poses_file.file_path) - # above line does not necessarily return a DataFrame - df = pd.DataFrame(df) - except (OSError, TypeError, ValueError) as e: - error_msg = ( - f"Could not load poses from {file_path}. " - "Please check that the file is valid and readable." - ) - logger.error(error_msg) - raise OSError from e - logger.info(f"Loaded poses from {file_path}") - return df - - -def _parse_dlc_csv_to_dataframe(file_path: Path) -> pd.DataFrame: - """If poses are loaded from a DeepLabCut.csv file, the resulting DataFrame - lacks the multi-index columns that are present in the .h5 file. This - function parses the csv file to a DataFrame with multi-index columns. - - Parameters - ---------- - file_path : pathlib Path - Path to the file containing the DLC poses, in .csv format. - - Returns - ------- - pandas DataFrame - DataFrame containing the DLC poses, with multi-index columns. - """ - - possible_level_names = ["scorer", "individuals", "bodyparts", "coords"] - with open(file_path, "r") as f: - # if line starts with a possible level name, split it into a list - # of strings, and add it to the list of header lines - header_lines = [ - line.strip().split(",") - for line in f.readlines() - if line.split(",")[0] in possible_level_names - ] - - # Form multi-index column names from the header lines - level_names = [line[0] for line in header_lines] - column_tuples = list(zip(*[line[1:] for line in header_lines])) - columns = pd.MultiIndex.from_tuples(column_tuples, names=level_names) - - # Import the DLC poses as a DataFrame - df = pd.read_csv( - file_path, skiprows=len(header_lines), index_col=0, names=columns - ) - df.columns.rename(level_names, inplace=True) - return df - - -def from_sleap(file_path: Union[Path, str]) -> dict: - """Load pose tracking data from a SLEAP labels file. - - Parameters - ---------- - file_path : pathlib Path or str - Path to the file containing the SLEAP predictions, either in ".slp" - or ".h5" (analysis) format. See Notes for more information. - - Returns - ------- - dict - Dictionary containing `pose_tracks`, `node_names` and `track_names`. - - `pose_tracks` is an array containing the predicted poses. - Shape: (n_frames, n_tracks, n_nodes, n_dims). The last axis - contains the spatial coordinates "x" and "y", as well as the - point-wise confidence values. - - `node_names` is a list of the node names. - - `track_names` is a list of the track names. - - Notes - ----- - The SLEAP inference procedure normally produces a file suffixed with ".slp" - containing the predictions, e.g. "myproject.predictions.slp". - This can be converted to an ".h5" (analysis) file using the command line - tool `sleap-convert` with the "--format analysis" option enabled, - or alternatively by choosing “Export Analysis HDF5…” from the “File” menu - of the SLEAP GUI [1]_. - - This function will only the predicted instances in the ".slp" file, - not the user-labeled ones. - - movement expects the tracks to be proofread before loading them. - There should be as many tracks as there are instances (animals) in the - video, without identity switches. Follow the SLEAP guide for - tracking and proofreading [2]_. - - References - ---------- - .. [1] https://sleap.ai/tutorials/analysis.html - .. [2] https://sleap.ai/guides/proofreading.html - - Examples - -------- - >>> from movement.io import load_poses - >>> poses = load_poses.from_sleap("path/to/labels.predictions.slp") - """ - - if not isinstance(file_path, Path): - file_path = Path(file_path) - - if file_path.suffix == ".h5": - # Load the SLEAP predictions from an analysis file - poses = _load_sleap_analysis_file(file_path) - elif file_path.suffix == ".slp": - # Load the SLEAP predictions from a labels file - poses = _load_sleap_labels_file(file_path) - else: - error_msg = ( - f"Expected file suffix to be '.h5' or '.slp', " - f"but got '{file_path.suffix}'. Make sure the file is " - "a SLEAP labels file with suffix '.slp' or SLEAP analysis " - "file with suffix '.h5'." - ) - logger.error(error_msg) - raise ValueError(error_msg) - - n_frames, n_tracks, n_nodes, n_dims = poses["tracks"].shape - logger.info(f"Loaded poses from {file_path}.") - logger.debug( - f"Shape: ({n_frames} frames, {n_tracks} tracks, " - f"{n_nodes} nodes, {n_dims - 1} spatial coords " - "+ 1 confidence score)" - ) - logger.info(f"Track names: {poses['track_names']}") - logger.info(f"Node names: {poses['node_names']}") - return poses - - -def _load_sleap_analysis_file(file_path: Path) -> dict: - """Load pose tracking data from a SLEAP analysis file. - - Parameters - ---------- - file_path : pathlib Path - Path to the file containing the SLEAP predictions, in ".h5" format. - - Returns - ------- - dict - Dictionary containing `pose_tracks`, `node_names` and `track_names`. - """ - - # Load the SLEAP poses - with h5py.File(file_path, "r") as f: - # First, load and reshape the pose tracks - tracks = f["tracks"][:].T - n_frames, n_nodes, n_dims, n_tracks = tracks.shape - tracks = tracks.reshape((n_frames, n_tracks, n_nodes, n_dims)) - - # If present, read the point-wise confidence scores - # and add them to the "tracks" array - confidence = np.full( - (n_frames, n_tracks, n_nodes, 3), np.nan, dtype="float32" - ) - if "point_scores" in f.keys(): - confidence = f["point_scores"][:].T - confidence = confidence.reshape((n_frames, n_tracks, n_nodes)) - tracks = np.concatenate( - [tracks, confidence[:, :, :, np.newaxis]], axis=3 - ) - - # Create the dictionary to be returned - poses = { - "tracks": tracks, - "node_names": [n.decode() for n in f["node_names"][:]], - "track_names": [n.decode() for n in f["track_names"][:]], - } - return poses - - -def _load_sleap_labels_file(file_path: Path) -> dict: - """Load pose tracking data from a SLEAP labels file. - - Parameters - ---------- - file_path : pathlib Path - Path to the file containing the SLEAP predictions, in ".slp" format. - - Returns - ------- - dict - Dictionary containing `pose_tracks`, `node_names` and `track_names`. - """ - labels = read_labels(file_path.as_posix()) - poses = { - "tracks": labels.numpy(return_confidence=True), - "node_names": [node.name for node in labels.skeletons[0].nodes], - "track_names": [track.name for track in labels.tracks], - } - # return_confidence=True adds the point-wise confidence scores - # as an extra coord dimension to the "tracks" array - - return poses From 846527c52492782d2a9000901e95bb0279701623 Mon Sep 17 00:00:00 2001 From: niksirbi Date: Thu, 20 Jul 2023 12:37:55 +0100 Subject: [PATCH 15/79] renamed numpy arrays for pose tracks and and scores to avoid clashiing with xarray.DataArray names --- movement/io/pose_tracks.py | 41 ++++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 19 deletions(-) diff --git a/movement/io/pose_tracks.py b/movement/io/pose_tracks.py index ac3e3785..73770b67 100644 --- a/movement/io/pose_tracks.py +++ b/movement/io/pose_tracks.py @@ -50,23 +50,26 @@ class PoseTracks(xr.Dataset): def __init__( self, - pose_tracks: np.ndarray, - confidence_scores: Optional[np.ndarray] = None, + tracks_array: np.ndarray, + scores_array: Optional[np.ndarray] = None, individual_names: Optional[list[str]] = None, keypoint_names: Optional[list[str]] = None, fps: Optional[float] = None, ): - """Create a `PoseTracks` dataset. + """Create a `PoseTracks` dataset containing pose tracks and + point-wise confidence scores. Parameters ---------- - pose_tracks : np.ndarray + tracks_array : np.ndarray Array of shape (n_frames, n_individuals, n_keypoints, n_space) - containing the pose tracks. - confidence_scores : np.ndarray, optional + containing the pose tracks. It will be converted to a + `xarray.DataArray` object named "pose_tracks". + scores_array : np.ndarray, optional Array of shape (n_frames, n_individuals, n_keypoints) containing - the point-wise confidence scores. If None (default), the - confidence scores will be set to an array of NaNs. + the point-wise confidence scores. It will be converted to a + `xarray.DataArray` object named "confidence_scores". + If None (default), the scores will be set to an array of NaNs. individual_names : list of str, optional List of unique names for the individuals in the video. If None (default), the individuals will be named "individual_0", @@ -80,9 +83,9 @@ def __init__( the `time` coordinate will not be created. """ - n_frames, n_individuals, n_keypoints, n_space = pose_tracks.shape - if confidence_scores is None: - confidence_scores = np.full( + n_frames, n_individuals, n_keypoints, n_space = tracks_array.shape + if scores_array is None: + scores_array = np.full( (n_frames, n_individuals, n_keypoints), np.nan, dtype="float32" ) if individual_names is None: @@ -99,8 +102,8 @@ def __init__( fps = None # Convert the pose tracks and confidence scores to xarray.DataArray - tracks_da = xr.DataArray(pose_tracks, dims=self.dim_names) - scores_da = xr.DataArray(confidence_scores, dims=self.dim_names[:-1]) + tracks_da = xr.DataArray(tracks_array, dims=self.dim_names) + scores_da = xr.DataArray(scores_array, dims=self.dim_names[:-1]) # Combine the DataArrays into a Dataset, with common coordinates super().__init__( @@ -169,8 +172,8 @@ def from_dataframe(cls, df: pd.DataFrame, fps: Optional[float] = None): ) return cls( - pose_tracks=tracks_with_scores[:, :, :, :-1], - confidence_scores=tracks_with_scores[:, :, :, -1], + tracks_array=tracks_with_scores[:, :, :, :-1], + scores_array=tracks_with_scores[:, :, :, -1], individual_names=individual_names, keypoint_names=keypoint_names, fps=fps, @@ -324,8 +327,8 @@ def _load_dict_from_sleap_analysis_file(file_path: Path): ) return { - "pose_tracks": tracks, - "confidence_scores": scores, + "tracks_array": tracks, + "scores_array": scores, "individual_names": [n.decode() for n in f["track_names"][:]], "keypoint_names": [n.decode() for n in f["node_names"][:]], } @@ -339,8 +342,8 @@ def _load_dict_from_sleap_labels_file(file_path: Path): tracks_with_scores = labels.numpy(return_confidence=True) return { - "pose_tracks": tracks_with_scores[:, :, :, :-1], - "confidence_scores": tracks_with_scores[:, :, :, -1], + "tracks_array": tracks_with_scores[:, :, :, :-1], + "scores_array": tracks_with_scores[:, :, :, -1], "individual_names": [track.name for track in labels.tracks], "keypoint_names": [kp.name for kp in labels.skeletons[0].nodes], } From 04ef7372059d7d6e0cc619bc89db921b2df798c9 Mon Sep 17 00:00:00 2001 From: niksirbi Date: Thu, 20 Jul 2023 15:07:13 +0100 Subject: [PATCH 16/79] implemented existing converter function as a PoseTracks.to_dlc_df() method --- movement/io/convert.py | 68 -------------------------------------- movement/io/pose_tracks.py | 66 +++++++++++++++++++++++++++++------- 2 files changed, 54 insertions(+), 80 deletions(-) delete mode 100644 movement/io/convert.py diff --git a/movement/io/convert.py b/movement/io/convert.py deleted file mode 100644 index c05a2dd7..00000000 --- a/movement/io/convert.py +++ /dev/null @@ -1,68 +0,0 @@ -""" -Functions to convert between different formats, -e.g. from DeepLabCut to SLEAP and vice versa. -""" -import logging - -import pandas as pd - -# get logger -logger = logging.getLogger(__name__) - - -def sleap_poses_to_dlc_df(pose_tracks: dict) -> pd.DataFrame: - """Convert pose tracking data from SLEAP labels to a DeepLabCut-style - DataFrame with multi-index columns. See Notes for details. - - Parameters - ---------- - pose_tracks : dict - Dictionary containing `pose_tracks`, `node_names` and `track_names`. - This dictionary is returned by `io.load_poses.from_sleap`. - - Returns - ------- - pandas DataFrame - DataFrame containing pose tracks in DLC style, with the multi-index - columns ("scorer", "individuals", "bodyparts", "coords"). - - Notes - ----- - Correspondence between SLEAP and DLC terminology: - - DLC "scorer" has no equivalent in SLEAP, so we assign it to "SLEAP" - - DLC "individuals" are the names of SLEAP "tracks" - - DLC "bodyparts" are the names of SLEAP "nodes" (i.e. the keypoints) - - DLC "coords" are referred to in SLEAP as "dims" - (i.e. "x" coord + "y" coord + "confidence/likelihood") - - DLC reports "likelihood" while SLEAP reports "confidence". - These both measure the point-wise prediction confidence but do not - have the same range and cannot be compared between the two frameworks. - """ - - # Get the number of frames, tracks, nodes and dimensions - n_frames, n_tracks, n_nodes, n_dims = pose_tracks["tracks"].shape - # Use the DLC terminology: scorer, individuals, bodyparts, coords - # The assigned scorer is always "DeepLabCut" - scorer = ["SLEAP"] - individuals = pose_tracks["track_names"] - bodyparts = pose_tracks["node_names"] - coords = ["x", "y", "likelihood"] - - # Create the DLC-style multi-index dataframe - index_levels = ["scorer", "individuals", "bodyparts", "coords"] - columns = pd.MultiIndex.from_product( - [scorer, individuals, bodyparts, coords], names=index_levels - ) - df = pd.DataFrame( - data=pose_tracks["tracks"].reshape(n_frames, -1), - index=pd.RangeIndex(0, n_frames), - columns=columns, - dtype=float, - ) - - # Log the conversion - logger.info( - f"Converted SLEAP pose tracks to DLC-style DataFrame " - f"with shape {df.shape}" - ) - return df diff --git a/movement/io/pose_tracks.py b/movement/io/pose_tracks.py index 73770b67..720ed53f 100644 --- a/movement/io/pose_tracks.py +++ b/movement/io/pose_tracks.py @@ -121,17 +121,11 @@ def __init__( ) if fps is not None: - self._add_time_coord() - - def _add_time_coord(self): - """Add a `time` coordinate to the dataset, based on the `frames` - dimension and the value of the `fps` attribute. - """ - times = pd.TimedeltaIndex(self.coords["frames"] / self.fps, unit="s") - self.coords["time"] = (self.dim_names[0], times) + times = pd.TimedeltaIndex(self.coords["frames"] / fps, unit="s") + self.coords["time"] = (self.dim_names[0], times) @classmethod - def from_dataframe(cls, df: pd.DataFrame, fps: Optional[float] = None): + def from_dlc_df(cls, df: pd.DataFrame, fps: Optional[float] = None): """Create a `PoseTracks` dataset from a DLC_style pandas DataFrame. Parameters @@ -249,7 +243,7 @@ def from_sleap( ds.attrs["source_software"] = "SLEAP" ds.attrs["source_file"] = file_path.as_posix() - logger.info(f"Loaded pose tracks from {ds.source_file}:") + logger.info(f"Loaded pose tracks from {file_path}:") logger.info(ds) return ds @@ -297,16 +291,64 @@ def from_dlc( logger.debug(f"Loaded poses from {file_path} into a DataFrame.") # Convert the DataFrame to a PoseTracks dataset - ds = cls.from_dataframe(df=df, fps=fps) + ds = cls.from_dlc_df(df=df, fps=fps) # Add metadata as attrs ds.attrs["source_software"] = "DeepLabCut" ds.attrs["source_file"] = dlc_poses_file.file_path.as_posix() - logger.info(f"Loaded pose tracks from {ds.source_file}:") + logger.info(f"Loaded pose tracks from {dlc_poses_file.file_path}:") logger.info(ds) return ds + def to_dlc_df(self) -> pd.DataFrame: + """Convert the PoseTracks dataset to a DeepLabCut-style pandas + DataFrame with multi-index columns. + See the Notes section of the `from_dlc_df()` method for details. + + Returns + ------- + pandas DataFrame + + Notes + ----- + The DataFrame will have a multi-index column with the following levels: + "scorer", "individuals", "bodyparts", "coords" (even if there is only + one individual present). Regardless of the provenance of the + points-wise confidence scores, they will be referred to as + "likelihood", and stored in the "coords" level (as DeepLabCut expects). + """ + + # Concatenate the pose tracks and confidence scores into one array + tracks_with_scores = np.concatenate( + ( + self.pose_tracks.data, + self.confidence_scores.data[..., np.newaxis], + ), + axis=-1, + ) + + # Create the DLC-style multi-index columns + # Use the DLC terminology: scorer, individuals, bodyparts, coords + scorer = ["movement"] + individuals = self.coords["individuals"].data.tolist() + bodyparts = self.coords["keypoints"].data.tolist() + # The confidence scores in DLC are referred to as "likelihood" + coords = self.coords["space"].data.tolist() + ["likelihood"] + + index_levels = ["scorer", "individuals", "bodyparts", "coords"] + columns = pd.MultiIndex.from_product( + [scorer, individuals, bodyparts, coords], names=index_levels + ) + df = pd.DataFrame( + data=tracks_with_scores.reshape(self.dims["frames"], -1), + index=self.coords["frames"].data, + columns=columns, + dtype=float, + ) + logger.info("Converted PoseTracks dataset to DLC-style DataFrame.") + return df + @staticmethod def _load_dict_from_sleap_analysis_file(file_path: Path): """Load pose tracks and confidence scores from a SLEAP analysis From 7ddab36df74aaf74afc01d43a40bf5ca50e28312 Mon Sep 17 00:00:00 2001 From: niksirbi Date: Thu, 20 Jul 2023 15:10:44 +0100 Subject: [PATCH 17/79] renamed from_dlc and from_sleap methods to from_dlc_file and from_sleap_file --- movement/io/pose_tracks.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/movement/io/pose_tracks.py b/movement/io/pose_tracks.py index 720ed53f..d65231a3 100644 --- a/movement/io/pose_tracks.py +++ b/movement/io/pose_tracks.py @@ -174,7 +174,7 @@ def from_dlc_df(cls, df: pd.DataFrame, fps: Optional[float] = None): ) @classmethod - def from_sleap( + def from_sleap_file( cls, file_path: Union[Path, str], fps: Optional[float] = None ): """Load pose tracking data from a SLEAP labels or analysis file. @@ -211,7 +211,7 @@ def from_sleap( Examples -------- >>> from movement.io import PoseTracks - >>> poses = PoseTracks.from_sleap("path/to/v1.predictions.slp", fps=30) + >>> poses = PoseTracks.from_sleap_file("path/to/file.slp", fps=30) """ if not isinstance(file_path, Path): @@ -248,7 +248,7 @@ def from_sleap( return ds @classmethod - def from_dlc( + def from_dlc_file( cls, file_path: Union[Path, str], fps: Optional[float] = None ): """Load pose tracking data from a DeepLabCut (DLC) output file. @@ -266,7 +266,7 @@ def from_dlc( Examples -------- >>> from movement.io import PoseTracks - >>> poses = PoseTracks.from_dlc("path/to/video_model.h5", fps=30) + >>> poses = PoseTracks.from_dlc_file("path/to/file.h5", fps=30) """ # Validate the input file path From 8ea6e70e693725824f5cdb83253068085a8744b2 Mon Sep 17 00:00:00 2001 From: niksirbi Date: Thu, 20 Jul 2023 18:20:48 +0100 Subject: [PATCH 18/79] change "frames" dim to "time" --- movement/io/pose_tracks.py | 68 +++++++++++++++++++++----------------- 1 file changed, 37 insertions(+), 31 deletions(-) diff --git a/movement/io/pose_tracks.py b/movement/io/pose_tracks.py index d65231a3..b3e73c0d 100644 --- a/movement/io/pose_tracks.py +++ b/movement/io/pose_tracks.py @@ -8,8 +8,6 @@ import xarray as xr from sleap_io.io.slp import read_labels -from movement.io.validators import DeepLabCutPosesFile - # get logger logger = logging.getLogger(__name__) @@ -18,35 +16,35 @@ class PoseTracks(xr.Dataset): """Dataset containing pose tracks and point-wise confidence scores. This is an `xarray.Dataset` object, with the following dimensions: - - `frames`: the number of frames in the video + - `time`: the number of frames in the video - `individuals`: the number of individuals in the video - `keypoints`: the number of keypoints in the skeleton - `space`: the number of spatial dimensions, either 2 or 3 Appropriate coordinate labels are assigned to each dimension: - frame indices (int) for `frames`. list of unique names (str) for - `individuals` and `keypoints`, ['x','y',('z')] for `space`. If `fps` - is supplied, the `frames` dimension is also assigned a `time` - coordinate. + list of unique names (str) for `individuals` and `keypoints`, + ['x','y',('z')] for `space`. The coordinates of the `time` dimension are + in seconds if `fps` is provided, otherwise they are in frame numbers. The dataset contains two data variables (`xarray.DataArray` objects): - - `pose_tracks`: with shape (`frames`, `individuals`, `keypoints`, `space`) - - `confidence_scores`: with shape (`frames`, `individuals`, `keypoints`) + - `pose_tracks`: with shape (`time`, `individuals`, `keypoints`, `space`) + - `confidence_scores`: with shape (`time`, `individuals`, `keypoints`) The dataset may also contain following attributes as metadata: - `fps`: the number of frames per second in the video + - `time_unit`: the unit of the `time` coordinates, frames or seconds - `source_software`: the software from which the pose tracks were loaded - `source_file`: the file from which the pose tracks were loaded """ dim_names: ClassVar[tuple] = ( - "frames", + "time", "individuals", "keypoints", "space", ) - __slots__ = ("fps", "source_software", "source_file") + __slots__ = ("fps", "time_unit", "source_software", "source_file") def __init__( self, @@ -80,7 +78,7 @@ def __init__( etc. fps : float, optional The number of frames per second in the video. If None (default), - the `time` coordinate will not be created. + the `time` coordinates will be in frame numbers. """ n_frames, n_individuals, n_keypoints, n_space = tracks_array.shape @@ -105,6 +103,13 @@ def __init__( tracks_da = xr.DataArray(tracks_array, dims=self.dim_names) scores_da = xr.DataArray(scores_array, dims=self.dim_names[:-1]) + # Create the time coordinate, depending on the value of fps + time_coords = np.arange(n_frames, dtype=int) + time_unit = "frames" + if fps is not None: + time_coords = time_coords / fps + time_unit = "seconds" + # Combine the DataArrays into a Dataset, with common coordinates super().__init__( data_vars={ @@ -112,18 +117,19 @@ def __init__( "confidence_scores": scores_da, }, coords={ - self.dim_names[0]: np.arange(n_frames, dtype=int), + self.dim_names[0]: time_coords, self.dim_names[1]: individual_names, self.dim_names[2]: keypoint_names, self.dim_names[3]: ["x", "y", "z"][:n_space], }, - attrs={"fps": fps, "source_software": None, "source_file": None}, + attrs={ + "fps": fps, + "time_unit": time_unit, + "source_software": None, + "source_file": None, + }, ) - if fps is not None: - times = pd.TimedeltaIndex(self.coords["frames"] / fps, unit="s") - self.coords["time"] = (self.dim_names[0], times) - @classmethod def from_dlc_df(cls, df: pd.DataFrame, fps: Optional[float] = None): """Create a `PoseTracks` dataset from a DLC_style pandas DataFrame. @@ -135,7 +141,7 @@ def from_dlc_df(cls, df: pd.DataFrame, fps: Optional[float] = None): be formatted as in DeepLabCut output files (see Notes). fps : float, optional The number of frames per second in the video. If None (default), - the `time` coordinate will not be created. + the `time` coordinates will be in frame numbers. Notes ----- @@ -184,9 +190,8 @@ def from_sleap_file( file_path : pathlib Path or str Path to the file containing the SLEAP predictions, either in ".slp" or ".h5" (analysis) format. See Notes for more information. - fps : float, optional - The number of frames per second in the video. If None (default), - the `time` coordinate will not be created. + The number of frames per second in the video. If None (default), + the `time` coordinates will be in frame numbers. Notes ----- @@ -260,7 +265,7 @@ def from_dlc_file( or ".csv" format. fps : float, optional The number of frames per second in the video. If None (default), - the `time` coordinate will not be created. + the `time` coordinates will be in frame numbers. Examples @@ -270,15 +275,16 @@ def from_dlc_file( """ # Validate the input file path - dlc_poses_file = DeepLabCutPosesFile(file_path=file_path) - file_suffix = dlc_poses_file.file_path.suffix + if not isinstance(file_path, Path): + file_path = Path(file_path) + file_suffix = file_path.suffix # Load the DLC poses into a DataFrame try: if file_suffix == ".csv": - df = cls._parse_dlc_csv_to_dataframe(dlc_poses_file.file_path) + df = cls._parse_dlc_csv_to_dataframe(file_path) else: # file can only be .h5 at this point - df = pd.read_hdf(dlc_poses_file.file_path) + df = pd.read_hdf(file_path) # above line does not necessarily return a DataFrame df = pd.DataFrame(df) except (OSError, TypeError, ValueError) as e: @@ -295,9 +301,9 @@ def from_dlc_file( # Add metadata as attrs ds.attrs["source_software"] = "DeepLabCut" - ds.attrs["source_file"] = dlc_poses_file.file_path.as_posix() + ds.attrs["source_file"] = file_path.as_posix() - logger.info(f"Loaded pose tracks from {dlc_poses_file.file_path}:") + logger.info(f"Loaded pose tracks from {file_path}:") logger.info(ds) return ds @@ -341,8 +347,8 @@ def to_dlc_df(self) -> pd.DataFrame: [scorer, individuals, bodyparts, coords], names=index_levels ) df = pd.DataFrame( - data=tracks_with_scores.reshape(self.dims["frames"], -1), - index=self.coords["frames"].data, + data=tracks_with_scores.reshape(self.dims["time"], -1), + index=np.arange(self.dims["time"], dtype=int), columns=columns, dtype=float, ) From 191eed861aa17f5fb18ea0c3f14a4e600ff66125 Mon Sep 17 00:00:00 2001 From: niksirbi Date: Thu, 20 Jul 2023 18:30:46 +0100 Subject: [PATCH 19/79] started adapting unit tests for PoseTracks object --- tests/test_unit/test_io.py | 172 +++++++++++++++++++++++++------------ 1 file changed, 116 insertions(+), 56 deletions(-) diff --git a/tests/test_unit/test_io.py b/tests/test_unit/test_io.py index fec91eca..90e7970a 100644 --- a/tests/test_unit/test_io.py +++ b/tests/test_unit/test_io.py @@ -1,38 +1,89 @@ import os import h5py +import numpy as np import pandas as pd import pytest -from pandas.testing import assert_frame_equal -from pydantic import ValidationError -from tables import HDF5ExtError +from xarray.testing import assert_allclose from movement.datasets import fetch_pose_data_path -from movement.io import load_poses +from movement.io import PoseTracks -class TestLoadPoses: - """Test the load_poses module.""" +class TestPoseTracksIO: + """Test the IO functionalities of the PoseTracks class.""" @pytest.fixture - def valid_dlc_files(self): - """Return the paths to valid DLC poses files, - in .h5 format. - - Returns - ------- - dict - Dictionary containing the paths. - - h5_path: pathlib Path to a valid .h5 file - - h5_str: path as str to a valid .h5 file - """ - h5_file = fetch_pose_data_path("DLC_single-wasp.predictions.h5") - csv_file = fetch_pose_data_path("DLC_single-wasp.predictions.csv") + def dlc_file_h5_single(self): + """Return the path to a valid DLC h5 file containing pose data + for a single animal.""" + return fetch_pose_data_path("DLC_single-wasp.predictions.h5") + + @pytest.fixture + def dlc_file_csv_single(self): + """Return the path to a valid DLC .csv file containing pose data + for a single animal. The underlying data is the same as in the + `dlc_file_h5_single` fixture.""" + return fetch_pose_data_path("DLC_single-wasp.predictions.csv") + + @pytest.fixture + def dlc_file_csv_multi(self): + """Return the path to a valid DLC .csv file containing pose data + for multiple animals.""" + return fetch_pose_data_path("DLC_two-mice.predictions.csv") + + @pytest.fixture + def sleap_file_h5_single(self): + """Return the path to a valid SLEAP "analysis" .h5 file containing + pose data for a single animal.""" + return fetch_pose_data_path("SLEAP_single-mouse_EPM.analysis.h5") + + @pytest.fixture + def sleap_file_slp_single(self): + """Return the path to a valid SLEAP .slp file containing + predicted poses (labels) for a single animal.""" + return fetch_pose_data_path("SLEAP_single-mouse_EPM.predictions.slp") + + @pytest.fixture + def sleap_file_h5_multi(self): + """Return the path to a valid SLEAP "analysis" .h5 file containing + pose data for multiple animals.""" + return fetch_pose_data_path( + "SLEAP_three-mice_Aeon_proofread.analysis.h5" + ) + + @pytest.fixture + def sleap_file_slp_multi(self): + """Return the path to a valid SLEAP .slp file containing + predicted poses (labels) for multiple animals.""" + return fetch_pose_data_path( + "SLEAP_three-mice_Aeon_proofread.predictions.slp" + ) + + @pytest.fixture + def valid_dlc_files( + dlc_file_h5_single, dlc_file_csv_single, dlc_file_csv_multi + ): + """Aggregate all valid DLC files in a dictionary, for convenience.""" + return { + "h5_single": dlc_file_h5_single, + "csv_single": dlc_file_csv_single, + "csv_multi": dlc_file_csv_multi, + } + + @pytest.fixture + def valid_sleap_files( + sleap_file_h5_single, + sleap_file_slp_single, + sleap_file_h5_multi, + sleap_file_slp_multi, + ): + """Aggregate all valid SLEAP files in a dictionary, for convenience.""" return { - "h5_path": h5_file, - "h5_str": h5_file.as_posix(), - "csv_path": csv_file, - "csv_str": csv_file.as_posix(), + "h5_single": sleap_file_h5_single, + "slp_single": sleap_file_slp_single, + "h5_multi": sleap_file_h5_multi, + "slp_multi": sleap_file_slp_multi, } @pytest.fixture @@ -59,37 +110,46 @@ def invalid_files(self, tmp_path): "nonexistent": nonexistent_file, } - def test_load_valid_dlc_files(self, valid_dlc_files): - """Test loading valid DLC poses files.""" - for file_type, file_path in valid_dlc_files.items(): - df = load_poses.from_dlc(file_path) - assert isinstance(df, pd.DataFrame) - assert not df.empty - - def test_load_invalid_dlc_files(self, invalid_files): - """Test loading invalid DLC poses files.""" - for file_type, file_path in invalid_files.items(): - if file_type == "nonexistent": - with pytest.raises(FileNotFoundError): - load_poses.from_dlc(file_path) - elif file_type == "wrong_ext": - with pytest.raises(ValueError): - load_poses.from_dlc(file_path) - else: - with pytest.raises((OSError, HDF5ExtError)): - load_poses.from_dlc(file_path) - - @pytest.mark.parametrize("file_path", [1, 1.0, True, None, [], {}]) - def test_load_from_dlc_with_incorrect_file_path_types(self, file_path): - """Test loading poses from a file_path with an incorrect type.""" - with pytest.raises(ValidationError): - load_poses.from_dlc(file_path) - - def test_load_from_dlc_csv_or_h5_file_returns_same_df( - self, valid_dlc_files + @pytest.fixture + def dlc_style_df(self, dlc_file_h5_single): + """Return a valid DLC-style DataFrame.""" + df = pd.read_hdf(dlc_file_h5_single) + return df + + def test_load_from_dlc_file_csv_or_h5_file_returns_same( + self, dlc_file_h5_single, dlc_file_csv_single ): - """Test that loading poses from DLC .csv and .h5 files - return the same DataFrame.""" - df_from_h5 = load_poses.from_dlc(valid_dlc_files["h5_path"]) - df_from_csv = load_poses.from_dlc(valid_dlc_files["csv_path"]) - assert_frame_equal(df_from_h5, df_from_csv) + """Test that loading pose tracks from DLC .csv and .h5 files + return the same Dataset.""" + ds_from_h5 = PoseTracks.from_dlc_file(dlc_file_h5_single) + ds_from_csv = PoseTracks.from_dlc_file(dlc_file_csv_single) + assert_allclose(ds_from_h5, ds_from_csv) + + @pytest.mark.parametrize("fps", [None, -5, 0, 30, 60.0]) + def test_fps_and_time_coords(self, sleap_file_h5_multi, fps): + """Test that time coordinates are set according to the fps.""" + ds = PoseTracks.from_sleap_file(sleap_file_h5_multi, fps=fps) + if (fps is None) or (fps <= 0): + assert ds.fps is None + assert ds.time_unit == "frames" + else: + assert ds.fps == fps + assert ds.time_unit == "seconds" + np.allclose( + ds.coords["time"].data, + np.arange(ds.dims["time"], dtype=int) / ds.attrs["fps"], + ) + + def test_from_and_to_dlc_df(self, dlc_style_df): + """Test that loading pose tracks from a DLC-style DataFrame and + converting back to a DataFrame returns the same data values.""" + ds = PoseTracks.from_dlc_df(dlc_style_df) + df = ds.to_dlc_df() + assert np.allclose(df.values, dlc_style_df.values) + + def test_load_from_str_path(self, sleap_file_h5_single): + """Test that file paths provided as strings are accepted as input.""" + assert_allclose( + PoseTracks.from_sleap_file(sleap_file_h5_single), + PoseTracks.from_sleap_file(sleap_file_h5_single.as_posix()), + ) From 2248a6c34c1baf59716715eda2cd0fecceb580fc Mon Sep 17 00:00:00 2001 From: niksirbi Date: Fri, 21 Jul 2023 16:24:56 +0100 Subject: [PATCH 20/79] added tests for PoseTracks initialisation --- tests/test_unit/test_io.py | 66 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/tests/test_unit/test_io.py b/tests/test_unit/test_io.py index 90e7970a..3c20bab6 100644 --- a/tests/test_unit/test_io.py +++ b/tests/test_unit/test_io.py @@ -153,3 +153,69 @@ def test_load_from_str_path(self, sleap_file_h5_single): PoseTracks.from_sleap_file(sleap_file_h5_single), PoseTracks.from_sleap_file(sleap_file_h5_single.as_posix()), ) + + @pytest.mark.parametrize( + "scores_array", [None, np.zeros((10, 2, 2)), np.zeros((10, 2, 3))] + ) + def test_init_scores(self, scores_array): + """Test that confidence scores are correctly initialized.""" + tracks = np.random.rand(10, 2, 2, 2) + + if scores_array is None: + ds = PoseTracks(tracks, scores_array=scores_array) + assert ds.confidence_scores.shape == (10, 2, 2) + assert np.all(np.isnan(ds.confidence_scores.data)) + elif scores_array.shape == (10, 2, 2): + ds = PoseTracks(tracks, scores_array=scores_array) + assert np.allclose(ds.confidence_scores.data, scores_array) + else: + with pytest.raises(ValueError): + ds = PoseTracks(tracks, scores_array=scores_array) + + @pytest.mark.parametrize( + "individual_names", + [None, ["animal_1", "animal_2"], ["animal_1", "animal_2", "animal_3"]], + ) + def test_init_individual_names(self, individual_names): + """Test that individual names are correctly initialized.""" + tracks = np.random.rand(10, 2, 2, 2) + + if individual_names is None: + ds = PoseTracks(tracks, individual_names=individual_names) + assert ds.dims["individuals"] == 2 + assert all( + [ + f"individual_{i}" in ds.coords["individuals"] + for i in range(2) + ] + ) + elif len(individual_names) == 2: + ds = PoseTracks(tracks, individual_names=individual_names) + assert ds.dims["individuals"] == 2 + assert all( + [n in ds.coords["individuals"] for n in individual_names] + ) + else: + with pytest.raises(ValueError): + ds = PoseTracks(tracks, individual_names=individual_names) + + @pytest.mark.parametrize( + "keypoint_names", [None, ["kp_1", "kp_2"], ["kp_1", "kp_2", "kp_3"]] + ) + def test_init_keypoint_names(self, keypoint_names): + """Test that keypoint names are correctly initialized.""" + tracks = np.random.rand(10, 2, 2, 2) + + if keypoint_names is None: + ds = PoseTracks(tracks, keypoint_names=keypoint_names) + assert ds.dims["keypoints"] == 2 + assert all( + [f"keypoint_{i}" in ds.coords["keypoints"] for i in range(2)] + ) + elif len(keypoint_names) == 2: + ds = PoseTracks(tracks, keypoint_names=keypoint_names) + assert ds.dims["keypoints"] == 2 + assert all([n in ds.coords["keypoints"] for n in keypoint_names]) + else: + with pytest.raises(ValueError): + ds = PoseTracks(tracks, keypoint_names=keypoint_names) From c8444c861981609e529e5cf732ea02458691ebf3 Mon Sep 17 00:00:00 2001 From: niksirbi Date: Fri, 21 Jul 2023 16:41:37 +0100 Subject: [PATCH 21/79] removed attrs dependency for now --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 01ddc1ee..6f125741 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,6 @@ dependencies = [ "tqdm", "sleap-io", "xarray", - "attrs", ] classifiers = [ From 64fe008e421135e592d11f86483a5b27035ef09c Mon Sep 17 00:00:00 2001 From: niksirbi Date: Fri, 21 Jul 2023 16:46:05 +0100 Subject: [PATCH 22/79] using pydantic 2.0 or greater --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 6f125741..4874af0a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ dependencies = [ "numpy", "pandas", "h5py", - "pydantic", + "pydantic>=2.0", "pooch", "tqdm", "sleap-io", From 49932dbcde52837f9cd0847b7d3d4785698dbf9b Mon Sep 17 00:00:00 2001 From: niksirbi Date: Fri, 21 Jul 2023 16:50:01 +0100 Subject: [PATCH 23/79] make _parse_dlc_csv_to_dataframe a static method --- movement/io/pose_tracks.py | 1 + 1 file changed, 1 insertion(+) diff --git a/movement/io/pose_tracks.py b/movement/io/pose_tracks.py index b3e73c0d..40046046 100644 --- a/movement/io/pose_tracks.py +++ b/movement/io/pose_tracks.py @@ -396,6 +396,7 @@ def _load_dict_from_sleap_labels_file(file_path: Path): "keypoint_names": [kp.name for kp in labels.skeletons[0].nodes], } + @staticmethod def _parse_dlc_csv_to_dataframe(file_path: Path) -> pd.DataFrame: """If poses are loaded from a DeepLabCut.csv file, the DataFrame lacks the multi-index columns that are present in the .h5 file. This From 2d56a11beff6d6d36ccff1c4f62b453ff2cca670 Mon Sep 17 00:00:00 2001 From: niksirbi Date: Fri, 21 Jul 2023 17:55:30 +0100 Subject: [PATCH 24/79] added test for loading a variety of valid pose files --- tests/test_unit/test_io.py | 60 +++++++++++++++++++++++++++----------- 1 file changed, 43 insertions(+), 17 deletions(-) diff --git a/tests/test_unit/test_io.py b/tests/test_unit/test_io.py index 3c20bab6..d6dd8812 100644 --- a/tests/test_unit/test_io.py +++ b/tests/test_unit/test_io.py @@ -4,6 +4,7 @@ import numpy as np import pandas as pd import pytest +from xarray import DataArray from xarray.testing import assert_allclose from movement.datasets import fetch_pose_data_path @@ -61,29 +62,25 @@ def sleap_file_slp_multi(self): ) @pytest.fixture - def valid_dlc_files( - dlc_file_h5_single, dlc_file_csv_single, dlc_file_csv_multi - ): - """Aggregate all valid DLC files in a dictionary, for convenience.""" - return { - "h5_single": dlc_file_h5_single, - "csv_single": dlc_file_csv_single, - "csv_multi": dlc_file_csv_multi, - } - - @pytest.fixture - def valid_sleap_files( + def valid_files( + self, + dlc_file_h5_single, + dlc_file_csv_single, + dlc_file_csv_multi, sleap_file_h5_single, sleap_file_slp_single, sleap_file_h5_multi, sleap_file_slp_multi, ): - """Aggregate all valid SLEAP files in a dictionary, for convenience.""" + """Aggregate all valid files in a dictionary, for convenience.""" return { - "h5_single": sleap_file_h5_single, - "slp_single": sleap_file_slp_single, - "h5_multi": sleap_file_h5_multi, - "slp_multi": sleap_file_slp_multi, + "DLC_h5_single": dlc_file_h5_single, + "DLC_csv_single": dlc_file_csv_single, + "DLC_csv_multi": dlc_file_csv_multi, + "SLEAP_h5_single": sleap_file_h5_single, + "SLEAP_slp_single": sleap_file_slp_single, + "SLEAP_h5_multi": sleap_file_h5_multi, + "SLEAP_slp_multi": sleap_file_slp_multi, } @pytest.fixture @@ -116,6 +113,35 @@ def dlc_style_df(self, dlc_file_h5_single): df = pd.read_hdf(dlc_file_h5_single) return df + def test_load_from_valid_files(self, valid_files): + """Test that loading pose tracks from a wide variety of valid files + returns a proper Dataset.""" + abbrev_expand = {"DLC": "DeepLabCut", "SLEAP": "SLEAP"} + + for file_type, file_path in valid_files.items(): + if file_type.startswith("DLC"): + ds = PoseTracks.from_dlc_file(file_path) + elif file_type.startswith("SLEAP"): + ds = PoseTracks.from_sleap_file(file_path) + + assert isinstance(ds, PoseTracks) + # Expected variables are present and of right shape/type + for var in ["pose_tracks", "confidence_scores"]: + assert var in ds.data_vars + assert isinstance(ds[var], DataArray) + assert ds.pose_tracks.ndim == 4 + assert ds.confidence_scores.shape == ds.pose_tracks.shape[:-1] + # Check the dims and coords + assert all([i in ds.dims for i in ds.dim_names]) + for d, dim in enumerate(ds.dim_names[1:]): + assert ds.dims[dim] == ds.pose_tracks.shape[d + 1] + assert all([isinstance(s, str) for s in ds.coords[dim].values]) + assert all([i in ds.coords["space"] for i in ["x", "y"]]) + # Check the metadata attributes + assert ds.source_software == abbrev_expand[file_type.split("_")[0]] + assert ds.source_file == file_path.as_posix() + assert ds.fps is None + def test_load_from_dlc_file_csv_or_h5_file_returns_same( self, dlc_file_h5_single, dlc_file_csv_single ): From d210229d621eab65c824e55e2f5633f10d9e0448 Mon Sep 17 00:00:00 2001 From: niksirbi Date: Fri, 21 Jul 2023 17:57:35 +0100 Subject: [PATCH 25/79] remove unnecessary variable assignments --- tests/test_unit/test_io.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_unit/test_io.py b/tests/test_unit/test_io.py index d6dd8812..97debed0 100644 --- a/tests/test_unit/test_io.py +++ b/tests/test_unit/test_io.py @@ -196,7 +196,7 @@ def test_init_scores(self, scores_array): assert np.allclose(ds.confidence_scores.data, scores_array) else: with pytest.raises(ValueError): - ds = PoseTracks(tracks, scores_array=scores_array) + PoseTracks(tracks, scores_array=scores_array) @pytest.mark.parametrize( "individual_names", @@ -223,7 +223,7 @@ def test_init_individual_names(self, individual_names): ) else: with pytest.raises(ValueError): - ds = PoseTracks(tracks, individual_names=individual_names) + PoseTracks(tracks, individual_names=individual_names) @pytest.mark.parametrize( "keypoint_names", [None, ["kp_1", "kp_2"], ["kp_1", "kp_2", "kp_3"]] @@ -244,4 +244,4 @@ def test_init_keypoint_names(self, keypoint_names): assert all([n in ds.coords["keypoints"] for n in keypoint_names]) else: with pytest.raises(ValueError): - ds = PoseTracks(tracks, keypoint_names=keypoint_names) + PoseTracks(tracks, keypoint_names=keypoint_names) From 2d4965711de28400db354ffa4ddaa0790ac1512c Mon Sep 17 00:00:00 2001 From: niksirbi Date: Fri, 21 Jul 2023 18:06:48 +0100 Subject: [PATCH 26/79] use typing.List in type hints to make py3.8 happy --- movement/io/pose_tracks.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/movement/io/pose_tracks.py b/movement/io/pose_tracks.py index 40046046..261939cb 100644 --- a/movement/io/pose_tracks.py +++ b/movement/io/pose_tracks.py @@ -1,6 +1,6 @@ import logging from pathlib import Path -from typing import ClassVar, Optional, Union +from typing import ClassVar, List, Optional, Union import h5py import numpy as np @@ -50,8 +50,8 @@ def __init__( self, tracks_array: np.ndarray, scores_array: Optional[np.ndarray] = None, - individual_names: Optional[list[str]] = None, - keypoint_names: Optional[list[str]] = None, + individual_names: Optional[List[str]] = None, + keypoint_names: Optional[List[str]] = None, fps: Optional[float] = None, ): """Create a `PoseTracks` dataset containing pose tracks and From 719742eb2aa5002b382b05a08e0e13d63f61e6fe Mon Sep 17 00:00:00 2001 From: niksirbi Date: Fri, 21 Jul 2023 18:28:40 +0100 Subject: [PATCH 27/79] make mypy happy --- .pre-commit-config.yaml | 1 + movement/io/pose_tracks.py | 9 +++++---- pyproject.toml | 9 +++++++++ 3 files changed, 15 insertions(+), 4 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6d994f29..a5371873 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -26,6 +26,7 @@ repos: - id: mypy additional_dependencies: - types-setuptools + - pandas-stubs - repo: https://github.com/mgedmin/check-manifest rev: "0.49" hooks: diff --git a/movement/io/pose_tracks.py b/movement/io/pose_tracks.py index 261939cb..704d4c2a 100644 --- a/movement/io/pose_tracks.py +++ b/movement/io/pose_tracks.py @@ -284,9 +284,7 @@ def from_dlc_file( if file_suffix == ".csv": df = cls._parse_dlc_csv_to_dataframe(file_path) else: # file can only be .h5 at this point - df = pd.read_hdf(file_path) - # above line does not necessarily return a DataFrame - df = pd.DataFrame(df) + df = pd.DataFrame(pd.read_hdf(file_path)) except (OSError, TypeError, ValueError) as e: error_msg = ( f"Could not load poses from {file_path}. " @@ -431,7 +429,10 @@ def _parse_dlc_csv_to_dataframe(file_path: Path) -> pd.DataFrame: # Import the DLC poses as a DataFrame df = pd.read_csv( - file_path, skiprows=len(header_lines), index_col=0, names=columns + file_path, + skiprows=len(header_lines), + index_col=0, + names=np.array(columns), ) df.columns.rename(level_names, inplace=True) return df diff --git a/pyproject.toml b/pyproject.toml index 4874af0a..338cc4bd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,7 @@ dev = [ "pre-commit", "ruff", "setuptools_scm", + "pandas-stubs", ] [build-system] @@ -91,6 +92,14 @@ ignore = [ "docs/source/", ] +[[tool.mypy.overrides]] +module = [ + "pooch.*", + "h5py.*", + "sleap_io.*", +] +ignore_missing_imports = true + [tool.ruff] line-length = 79 From fa571207ac0b9e8badea6b36894a95abf623ed80 Mon Sep 17 00:00:00 2001 From: niksirbi Date: Mon, 31 Jul 2023 09:30:24 +0100 Subject: [PATCH 28/79] added file validators using pydantic 2.0 --- movement/io/file_validators.py | 200 +++++++++++++++++++++++++++++++++ movement/io/pose_tracks.py | 173 +++++++++++++++++++--------- movement/io/validators.py | 39 ------- 3 files changed, 318 insertions(+), 94 deletions(-) create mode 100644 movement/io/file_validators.py delete mode 100644 movement/io/validators.py diff --git a/movement/io/file_validators.py b/movement/io/file_validators.py new file mode 100644 index 00000000..eb3804fc --- /dev/null +++ b/movement/io/file_validators.py @@ -0,0 +1,200 @@ +import os +from pathlib import Path +from typing import Literal, Optional + +import h5py +from pydantic import ( + BaseModel, + Field, + field_validator, + model_validator, +) + + +class ValidFile(BaseModel): + """Pydantic class for validating file paths. + + It ensures that: + - the path can be converted to a pathlib.Path object. + - the path does not point to a directory. + - the file has the expected access permission(s): 'r', 'w', or 'rw'. + - the file exists `expected_permission` is 'r' or 'rw'. + - the file does not exist when `expected_permission` is 'w'. + - the file has one of the expected suffixes, if specified. + + Parameters + ---------- + path : str or pathlib.Path + Path to the file. + expected_permission : {'r', 'w', 'rw'} + Expected access permission(s) for the file. If 'r', the file is + expected to be readable. If 'w', the file is expected to be writable. + If 'rw', the file is expected to be both readable and writable. + Default: 'r'. + expected_suffix : list of str or None + Expected suffix(es) for the file. If None (default), this check is + skipped. + + """ + + path: Path + expected_permission: Literal["r", "w", "rw"] = Field(default="r") + expected_suffix: Optional[list[str]] = Field(default=None) + + @field_validator("path", mode="before") # run before instantiation + def convert_to_path(cls, value): + if not isinstance(value, Path): + try: + value = Path(value) + except TypeError as error: + raise error + return value + + @field_validator("path") + def path_exists(cls, value): + if not value.exists(): + raise FileNotFoundError(f"File not found: {value}.") + return value + + @field_validator("path") + def path_is_not_dir(cls, value): + if value.is_dir(): + raise ValueError( + f"Expected a file but got a directory: {value}. " + "Please specify a file path." + ) + return value + + @model_validator(mode="after") + def file_has_expected_permission(self) -> "ValidFile": + """Ensure that the file has the expected permission.""" + is_readable = os.access(self.path, os.R_OK) + is_writeable = os.access(self.path.parent, os.W_OK) + + if self.expected_permission == "r": + if not is_readable: + raise PermissionError( + f"Unable to read file: {self.path}. " + "Make sure that you have read permissions for it." + ) + elif self.expected_permission == "w": + if not is_writeable: + raise PermissionError( + f"Unable to write to file: {self.path}. " + "Make sure that you have write permissions for it." + ) + elif self.expected_permission == "rw": + if not (is_readable and is_writeable): + raise PermissionError( + f"Unable to read and/or write to file: {self.path}. Make" + "sure that you have read and write permissions for it." + ) + return self + + @model_validator(mode="after") + def file_exists_when_expected(self) -> "ValidFile": + """Ensure that the file exists when expected (matches the expected + permission, i.e. the intended use of the file).""" + if self.expected_permission in ["r", "rw"]: + if not self.path.exists(): + raise FileNotFoundError( + f"Expected file {self.path} does not exist." + ) + else: + if self.path.exists(): + raise FileExistsError( + f"Expected file {self.path} already exists." + ) + return self + + @model_validator(mode="after") + def file_has_expected_suffix(self) -> "ValidFile": + """Ensure that the file has the expected suffix.""" + if self.expected_suffix is not None: + if self.path.suffix.lower() not in self.expected_suffix: + raise ValueError( + f"Expected file extension(s) {self.expected_suffix} " + f"but got {self.path.suffix} for file: {self.path}." + ) + return self + + +class ValidHDF5(BaseModel): + """Pydantic class for validating HDF5 files. This class ensures that the + file is a properly formatted and contains the expected datasets + (if specified). + + Parameters + ---------- + file : movement.io.validators.ValidFile + Validated path to the HDF5 file. + expected_datasets : list of str or None + List of names of the expected datasets in the HDF5 file. If None + (default), this check is skipped. + """ + + file: ValidFile + expected_datasets: Optional[list[str]] = Field(default=None) + + @field_validator("file") + def file_is_h5(cls, value): + """Ensure that the file is indeed in HDF5 format.""" + try: + with h5py.File(value.path, "r") as f: + assert isinstance( + f, h5py.File + ), f"Expected an HDF5 file but got {type(f)}: {value.path}. " + except OSError as error: + raise error + return value + + @model_validator(mode="after") + def h5_file_contains_expected_datasets(self) -> "ValidHDF5": + """Ensure that the HDF5 file contains the expected datasets.""" + if self.expected_datasets is not None: + with h5py.File(self.file.path, "r") as f: + diff = set(self.expected_datasets).difference(set(f.keys())) + print(diff) + if len(diff) > 0: + raise ValueError( + f"Could not find the expected dataset(s) {diff} " + f"in file: {self.file.path}. " + ) + return self + + +class ValidPosesCSV(BaseModel): + """Pydantic class for validating CSV files that contain pose estimation + outputs in DeepLabCut format. This class ensures that the CSV file contains + the expected index column levels among its top rows. + + Parameters + ---------- + file : movement.io.validators.ValidFile + Validated path to the CSV file. + multianimal : bool + Whether to ensure that the CSV file contains pose estimation outputs + for multiple animals. Default: False. + """ + + file: ValidFile + multianimal: bool = Field(default=False) + + @model_validator(mode="after") + def csv_file_contains_expected_levels(self) -> "ValidPosesCSV": + expected_levels = ["scorer", "bodyparts", "coords"] + if self.multianimal: + expected_levels.insert(1, "individuals") + + with open(self.file.path, "r") as f: + header_rows_start = [f.readline().split(",")[0] for _ in range(4)] + level_in_header_row_starts = [ + level in header_rows_start for level in expected_levels + ] + if not all(level_in_header_row_starts): + raise ValueError( + f"The header rows of the CSV file {self.file.path} do not " + "contain all expected index column levels " + f"{expected_levels}." + ) + return self diff --git a/movement/io/pose_tracks.py b/movement/io/pose_tracks.py index 704d4c2a..14c1ec10 100644 --- a/movement/io/pose_tracks.py +++ b/movement/io/pose_tracks.py @@ -6,8 +6,11 @@ import numpy as np import pandas as pd import xarray as xr +from pydantic import ValidationError from sleap_io.io.slp import read_labels +from movement.io.file_validators import ValidFile, ValidHDF5, ValidPosesCSV + # get logger logger = logging.getLogger(__name__) @@ -190,14 +193,16 @@ def from_sleap_file( file_path : pathlib Path or str Path to the file containing the SLEAP predictions, either in ".slp" or ".h5" (analysis) format. See Notes for more information. - The number of frames per second in the video. If None (default), + fps : float, optional + The number of frames per second in the video. If None (default), the `time` coordinates will be in frame numbers. + Notes ----- The SLEAP predictions are normally saved in a ".slp" file, e.g. "v1.predictions.slp". If this file contains both user-labeled and - predicted instances, only the predicted iones will be loaded. + predicted instances, only the predicted ones will be loaded. An analysis file, suffixed with ".h5" can be exported from the ".slp" file, using either the command line tool `sleap-convert` (with the @@ -219,26 +224,25 @@ def from_sleap_file( >>> poses = PoseTracks.from_sleap_file("path/to/file.slp", fps=30) """ - if not isinstance(file_path, Path): - file_path = Path(file_path) + # Validate the file path + try: + file = ValidFile( + path=file_path, + expected_permission="r", + expected_suffix=[".h5", ".slp"], + ) + except ValidationError as error: + logger.error(error) + raise error # Load data into a dictionary - if file_path.suffix == ".h5": - data_dict = cls._load_dict_from_sleap_analysis_file(file_path) - elif file_path.suffix == ".slp": - data_dict = cls._load_dict_from_sleap_labels_file(file_path) - else: - error_msg = ( - f"Expected file suffix to be '.h5' or '.slp', " - f"but got '{file_path.suffix}'. Make sure the file is " - "a SLEAP labels file with suffix '.slp' or SLEAP analysis " - "file with suffix '.h5'." - ) - # logger.error(error_msg) - raise ValueError(error_msg) + if file.path.suffix == ".h5": + data_dict = cls._load_dict_from_sleap_analysis_file(file) + else: # file.path.suffix == ".slp" + data_dict = cls._load_dict_from_sleap_labels_file(file) logger.debug( - f"Loaded pose tracks from {file_path.as_posix()} into a dict." + f"Loaded pose tracks from {file.path.as_posix()} into a dict." ) # Initialize a PoseTracks dataset from the dictionary @@ -246,9 +250,9 @@ def from_sleap_file( # Add metadata as attrs ds.attrs["source_software"] = "SLEAP" - ds.attrs["source_file"] = file_path.as_posix() + ds.attrs["source_file"] = file.path.as_posix() - logger.info(f"Loaded pose tracks from {file_path}:") + logger.info(f"Loaded pose tracks from {file.path}:") logger.info(ds) return ds @@ -274,32 +278,32 @@ def from_dlc_file( >>> poses = PoseTracks.from_dlc_file("path/to/file.h5", fps=30) """ - # Validate the input file path - if not isinstance(file_path, Path): - file_path = Path(file_path) - file_suffix = file_path.suffix - - # Load the DLC poses into a DataFrame + # Validate the file path try: - if file_suffix == ".csv": - df = cls._parse_dlc_csv_to_dataframe(file_path) - else: # file can only be .h5 at this point - df = pd.DataFrame(pd.read_hdf(file_path)) - except (OSError, TypeError, ValueError) as e: - error_msg = ( - f"Could not load poses from {file_path}. " - "Please check that the file is valid and readable." + file = ValidFile( + path=file_path, + expected_permission="r", + expected_suffix=[".csv", ".h5"], ) - logger.error(error_msg) - raise OSError from e - logger.debug(f"Loaded poses from {file_path} into a DataFrame.") + except ValidationError as error: + logger.error(error) + raise error + # Load the DLC poses into a DataFrame + if file.path.suffix == ".csv": + df = cls._parse_dlc_csv_to_df(file) + else: # file.path.suffix == ".h5" + df = cls._load_df_from_dlc_h5(file) + + logger.debug( + f"Loaded poses from {file.path.as_posix()} into a DataFrame." + ) # Convert the DataFrame to a PoseTracks dataset ds = cls.from_dlc_df(df=df, fps=fps) # Add metadata as attrs ds.attrs["source_software"] = "DeepLabCut" - ds.attrs["source_file"] = file_path.as_posix() + ds.attrs["source_file"] = file.path.as_posix() logger.info(f"Loaded pose tracks from {file_path}:") logger.info(ds) @@ -353,12 +357,48 @@ def to_dlc_df(self) -> pd.DataFrame: logger.info("Converted PoseTracks dataset to DLC-style DataFrame.") return df + def to_dlc_file(self, file_path: Union[str, Path]): + """Save the dataset to a DeepLabCut-style .h5 or .csv file + + Parameters + ---------- + file_path : pathlib Path or str + Path to the file to save the DLC poses to. The file extension + must be either ".h5" (recommended) or ".csv". + """ + + # Validate the file path + try: + file = ValidFile( + path=file_path, + expected_permission="w", + expected_suffix=[".csv", ".h5"], + ) + except ValidationError as error: + logger.error(error) + raise error + + # Convert the PoseTracks dataset to a DataFrame + df = self.to_dlc_df() + if file.path.suffix == ".csv": + df.to_csv(file.path, sep=",") + else: # file.path.suffix == ".h5" + df.to_hdf(file.path, key="df_with_missing") + logger.info(f"Saved PoseTracks dataset to {file.path.as_posix()}.") + @staticmethod - def _load_dict_from_sleap_analysis_file(file_path: Path): + def _load_dict_from_sleap_analysis_file(file: ValidFile): """Load pose tracks and confidence scores from a SLEAP analysis file into a dictionary.""" - with h5py.File(file_path, "r") as f: + # Validate the hdf5 file + try: + ValidHDF5(file=file, expected_datasets=["tracks"]) + except ValidationError as error: + logger.error(error) + raise error + + with h5py.File(file.path, "r") as f: tracks = f["tracks"][:].T n_frames, n_keypoints, n_space, n_tracks = tracks.shape tracks = tracks.reshape((n_frames, n_tracks, n_keypoints, n_space)) @@ -380,11 +420,18 @@ def _load_dict_from_sleap_analysis_file(file_path: Path): } @staticmethod - def _load_dict_from_sleap_labels_file(file_path: Path): + def _load_dict_from_sleap_labels_file(file: ValidFile): """Load pose tracks and confidence scores from a SLEAP labels file into a dictionary.""" - labels = read_labels(file_path.as_posix()) + # Validate the .slp file as an HDF5 file + try: + ValidHDF5(file=file, expected_datasets=["pred_points", "metadata"]) + except ValidationError as error: + logger.error(error) + raise error + + labels = read_labels(file.path.as_posix()) tracks_with_scores = labels.numpy(return_confidence=True) return { @@ -395,25 +442,21 @@ def _load_dict_from_sleap_labels_file(file_path: Path): } @staticmethod - def _parse_dlc_csv_to_dataframe(file_path: Path) -> pd.DataFrame: + def _parse_dlc_csv_to_df(file: ValidFile) -> pd.DataFrame: """If poses are loaded from a DeepLabCut.csv file, the DataFrame lacks the multi-index columns that are present in the .h5 file. This function parses the csv file to a pandas DataFrame with multi-index columns, i.e. the same format as in the .h5 file. - - Parameters - ---------- - file_path : pathlib Path - Path to the file containing the DLC poses, in .csv format. - - Returns - ------- - pandas DataFrame - DataFrame containing the DLC poses, with multi-index columns. """ + try: + ValidPosesCSV(file=file, multianimal=False) + except ValidationError as error: + logger.error(error) + raise error + possible_level_names = ["scorer", "individuals", "bodyparts", "coords"] - with open(file_path, "r") as f: + with open(file.path, "r") as f: # if line starts with a possible level name, split it into a list # of strings, and add it to the list of header lines header_lines = [ @@ -429,10 +472,30 @@ def _parse_dlc_csv_to_dataframe(file_path: Path) -> pd.DataFrame: # Import the DLC poses as a DataFrame df = pd.read_csv( - file_path, + file.path, skiprows=len(header_lines), index_col=0, names=np.array(columns), ) df.columns.rename(level_names, inplace=True) return df + + @staticmethod + def _load_df_from_dlc_h5(file: ValidFile) -> pd.DataFrame: + """Load pose tracks and likelihood scores from a DeepLabCut .h5 file + into a pandas DataFrame.""" + + try: + ValidHDF5(file=file, expected_datasets=["df_with_missing"]) + except ValidationError as error: + logger.error(error) + raise error + + try: + # pd.read_hdf does not always return a DataFrame + df = pd.DataFrame(pd.read_hdf(file.path, key="df_with_missing")) + except Exception as error: + logger.error(error) + raise error + + return df diff --git a/movement/io/validators.py b/movement/io/validators.py deleted file mode 100644 index 8984b0aa..00000000 --- a/movement/io/validators.py +++ /dev/null @@ -1,39 +0,0 @@ -import logging -from pathlib import Path - -from pydantic import BaseModel, field_validator - -# initialize logger -logger = logging.getLogger(__name__) - - -class DeepLabCutPosesFile(BaseModel): - """Pydantic class for validating files containing - pose estimation results from DeepLabCut (DLC). - - Pydantic will enforce the input data type. - This class additionally checks that the file exists - and has a valid suffix. - """ - - file_path: Path - - @field_validator("file_path") - def file_must_exist(cls, value): - if not value.is_file(): - error_msg = f"File not found: {value}" - logger.error(error_msg) - raise FileNotFoundError(error_msg) - return value - - @field_validator("file_path") - def file_must_have_valid_suffix(cls, value): - if value.suffix not in (".h5", ".csv"): - error_msg = ( - "Expected a file with pose estimation results from " - "DeepLabCut, in one of '.h5' or '.csv' formats. " - f"Received a file with suffix '{value.suffix}' instead." - ) - logger.error(error_msg) - raise ValueError(error_msg) - return value From 92123c716f3f27f86a0d6d0a0a7ac3f20cd79c2b Mon Sep 17 00:00:00 2001 From: niksirbi Date: Mon, 7 Aug 2023 18:18:31 +0100 Subject: [PATCH 29/79] Replaced pydantic with attrs for validation --- .pre-commit-config.yaml | 2 + movement/io/file_validators.py | 240 ++++++++++++++---------------- movement/io/pose_tracks.py | 246 +++++++++++++------------------ movement/io/tracks_validators.py | 160 ++++++++++++++++++++ pyproject.toml | 4 +- tests/test_unit/test_io.py | 151 +++++++++++++------ 6 files changed, 484 insertions(+), 319 deletions(-) create mode 100644 movement/io/tracks_validators.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a5371873..f5fe6c90 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -25,8 +25,10 @@ repos: hooks: - id: mypy additional_dependencies: + - attrs - types-setuptools - pandas-stubs + - types-attrs - repo: https://github.com/mgedmin/check-manifest rev: "0.49" hooks: diff --git a/movement/io/file_validators.py b/movement/io/file_validators.py index eb3804fc..11a53032 100644 --- a/movement/io/file_validators.py +++ b/movement/io/file_validators.py @@ -1,26 +1,18 @@ +import logging import os from pathlib import Path -from typing import Literal, Optional +from typing import List, Literal import h5py -from pydantic import ( - BaseModel, - Field, - field_validator, - model_validator, -) +from attrs import define, field, validators +# get logger +logger = logging.getLogger(__name__) -class ValidFile(BaseModel): - """Pydantic class for validating file paths. - It ensures that: - - the path can be converted to a pathlib.Path object. - - the path does not point to a directory. - - the file has the expected access permission(s): 'r', 'w', or 'rw'. - - the file exists `expected_permission` is 'r' or 'rw'. - - the file does not exist when `expected_permission` is 'w'. - - the file has one of the expected suffixes, if specified. +@define +class ValidFile: + """Class for validating file paths. Parameters ---------- @@ -31,170 +23,160 @@ class ValidFile(BaseModel): expected to be readable. If 'w', the file is expected to be writable. If 'rw', the file is expected to be both readable and writable. Default: 'r'. - expected_suffix : list of str or None - Expected suffix(es) for the file. If None (default), this check is - skipped. - + expected_suffix : list of str + Expected suffix(es) for the file. If an empty list (default), this + check is skipped. + + Raises + ------ + IsADirectoryError + If the path points to a directory. + PermissionError + If the file does not have the expected access permission(s). + FileNotFoundError + If the file does not exist when `expected_permission` is 'r' or 'rw'. + FileExistsError + If the file exists when `expected_permission` is 'w'. + ValueError + If the file does not have one of the expected suffix(es). """ - path: Path - expected_permission: Literal["r", "w", "rw"] = Field(default="r") - expected_suffix: Optional[list[str]] = Field(default=None) - - @field_validator("path", mode="before") # run before instantiation - def convert_to_path(cls, value): - if not isinstance(value, Path): - try: - value = Path(value) - except TypeError as error: - raise error - return value - - @field_validator("path") - def path_exists(cls, value): - if not value.exists(): - raise FileNotFoundError(f"File not found: {value}.") - return value - - @field_validator("path") - def path_is_not_dir(cls, value): + path: Path = field(converter=Path, validator=validators.instance_of(Path)) + expected_permission: Literal["r", "w", "rw"] = field( + default="r", validator=validators.in_(["r", "w", "rw"]), kw_only=True + ) + expected_suffix: List[str] = field(factory=list, kw_only=True) + + @path.validator + def path_is_not_dir(self, attribute, value): + """Ensures that the path does not point to a directory.""" if value.is_dir(): - raise ValueError( - f"Expected a file but got a directory: {value}. " - "Please specify a file path." + raise IsADirectoryError( + f"Expected a file path but got a directory: {value}." ) - return value - - @model_validator(mode="after") - def file_has_expected_permission(self) -> "ValidFile": - """Ensure that the file has the expected permission.""" - is_readable = os.access(self.path, os.R_OK) - is_writeable = os.access(self.path.parent, os.W_OK) - if self.expected_permission == "r": - if not is_readable: + @path.validator + def file_has_access_permissions(self, attribute, value): + """Ensures that the file has the expected access permission(s). + Raises a PermissionError if not.""" + if "r" in self.expected_permission: + if not os.access(value, os.R_OK): raise PermissionError( - f"Unable to read file: {self.path}. " + f"Unable to read file: {value}. " "Make sure that you have read permissions for it." ) - elif self.expected_permission == "w": - if not is_writeable: + if "w" in self.expected_permission: + if not os.access(value, os.W_OK): raise PermissionError( - f"Unable to write to file: {self.path}. " + f"Unable to write to file: {value}. " "Make sure that you have write permissions for it." ) - elif self.expected_permission == "rw": - if not (is_readable and is_writeable): - raise PermissionError( - f"Unable to read and/or write to file: {self.path}. Make" - "sure that you have read and write permissions for it." - ) - return self - - @model_validator(mode="after") - def file_exists_when_expected(self) -> "ValidFile": - """Ensure that the file exists when expected (matches the expected - permission, i.e. the intended use of the file).""" - if self.expected_permission in ["r", "rw"]: - if not self.path.exists(): - raise FileNotFoundError( - f"Expected file {self.path} does not exist." - ) - else: - if self.path.exists(): - raise FileExistsError( - f"Expected file {self.path} already exists." - ) - return self - @model_validator(mode="after") - def file_has_expected_suffix(self) -> "ValidFile": - """Ensure that the file has the expected suffix.""" - if self.expected_suffix is not None: - if self.path.suffix.lower() not in self.expected_suffix: + @path.validator + def file_exists_when_expected(self, attribute, value): + """Ensures that the file exists (or not) depending on the expected + usage (read and/or write).""" + if "r" in self.expected_permission: + if not value.exists(): + raise FileNotFoundError(f"File {value} does not exist.") + else: # expected_permission is 'w' + if value.exists(): + raise FileExistsError(f"File {value} already exists.") + + @path.validator + def file_has_expected_suffix(self, attribute, value): + """Ensures that the file has one of the expected suffix(es).""" + if self.expected_suffix: # list is not empty + if value.suffix not in self.expected_suffix: raise ValueError( - f"Expected file extension(s) {self.expected_suffix} " - f"but got {self.path.suffix} for file: {self.path}." + f"Expected file with suffix(es) {self.expected_suffix} " + f"but got suffix {value.suffix} instead." ) - return self -class ValidHDF5(BaseModel): - """Pydantic class for validating HDF5 files. This class ensures that the - file is a properly formatted and contains the expected datasets - (if specified). +@define +class ValidHDF5: + """Class for validating HDF5 files. Parameters ---------- - file : movement.io.validators.ValidFile - Validated path to the HDF5 file. + path : pathlib.Path + Path to the HDF5 file. expected_datasets : list of str or None - List of names of the expected datasets in the HDF5 file. If None - (default), this check is skipped. + List of names of the expected datasets in the HDF5 file. If an empty + list (default), this check is skipped. + + Raises + ------ + ValueError + If the file is not in HDF5 format or if it does not contain the + expected datasets. """ - file: ValidFile - expected_datasets: Optional[list[str]] = Field(default=None) + path: Path = field(validator=validators.instance_of(Path)) + expected_datasets: List[str] = field(factory=list, kw_only=True) - @field_validator("file") - def file_is_h5(cls, value): + @path.validator + def file_is_h5(self, attribute, value): """Ensure that the file is indeed in HDF5 format.""" - try: - with h5py.File(value.path, "r") as f: - assert isinstance( - f, h5py.File - ), f"Expected an HDF5 file but got {type(f)}: {value.path}. " - except OSError as error: - raise error - return value - - @model_validator(mode="after") - def h5_file_contains_expected_datasets(self) -> "ValidHDF5": + with h5py.File(value, "r") as f: + if not isinstance(f, h5py.File): + raise ValueError( + f"Expected an HDF5 file but got {type(f)}: {value}." + ) + + @path.validator + def file_contains_expected_datasets(self, attribute, value): """Ensure that the HDF5 file contains the expected datasets.""" - if self.expected_datasets is not None: - with h5py.File(self.file.path, "r") as f: + if self.expected_datasets: + with h5py.File(value, "r") as f: diff = set(self.expected_datasets).difference(set(f.keys())) - print(diff) if len(diff) > 0: raise ValueError( f"Could not find the expected dataset(s) {diff} " - f"in file: {self.file.path}. " + f"in file: {value}. " ) - return self -class ValidPosesCSV(BaseModel): - """Pydantic class for validating CSV files that contain pose estimation - outputs in DeepLabCut format. This class ensures that the CSV file contains - the expected index column levels among its top rows. +@define +class ValidPosesCSV: + """Class for validating CSV files that contain pose estimation outputs. + in DeepLabCut format. Parameters ---------- - file : movement.io.validators.ValidFile - Validated path to the CSV file. + path : pathlib.Path + Path to the CSV file. multianimal : bool Whether to ensure that the CSV file contains pose estimation outputs for multiple animals. Default: False. + + Raises + ------ + ValueError + If the CSV file does not contain the expected DeepLabCut index column + levels among its top rows. """ - file: ValidFile - multianimal: bool = Field(default=False) + path: Path = field(validator=validators.instance_of(Path)) + multianimal: bool = field(default=False, kw_only=True) - @model_validator(mode="after") - def csv_file_contains_expected_levels(self) -> "ValidPosesCSV": + @path.validator + def csv_file_contains_expected_levels(self, attribute, value): + """Ensure that the CSV file contains the expected index column levels + among its top rows.""" expected_levels = ["scorer", "bodyparts", "coords"] if self.multianimal: expected_levels.insert(1, "individuals") - with open(self.file.path, "r") as f: + with open(value, "r") as f: header_rows_start = [f.readline().split(",")[0] for _ in range(4)] level_in_header_row_starts = [ level in header_rows_start for level in expected_levels ] if not all(level_in_header_row_starts): raise ValueError( - f"The header rows of the CSV file {self.file.path} do not " + f"The header rows of the CSV file {value.path} do not " "contain all expected index column levels " f"{expected_levels}." ) - return self diff --git a/movement/io/pose_tracks.py b/movement/io/pose_tracks.py index 14c1ec10..49de5fa6 100644 --- a/movement/io/pose_tracks.py +++ b/movement/io/pose_tracks.py @@ -1,15 +1,15 @@ import logging from pathlib import Path -from typing import ClassVar, List, Optional, Union +from typing import ClassVar, Optional, Union import h5py import numpy as np import pandas as pd import xarray as xr -from pydantic import ValidationError from sleap_io.io.slp import read_labels from movement.io.file_validators import ValidFile, ValidHDF5, ValidPosesCSV +from movement.io.tracks_validators import ValidPoseTracks # get logger logger = logging.getLogger(__name__) @@ -31,7 +31,7 @@ class PoseTracks(xr.Dataset): The dataset contains two data variables (`xarray.DataArray` objects): - `pose_tracks`: with shape (`time`, `individuals`, `keypoints`, `space`) - - `confidence_scores`: with shape (`time`, `individuals`, `keypoints`) + - `confidence`: with shape (`time`, `individuals`, `keypoints`) The dataset may also contain following attributes as metadata: - `fps`: the number of frames per second in the video @@ -49,84 +49,45 @@ class PoseTracks(xr.Dataset): __slots__ = ("fps", "time_unit", "source_software", "source_file") - def __init__( - self, - tracks_array: np.ndarray, - scores_array: Optional[np.ndarray] = None, - individual_names: Optional[List[str]] = None, - keypoint_names: Optional[List[str]] = None, - fps: Optional[float] = None, - ): - """Create a `PoseTracks` dataset containing pose tracks and - point-wise confidence scores. + @classmethod + def _from_valid_data(cls, data: ValidPoseTracks): + """Initialize a `PoseTracks` xarray.Dataset from already validated + data - i.e. a `ValidPoseTracks` object. Parameters ---------- - tracks_array : np.ndarray - Array of shape (n_frames, n_individuals, n_keypoints, n_space) - containing the pose tracks. It will be converted to a - `xarray.DataArray` object named "pose_tracks". - scores_array : np.ndarray, optional - Array of shape (n_frames, n_individuals, n_keypoints) containing - the point-wise confidence scores. It will be converted to a - `xarray.DataArray` object named "confidence_scores". - If None (default), the scores will be set to an array of NaNs. - individual_names : list of str, optional - List of unique names for the individuals in the video. If None - (default), the individuals will be named "individual_0", - "individual_1", etc. - keypoint_names : list of str, optional - List of unique names for the keypoints in the skeleton. If None - (default), the keypoints will be named "keypoint_0", "keypoint_1", - etc. - fps : float, optional - The number of frames per second in the video. If None (default), - the `time` coordinates will be in frame numbers. + data : movement.io.tracks_validators.ValidPoseTracks + The validated data object. """ - n_frames, n_individuals, n_keypoints, n_space = tracks_array.shape - if scores_array is None: - scores_array = np.full( - (n_frames, n_individuals, n_keypoints), np.nan, dtype="float32" - ) - if individual_names is None: - individual_names = [ - f"individual_{i}" for i in range(n_individuals) - ] - if keypoint_names is None: - keypoint_names = [f"keypoint_{i}" for i in range(n_keypoints)] - if (fps is not None) and (fps <= 0): - logger.warning( - f"Expected fps to be a positive number, but got {fps}. " - "Setting fps to None." - ) - fps = None - - # Convert the pose tracks and confidence scores to xarray.DataArray - tracks_da = xr.DataArray(tracks_array, dims=self.dim_names) - scores_da = xr.DataArray(scores_array, dims=self.dim_names[:-1]) + n_frames = data.tracks_array.shape[0] + n_space = data.tracks_array.shape[-1] # Create the time coordinate, depending on the value of fps time_coords = np.arange(n_frames, dtype=int) time_unit = "frames" - if fps is not None: - time_coords = time_coords / fps + if data.fps is not None: + time_coords = time_coords / data.fps time_unit = "seconds" - # Combine the DataArrays into a Dataset, with common coordinates - super().__init__( + # Convert data to an xarray.Dataset + return cls( data_vars={ - "pose_tracks": tracks_da, - "confidence_scores": scores_da, + "pose_tracks": xr.DataArray( + data.tracks_array, dims=cls.dim_names + ), + "confidence": xr.DataArray( + data.scores_array, dims=cls.dim_names[:-1] + ), }, coords={ - self.dim_names[0]: time_coords, - self.dim_names[1]: individual_names, - self.dim_names[2]: keypoint_names, - self.dim_names[3]: ["x", "y", "z"][:n_space], + cls.dim_names[0]: time_coords, + cls.dim_names[1]: data.individual_names, + cls.dim_names[2]: data.keypoint_names, + cls.dim_names[3]: ["x", "y", "z"][:n_space], }, attrs={ - "fps": fps, + "fps": data.fps, "time_unit": time_unit, "source_software": None, "source_file": None, @@ -174,13 +135,19 @@ def from_dlc_df(cls, df: pd.DataFrame, fps: Optional[float] = None): (-1, len(individual_names), len(keypoint_names), 3) ) - return cls( - tracks_array=tracks_with_scores[:, :, :, :-1], - scores_array=tracks_with_scores[:, :, :, -1], - individual_names=individual_names, - keypoint_names=keypoint_names, - fps=fps, - ) + try: + valid_data = ValidPoseTracks( + tracks_array=tracks_with_scores[:, :, :, :-1], + scores_array=tracks_with_scores[:, :, :, -1], + individual_names=individual_names, + keypoint_names=keypoint_names, + fps=fps, + ) + except ValueError as error: + logger.error(error) + raise error + else: + return cls._from_valid_data(valid_data) @classmethod def from_sleap_file( @@ -224,29 +191,25 @@ def from_sleap_file( >>> poses = PoseTracks.from_sleap_file("path/to/file.slp", fps=30) """ - # Validate the file path try: file = ValidFile( - path=file_path, + file_path, expected_permission="r", expected_suffix=[".h5", ".slp"], ) - except ValidationError as error: + except (OSError, ValueError) as error: logger.error(error) raise error - # Load data into a dictionary + # Load and validate data if file.path.suffix == ".h5": - data_dict = cls._load_dict_from_sleap_analysis_file(file) + valid_data = cls._load_from_sleap_analysis_file(file.path, fps=fps) else: # file.path.suffix == ".slp" - data_dict = cls._load_dict_from_sleap_labels_file(file) - - logger.debug( - f"Loaded pose tracks from {file.path.as_posix()} into a dict." - ) + valid_data = cls._load_from_sleap_labels_file(file.path, fps=fps) + logger.debug(f"Validated pose tracks from {file.path}.") # Initialize a PoseTracks dataset from the dictionary - ds = cls(**data_dict, fps=fps) + ds = cls._from_valid_data(valid_data) # Add metadata as attrs ds.attrs["source_software"] = "SLEAP" @@ -271,33 +234,29 @@ def from_dlc_file( The number of frames per second in the video. If None (default), the `time` coordinates will be in frame numbers. - Examples -------- >>> from movement.io import PoseTracks >>> poses = PoseTracks.from_dlc_file("path/to/file.h5", fps=30) """ - # Validate the file path try: file = ValidFile( - path=file_path, + file_path, expected_permission="r", expected_suffix=[".csv", ".h5"], ) - except ValidationError as error: + except (OSError, ValueError) as error: logger.error(error) raise error # Load the DLC poses into a DataFrame if file.path.suffix == ".csv": - df = cls._parse_dlc_csv_to_df(file) + df = cls._parse_dlc_csv_to_df(file.path) else: # file.path.suffix == ".h5" - df = cls._load_df_from_dlc_h5(file) + df = cls._load_df_from_dlc_h5(file.path) - logger.debug( - f"Loaded poses from {file.path.as_posix()} into a DataFrame." - ) + logger.debug(f"Loaded poses from {file.path} into a DataFrame.") # Convert the DataFrame to a PoseTracks dataset ds = cls.from_dlc_df(df=df, fps=fps) @@ -305,7 +264,7 @@ def from_dlc_file( ds.attrs["source_software"] = "DeepLabCut" ds.attrs["source_file"] = file.path.as_posix() - logger.info(f"Loaded pose tracks from {file_path}:") + logger.info(f"Loaded pose tracks from {file.path}:") logger.info(ds) return ds @@ -331,7 +290,7 @@ def to_dlc_df(self) -> pd.DataFrame: tracks_with_scores = np.concatenate( ( self.pose_tracks.data, - self.confidence_scores.data[..., np.newaxis], + self.confidence.data[..., np.newaxis], ), axis=-1, ) @@ -367,14 +326,13 @@ def to_dlc_file(self, file_path: Union[str, Path]): must be either ".h5" (recommended) or ".csv". """ - # Validate the file path try: file = ValidFile( - path=file_path, + file_path, expected_permission="w", expected_suffix=[".csv", ".h5"], ) - except ValidationError as error: + except (OSError, ValueError) as error: logger.error(error) raise error @@ -384,19 +342,16 @@ def to_dlc_file(self, file_path: Union[str, Path]): df.to_csv(file.path, sep=",") else: # file.path.suffix == ".h5" df.to_hdf(file.path, key="df_with_missing") - logger.info(f"Saved PoseTracks dataset to {file.path.as_posix()}.") + logger.info(f"Saved PoseTracks dataset to {file.path}.") @staticmethod - def _load_dict_from_sleap_analysis_file(file: ValidFile): - """Load pose tracks and confidence scores from a SLEAP analysis - file into a dictionary.""" + def _load_from_sleap_analysis_file( + file_path: Path, fps: Optional[float] + ) -> ValidPoseTracks: + """Load and validate pose tracks and confidence scores from a SLEAP + analysis file""" - # Validate the hdf5 file - try: - ValidHDF5(file=file, expected_datasets=["tracks"]) - except ValidationError as error: - logger.error(error) - raise error + file = ValidHDF5(file_path, expected_datasets=["tracks"]) with h5py.File(file.path, "r") as f: tracks = f["tracks"][:].T @@ -412,48 +367,57 @@ def _load_dict_from_sleap_analysis_file(file: ValidFile): (n_frames, n_tracks, n_keypoints) ) - return { - "tracks_array": tracks, - "scores_array": scores, - "individual_names": [n.decode() for n in f["track_names"][:]], - "keypoint_names": [n.decode() for n in f["node_names"][:]], - } + try: + valid_data = ValidPoseTracks( + tracks_array=tracks, + scores_array=scores, + individual_names=[n.decode() for n in f["track_names"][:]], + keypoint_names=[n.decode() for n in f["node_names"][:]], + fps=fps, + ) + except ValueError as error: + logger.error(error) + raise error + else: + return valid_data @staticmethod - def _load_dict_from_sleap_labels_file(file: ValidFile): - """Load pose tracks and confidence scores from a SLEAP labels file - into a dictionary.""" - - # Validate the .slp file as an HDF5 file - try: - ValidHDF5(file=file, expected_datasets=["pred_points", "metadata"]) - except ValidationError as error: - logger.error(error) - raise error + def _load_from_sleap_labels_file( + file_path: Path, fps: Optional[float] + ) -> ValidPoseTracks: + """Load and validate pose tracks and confidence scores from a SLEAP + labels file.""" + + file = ValidHDF5( + file_path, expected_datasets=["pred_points", "metadata"] + ) labels = read_labels(file.path.as_posix()) tracks_with_scores = labels.numpy(return_confidence=True) - return { - "tracks_array": tracks_with_scores[:, :, :, :-1], - "scores_array": tracks_with_scores[:, :, :, -1], - "individual_names": [track.name for track in labels.tracks], - "keypoint_names": [kp.name for kp in labels.skeletons[0].nodes], - } + try: + valid_data = ValidPoseTracks( + tracks_array=tracks_with_scores[:, :, :, :-1], + scores_array=tracks_with_scores[:, :, :, -1], + individual_names=[track.name for track in labels.tracks], + keypoint_names=[kp.name for kp in labels.skeletons[0].nodes], + fps=fps, + ) + except ValueError as error: + logger.error(error) + raise error + else: + return valid_data @staticmethod - def _parse_dlc_csv_to_df(file: ValidFile) -> pd.DataFrame: + def _parse_dlc_csv_to_df(file_path: Path) -> pd.DataFrame: """If poses are loaded from a DeepLabCut.csv file, the DataFrame lacks the multi-index columns that are present in the .h5 file. This function parses the csv file to a pandas DataFrame with multi-index columns, i.e. the same format as in the .h5 file. """ - try: - ValidPosesCSV(file=file, multianimal=False) - except ValidationError as error: - logger.error(error) - raise error + file = ValidPosesCSV(file_path, multianimal=False) possible_level_names = ["scorer", "individuals", "bodyparts", "coords"] with open(file.path, "r") as f: @@ -481,15 +445,11 @@ def _parse_dlc_csv_to_df(file: ValidFile) -> pd.DataFrame: return df @staticmethod - def _load_df_from_dlc_h5(file: ValidFile) -> pd.DataFrame: + def _load_df_from_dlc_h5(file_path: Path) -> pd.DataFrame: """Load pose tracks and likelihood scores from a DeepLabCut .h5 file into a pandas DataFrame.""" - try: - ValidHDF5(file=file, expected_datasets=["df_with_missing"]) - except ValidationError as error: - logger.error(error) - raise error + file = ValidHDF5(file_path, expected_datasets=["df_with_missing"]) try: # pd.read_hdf does not always return a DataFrame @@ -497,5 +457,5 @@ def _load_df_from_dlc_h5(file: ValidFile) -> pd.DataFrame: except Exception as error: logger.error(error) raise error - - return df + else: + return df diff --git a/movement/io/tracks_validators.py b/movement/io/tracks_validators.py new file mode 100644 index 00000000..4b578ad6 --- /dev/null +++ b/movement/io/tracks_validators.py @@ -0,0 +1,160 @@ +import logging +from collections.abc import Iterable +from typing import Any, List, Optional, Union + +import numpy as np +from attrs import converters, define, field + +# get logger +logger = logging.getLogger(__name__) + + +def _list_of_str(value: Union[str, Iterable[Any]]) -> list[str]: + """Try to coerce the value into a list of strings. + Otherwise, raise a ValueError.""" + if type(value) is str: + warning_msg = ( + f"Invalid value ({value}). Expected a list of strings. " + "Converting to a list of length 1." + ) + logger.warning(warning_msg) + return [value] + elif isinstance(value, Iterable): + return [str(item) for item in value] + else: + error_msg = f"Invalid value ({value}). Expected a list of strings." + logger.error(error_msg) + raise ValueError(error_msg) + + +def _ensure_type_ndarray(value: Any): + """Raise ValueError the value is a not numpy array.""" + if type(value) is not np.ndarray: + raise ValueError(f"Expected a numpy array, but got {type(value)}.") + + +def _set_fps_to_none_if_invalid(fps: Optional[float]) -> Optional[float]: + """Set fps to None if a non-positive float is passed.""" + if fps is not None and fps <= 0: + logger.warning( + f"Invalid fps value ({fps}). Expected a positive number. " + "Setting fps to None." + ) + return None + return fps + + +@define(kw_only=True) +class ValidPoseTracks: + """Class for validating pose tracking data imported from files, before + they are converted to a `PoseTracks` object. + + Attributes + ---------- + tracks_array : np.ndarray + Array of shape (n_frames, n_individuals, n_keypoints, n_space) + containing the pose tracks. It will be converted to a + `xarray.DataArray` object named "pose_tracks". + scores_array : np.ndarray, optional + Array of shape (n_frames, n_individuals, n_keypoints) containing + the point-wise confidence scores. It will be converted to a + `xarray.DataArray` object named "confidence". + If None (default), the scores will be set to an array of NaNs. + individual_names : list of str, optional + List of unique names for the individuals in the video. If None + (default), the individuals will be named "individual_0", + "individual_1", etc. + keypoint_names : list of str, optional + List of unique names for the keypoints in the skeleton. If None + (default), the keypoints will be named "keypoint_0", "keypoint_1", + etc. + fps : float, optional + Frames per second of the video. Defaults to None. + """ + + # Define class attributes + tracks_array: np.ndarray = field() + scores_array: Optional[np.ndarray] = field(default=None) + individual_names: Optional[List[str]] = field( + default=None, + converter=converters.optional(_list_of_str), + ) + keypoint_names: Optional[List[str]] = field( + default=None, + converter=converters.optional(_list_of_str), + ) + fps: Optional[float] = field( + default=None, + converter=converters.pipe( # type: ignore + converters.optional(float), _set_fps_to_none_if_invalid + ), + ) + + # Add validators + @tracks_array.validator + def _validate_tracks_array(self, attribute, value): + _ensure_type_ndarray(value) + if value.ndim != 4: + raise ValueError( + f"Expected `{attribute}` to have 4 dimensions, " + f"but got {value.ndim}." + ) + if value.shape[-1] not in [2, 3]: + raise ValueError( + f"Expected `{attribute}` to have 2 or 3 spatial dimensions, " + f"but got {value.shape[-1]}." + ) + + @scores_array.validator + def _validate_scores_array(self, attribute, value): + if value is not None: + _ensure_type_ndarray(value) + if value.shape != self.tracks_array.shape[:-1]: + raise ValueError( + f"Expected `{attribute}` to have shape " + f"{self.tracks_array.shape[:-1]}, but got {value.shape}." + ) + + @individual_names.validator + def _validate_individual_names(self, attribute, value): + if value is not None: + if len(value) != self.tracks_array.shape[1]: + raise ValueError( + f"Expected {self.tracks_array.shape[1]} `{attribute}`, " + f"but got {len(value)}." + ) + + @keypoint_names.validator + def _validate_keypoint_names(self, attribute, value): + if value is not None: + if len(value) != self.tracks_array.shape[2]: + raise ValueError( + f"Expected {self.tracks_array.shape[2]} `{attribute}`, " + f"but got {len(value)}." + ) + + def __attrs_post_init__(self): + """Assign default values to optional attributes (if None)""" + if self.scores_array is None: + self.scores_array = np.full( + (self.tracks_array.shape[:-1]), np.nan, dtype="float32" + ) + logger.warning( + "Scores array was not provided. Setting to an array of NaNs." + ) + if self.individual_names is None: + self.individual_names = [ + f"individual_{i}" for i in range(self.tracks_array.shape[1]) + ] + logger.warning( + "Individual names were not provided. " + f"Setting to {self.individual_names}." + ) + if self.keypoint_names is None: + self.keypoint_names = [ + f"keypoint_{i}" for i in range(self.tracks_array.shape[2]) + ] + logger.warning( + "Keypoint names were not provided. " + f"Setting to {self.keypoint_names}." + ) diff --git a/pyproject.toml b/pyproject.toml index 338cc4bd..dfdb2e82 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ dependencies = [ "numpy", "pandas", "h5py", - "pydantic>=2.0", + "attrs", "pooch", "tqdm", "sleap-io", @@ -53,6 +53,7 @@ dev = [ "ruff", "setuptools_scm", "pandas-stubs", + "types-attrs", ] [build-system] @@ -100,7 +101,6 @@ module = [ ] ignore_missing_imports = true - [tool.ruff] line-length = 79 exclude = ["__init__.py","build",".eggs"] diff --git a/tests/test_unit/test_io.py b/tests/test_unit/test_io.py index 97debed0..ab089b20 100644 --- a/tests/test_unit/test_io.py +++ b/tests/test_unit/test_io.py @@ -9,11 +9,17 @@ from movement.datasets import fetch_pose_data_path from movement.io import PoseTracks +from movement.io.tracks_validators import ValidPoseTracks class TestPoseTracksIO: """Test the IO functionalities of the PoseTracks class.""" + @pytest.fixture + def valid_tracks_array(self): + """Return a valid tracks array.""" + return np.zeros((10, 2, 2, 2)) + @pytest.fixture def dlc_file_h5_single(self): """Return the path to a valid DLC h5 file containing pose data @@ -126,11 +132,11 @@ def test_load_from_valid_files(self, valid_files): assert isinstance(ds, PoseTracks) # Expected variables are present and of right shape/type - for var in ["pose_tracks", "confidence_scores"]: + for var in ["pose_tracks", "confidence"]: assert var in ds.data_vars assert isinstance(ds[var], DataArray) assert ds.pose_tracks.ndim == 4 - assert ds.confidence_scores.shape == ds.pose_tracks.shape[:-1] + assert ds.confidence.shape == ds.pose_tracks.shape[:-1] # Check the dims and coords assert all([i in ds.dims for i in ds.dim_names]) for d, dim in enumerate(ds.dim_names[1:]): @@ -181,67 +187,122 @@ def test_load_from_str_path(self, sleap_file_h5_single): ) @pytest.mark.parametrize( - "scores_array", [None, np.zeros((10, 2, 2)), np.zeros((10, 2, 3))] + "tracks_array", + [ + None, # invalid, argument is non-optional + [1, 2, 3], # not an ndarray + np.zeros((10, 2, 3)), # not 4d + np.zeros((10, 2, 3, 4)), # last dim not 2 or 3 + ], ) - def test_init_scores(self, scores_array): - """Test that confidence scores are correctly initialized.""" - tracks = np.random.rand(10, 2, 2, 2) + def test_tracks_array_validation(self, tracks_array): + """Test that invalid tracks arrays raise the appropriate errors.""" + with pytest.raises(ValueError): + ValidPoseTracks(tracks_array=tracks_array) + @pytest.mark.parametrize( + "scores_array", + [ + None, # valid, should default to array of NaNs + np.ones((10, 3, 2)), # will not match tracks_array shape + [1, 2, 3], # not an ndarray, should raise ValueError + ], + ) + def test_scores_array_validation(self, valid_tracks_array, scores_array): + """Test that invalid scores arrays raise the appropriate errors.""" if scores_array is None: - ds = PoseTracks(tracks, scores_array=scores_array) - assert ds.confidence_scores.shape == (10, 2, 2) - assert np.all(np.isnan(ds.confidence_scores.data)) - elif scores_array.shape == (10, 2, 2): - ds = PoseTracks(tracks, scores_array=scores_array) - assert np.allclose(ds.confidence_scores.data, scores_array) + poses = ValidPoseTracks(tracks_array=valid_tracks_array) + assert np.all(np.isnan(poses.scores_array)) else: with pytest.raises(ValueError): - PoseTracks(tracks, scores_array=scores_array) + ValidPoseTracks( + tracks_array=valid_tracks_array, scores_array=scores_array + ) @pytest.mark.parametrize( "individual_names", - [None, ["animal_1", "animal_2"], ["animal_1", "animal_2", "animal_3"]], + [ + None, # generate default names + ["ind1", "ind2"], # valid input + ("ind1", "ind2"), # valid input + [1, 2], # will be converted to ["1", "2"] + "ind1", # will be converted to ["ind1"] + 5, # invalid, should raise ValueError + ], ) - def test_init_individual_names(self, individual_names): - """Test that individual names are correctly initialized.""" - tracks = np.random.rand(10, 2, 2, 2) - + def test_individual_names_validation( + self, valid_tracks_array, individual_names + ): if individual_names is None: - ds = PoseTracks(tracks, individual_names=individual_names) - assert ds.dims["individuals"] == 2 - assert all( - [ - f"individual_{i}" in ds.coords["individuals"] - for i in range(2) - ] + poses = ValidPoseTracks( + tracks_array=valid_tracks_array, + individual_names=individual_names, + ) + assert poses.individual_names == ["individual_0", "individual_1"] + elif type(individual_names) in (list, tuple): + poses = ValidPoseTracks( + tracks_array=valid_tracks_array, + individual_names=individual_names, ) - elif len(individual_names) == 2: - ds = PoseTracks(tracks, individual_names=individual_names) - assert ds.dims["individuals"] == 2 - assert all( - [n in ds.coords["individuals"] for n in individual_names] + assert poses.individual_names == [str(i) for i in individual_names] + elif type(individual_names) == str: + poses = ValidPoseTracks( + tracks_array=np.zeros((10, 1, 2, 2)), + individual_names=individual_names, ) + assert poses.individual_names == [individual_names] + # raises error if not 1 individual + with pytest.raises(ValueError): + ValidPoseTracks( + tracks_array=valid_tracks_array, + individual_names=individual_names, + ) else: with pytest.raises(ValueError): - PoseTracks(tracks, individual_names=individual_names) + ValidPoseTracks( + tracks_array=valid_tracks_array, + individual_names=individual_names, + ) @pytest.mark.parametrize( - "keypoint_names", [None, ["kp_1", "kp_2"], ["kp_1", "kp_2", "kp_3"]] + "keypoint_names", + [ + None, # generate default names + ["key1", "key2"], # valid input + ("key", "key2"), # valid input + [1, 2], # will be converted to ["1", "2"] + "key1", # will be converted to ["ind1"] + 5, # invalid, should raise ValueError + ], ) - def test_init_keypoint_names(self, keypoint_names): - """Test that keypoint names are correctly initialized.""" - tracks = np.random.rand(10, 2, 2, 2) - + def test_keypoint_names_validation( + self, valid_tracks_array, keypoint_names + ): if keypoint_names is None: - ds = PoseTracks(tracks, keypoint_names=keypoint_names) - assert ds.dims["keypoints"] == 2 - assert all( - [f"keypoint_{i}" in ds.coords["keypoints"] for i in range(2)] + poses = ValidPoseTracks( + tracks_array=valid_tracks_array, keypoint_names=keypoint_names ) - elif len(keypoint_names) == 2: - ds = PoseTracks(tracks, keypoint_names=keypoint_names) - assert ds.dims["keypoints"] == 2 - assert all([n in ds.coords["keypoints"] for n in keypoint_names]) + assert poses.keypoint_names == ["keypoint_0", "keypoint_1"] + elif type(keypoint_names) in (list, tuple): + poses = ValidPoseTracks( + tracks_array=valid_tracks_array, keypoint_names=keypoint_names + ) + assert poses.keypoint_names == [str(i) for i in keypoint_names] + elif type(keypoint_names) == str: + poses = ValidPoseTracks( + tracks_array=np.zeros((10, 2, 1, 2)), + keypoint_names=keypoint_names, + ) + assert poses.keypoint_names == [keypoint_names] + # raises error if not 1 keypoint + with pytest.raises(ValueError): + ValidPoseTracks( + tracks_array=valid_tracks_array, + keypoint_names=keypoint_names, + ) else: with pytest.raises(ValueError): - PoseTracks(tracks, keypoint_names=keypoint_names) + ValidPoseTracks( + tracks_array=valid_tracks_array, + keypoint_names=keypoint_names, + ) From 1136a3bc3f329e960d5ba2e760232e45048c1a55 Mon Sep 17 00:00:00 2001 From: niksirbi Date: Mon, 7 Aug 2023 19:18:28 +0100 Subject: [PATCH 30/79] expanded tests for file validators --- movement/io/file_validators.py | 38 ++++++++++---------- tests/test_unit/test_io.py | 63 ++++++++++++++++++++++++++++++++++ 2 files changed, 83 insertions(+), 18 deletions(-) diff --git a/movement/io/file_validators.py b/movement/io/file_validators.py index 11a53032..2e153873 100644 --- a/movement/io/file_validators.py +++ b/movement/io/file_validators.py @@ -55,6 +55,17 @@ def path_is_not_dir(self, attribute, value): f"Expected a file path but got a directory: {value}." ) + @path.validator + def file_exists_when_expected(self, attribute, value): + """Ensures that the file exists (or not) depending on the expected + usage (read and/or write).""" + if "r" in self.expected_permission: + if not value.exists(): + raise FileNotFoundError(f"File {value} does not exist.") + else: # expected_permission is 'w' + if value.exists(): + raise FileExistsError(f"File {value} already exists.") + @path.validator def file_has_access_permissions(self, attribute, value): """Ensures that the file has the expected access permission(s). @@ -66,23 +77,12 @@ def file_has_access_permissions(self, attribute, value): "Make sure that you have read permissions for it." ) if "w" in self.expected_permission: - if not os.access(value, os.W_OK): + if not os.access(value.parent, os.W_OK): raise PermissionError( f"Unable to write to file: {value}. " "Make sure that you have write permissions for it." ) - @path.validator - def file_exists_when_expected(self, attribute, value): - """Ensures that the file exists (or not) depending on the expected - usage (read and/or write).""" - if "r" in self.expected_permission: - if not value.exists(): - raise FileNotFoundError(f"File {value} does not exist.") - else: # expected_permission is 'w' - if value.exists(): - raise FileExistsError(f"File {value} already exists.") - @path.validator def file_has_expected_suffix(self, attribute, value): """Ensures that the file has one of the expected suffix(es).""" @@ -119,11 +119,13 @@ class ValidHDF5: @path.validator def file_is_h5(self, attribute, value): """Ensure that the file is indeed in HDF5 format.""" - with h5py.File(value, "r") as f: - if not isinstance(f, h5py.File): - raise ValueError( - f"Expected an HDF5 file but got {type(f)}: {value}." - ) + try: + with h5py.File(value, "r") as f: + f.close() + except Exception as e: + raise ValueError( + f"File {value} does not seem to be in valid" "HDF5 format." + ) from e @path.validator def file_contains_expected_datasets(self, attribute, value): @@ -176,7 +178,7 @@ def csv_file_contains_expected_levels(self, attribute, value): ] if not all(level_in_header_row_starts): raise ValueError( - f"The header rows of the CSV file {value.path} do not " + f"The header rows of the CSV file {value} do not " "contain all expected index column levels " f"{expected_levels}." ) diff --git a/tests/test_unit/test_io.py b/tests/test_unit/test_io.py index ab089b20..6f56b12a 100644 --- a/tests/test_unit/test_io.py +++ b/tests/test_unit/test_io.py @@ -9,6 +9,7 @@ from movement.datasets import fetch_pose_data_path from movement.io import PoseTracks +from movement.io.file_validators import ValidFile, ValidHDF5, ValidPosesCSV from movement.io.tracks_validators import ValidPoseTracks @@ -106,11 +107,26 @@ def invalid_files(self, tmp_path): nonexistent_file = tmp_path / "nonexistent.h5" + directory = tmp_path / "directory" + directory.mkdir() + + fake_h5_file = tmp_path / "fake.h5" + with open(fake_h5_file, "w") as f: + f.write("") + + fake_csv_file = tmp_path / "fake.csv" + with open(fake_csv_file, "w") as f: + f.write("some,columns\n") + f.write("1,2") + return { "unreadable": unreadable_file, "wrong_ext": wrong_ext_file, "no_dataframe": h5_file_no_dataframe, "nonexistent": nonexistent_file, + "directory": directory, + "fake_h5": fake_h5_file, + "fake_csv": fake_csv_file, } @pytest.fixture @@ -148,6 +164,53 @@ def test_load_from_valid_files(self, valid_files): assert ds.source_file == file_path.as_posix() assert ds.fps is None + def test_file_validation(self, invalid_files): + """Test that loading from invalid files path raises the + appropriate errors.""" + for file_type, file_path in invalid_files.items(): + if file_type == "unreadable": + with pytest.raises(PermissionError): + ValidFile(path=file_path, expected_permission="r") + elif file_type == "wrong_ext": + with pytest.raises(ValueError): + ValidFile( + path=file_path, + expected_permission="r", + expected_suffix=["h5", "csv"], + ) + elif file_type == "nonexistent": + with pytest.raises(FileNotFoundError): + ValidFile(path=file_path, expected_permission="r") + elif file_type == "directory": + with pytest.raises(IsADirectoryError): + ValidFile(path=file_path, expected_permission="r") + elif file_type in ["fake_h5", "no_dataframe"]: + with pytest.raises(ValueError): + ValidHDF5(path=file_path, expected_datasets=["dataframe"]) + elif file_type == "fake_csv": + with pytest.raises(ValueError): + ValidPosesCSV(path=file_path) + + def test_write_to_dlc_file( + self, sleap_file_h5_multi, invalid_files, tmp_path + ): + """Test that writing pose tracks to DLC .h5 and .csv files and then + reading them back in returns the same Dataset.""" + ds = PoseTracks.from_sleap_file(sleap_file_h5_multi) + ds.to_dlc_file(tmp_path / "dlc.h5") + ds.to_dlc_file(tmp_path / "dlc.csv") + ds_from_h5 = PoseTracks.from_dlc_file(tmp_path / "dlc.h5") + ds_from_csv = PoseTracks.from_dlc_file(tmp_path / "dlc.csv") + assert_allclose(ds_from_h5, ds) + assert_allclose(ds_from_csv, ds) + + with pytest.raises(FileExistsError): + ds.to_dlc_file(invalid_files["fake_h5"]) + with pytest.raises(ValueError): + ds.to_dlc_file(tmp_path / "dlc.txt") + with pytest.raises(IsADirectoryError): + ds.to_dlc_file(invalid_files["directory"]) + def test_load_from_dlc_file_csv_or_h5_file_returns_same( self, dlc_file_h5_single, dlc_file_csv_single ): From 3c07c3a70a2890f70a1ce2e699cc31bc8ea30e77 Mon Sep 17 00:00:00 2001 From: niksirbi Date: Mon, 7 Aug 2023 19:31:37 +0100 Subject: [PATCH 31/79] fix code smells --- movement/io/file_validators.py | 24 ++++++++++++------------ movement/io/tracks_validators.py | 22 ++++++++++------------ 2 files changed, 22 insertions(+), 24 deletions(-) diff --git a/movement/io/file_validators.py b/movement/io/file_validators.py index 2e153873..c9d278ce 100644 --- a/movement/io/file_validators.py +++ b/movement/io/file_validators.py @@ -70,18 +70,18 @@ def file_exists_when_expected(self, attribute, value): def file_has_access_permissions(self, attribute, value): """Ensures that the file has the expected access permission(s). Raises a PermissionError if not.""" - if "r" in self.expected_permission: - if not os.access(value, os.R_OK): - raise PermissionError( - f"Unable to read file: {value}. " - "Make sure that you have read permissions for it." - ) - if "w" in self.expected_permission: - if not os.access(value.parent, os.W_OK): - raise PermissionError( - f"Unable to write to file: {value}. " - "Make sure that you have write permissions for it." - ) + file_is_readable = os.access(value, os.R_OK) + parent_is_writeable = os.access(value.parent, os.W_OK) + if ("r" in self.expected_permission) and (not file_is_readable): + raise PermissionError( + f"Unable to read file: {value}. " + "Make sure that you have read permissions for it." + ) + if ("w" in self.expected_permission) and (not parent_is_writeable): + raise PermissionError( + f"Unable to write to file: {value}. " + "Make sure that you have write permissions." + ) @path.validator def file_has_expected_suffix(self, attribute, value): diff --git a/movement/io/tracks_validators.py b/movement/io/tracks_validators.py index 4b578ad6..9f933289 100644 --- a/movement/io/tracks_validators.py +++ b/movement/io/tracks_validators.py @@ -117,21 +117,19 @@ def _validate_scores_array(self, attribute, value): @individual_names.validator def _validate_individual_names(self, attribute, value): - if value is not None: - if len(value) != self.tracks_array.shape[1]: - raise ValueError( - f"Expected {self.tracks_array.shape[1]} `{attribute}`, " - f"but got {len(value)}." - ) + if (value is not None) and (len(value) != self.tracks_array.shape[1]): + raise ValueError( + f"Expected {self.tracks_array.shape[1]} `{attribute}`, " + f"but got {len(value)}." + ) @keypoint_names.validator def _validate_keypoint_names(self, attribute, value): - if value is not None: - if len(value) != self.tracks_array.shape[2]: - raise ValueError( - f"Expected {self.tracks_array.shape[2]} `{attribute}`, " - f"but got {len(value)}." - ) + if (value is not None) and (len(value) != self.tracks_array.shape[2]): + raise ValueError( + f"Expected {self.tracks_array.shape[2]} `{attribute}`, " + f"but got {len(value)}." + ) def __attrs_post_init__(self): """Assign default values to optional attributes (if None)""" From 25732ee21fc7a908e8eaa21f6c943291904cb7cd Mon Sep 17 00:00:00 2001 From: niksirbi Date: Mon, 7 Aug 2023 19:38:04 +0100 Subject: [PATCH 32/79] use Iterable from typing --- movement/io/tracks_validators.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/movement/io/tracks_validators.py b/movement/io/tracks_validators.py index 9f933289..b772c46f 100644 --- a/movement/io/tracks_validators.py +++ b/movement/io/tracks_validators.py @@ -1,6 +1,5 @@ import logging -from collections.abc import Iterable -from typing import Any, List, Optional, Union +from typing import Any, Iterable, List, Optional, Union import numpy as np from attrs import converters, define, field From 95ab24f0a59aec5c6128e6f9ba0c4a2c5189e429 Mon Sep 17 00:00:00 2001 From: niksirbi Date: Tue, 8 Aug 2023 14:47:49 +0100 Subject: [PATCH 33/79] increase test coverage --- tests/test_unit/test_io.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/tests/test_unit/test_io.py b/tests/test_unit/test_io.py index 6f56b12a..2ba5f9c8 100644 --- a/tests/test_unit/test_io.py +++ b/tests/test_unit/test_io.py @@ -164,9 +164,25 @@ def test_load_from_valid_files(self, valid_files): assert ds.source_file == file_path.as_posix() assert ds.fps is None - def test_file_validation(self, invalid_files): - """Test that loading from invalid files path raises the - appropriate errors.""" + def test_load_from_invalid_files(self, invalid_files): + """Test that loading pose tracks from a wide variety of invalid files + raises the appropriate errors.""" + for file_path in invalid_files.values(): + with pytest.raises((OSError, ValueError)): + PoseTracks.from_dlc_file(file_path) + with pytest.raises((OSError, ValueError)): + PoseTracks.from_sleap_file(file_path) + + @pytest.mark.parametrize("file_path", [1, 1.0, True, None, [], {}]) + def test_load_with_incorrect_file_path_types(self, file_path): + """Test loading poses from a file_path with an incorrect type.""" + with pytest.raises(TypeError): + PoseTracks.from_dlc_file(file_path) + with pytest.raises(TypeError): + PoseTracks.from_sleap_file(file_path) + + def test_file_validator(self, invalid_files): + """Test that the file validator class raoses the right errors.""" for file_type, file_path in invalid_files.items(): if file_type == "unreadable": with pytest.raises(PermissionError): From 962caa62ae9aac35b7ff2f6afbe9070aeac859d7 Mon Sep 17 00:00:00 2001 From: niksirbi Date: Tue, 8 Aug 2023 14:50:09 +0100 Subject: [PATCH 34/79] correct type hints in function signatures --- movement/io/tracks_validators.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/movement/io/tracks_validators.py b/movement/io/tracks_validators.py index b772c46f..410e3c7d 100644 --- a/movement/io/tracks_validators.py +++ b/movement/io/tracks_validators.py @@ -8,7 +8,7 @@ logger = logging.getLogger(__name__) -def _list_of_str(value: Union[str, Iterable[Any]]) -> list[str]: +def _list_of_str(value: Union[str, Iterable[Any]]) -> List[str]: """Try to coerce the value into a list of strings. Otherwise, raise a ValueError.""" if type(value) is str: @@ -26,7 +26,7 @@ def _list_of_str(value: Union[str, Iterable[Any]]) -> list[str]: raise ValueError(error_msg) -def _ensure_type_ndarray(value: Any): +def _ensure_type_ndarray(value: Any) -> None: """Raise ValueError the value is a not numpy array.""" if type(value) is not np.ndarray: raise ValueError(f"Expected a numpy array, but got {type(value)}.") From 34138a68614b9eacb11c781562877a5452499075 Mon Sep 17 00:00:00 2001 From: niksirbi Date: Tue, 8 Aug 2023 18:57:18 +0100 Subject: [PATCH 35/79] refactored io module to use xarray accessor instead of subclassing --- movement/io/__init__.py | 2 +- movement/io/load_poses.py | 418 ++++++++++++++++ movement/io/pose_tracks.py | 461 ------------------ movement/io/poses_accessor.py | 83 ++++ movement/io/save_poses.py | 102 ++++ movement/io/tracks_validators.py | 157 ------ .../io/{file_validators.py => validators.py} | 153 +++++- tests/test_unit/test_io.py | 143 ++++-- 8 files changed, 852 insertions(+), 667 deletions(-) create mode 100644 movement/io/load_poses.py delete mode 100644 movement/io/pose_tracks.py create mode 100644 movement/io/poses_accessor.py create mode 100644 movement/io/save_poses.py delete mode 100644 movement/io/tracks_validators.py rename movement/io/{file_validators.py => validators.py} (53%) diff --git a/movement/io/__init__.py b/movement/io/__init__.py index 855b4d62..c95035f8 100644 --- a/movement/io/__init__.py +++ b/movement/io/__init__.py @@ -1 +1 @@ -from .pose_tracks import PoseTracks +from .poses_accessor import PosesAccessor diff --git a/movement/io/load_poses.py b/movement/io/load_poses.py new file mode 100644 index 00000000..1723016f --- /dev/null +++ b/movement/io/load_poses.py @@ -0,0 +1,418 @@ +import logging +from pathlib import Path +from typing import Optional, Union + +import h5py +import numpy as np +import pandas as pd +import xarray as xr +from sleap_io.io.slp import read_labels + +from movement.io.poses_accessor import PosesAccessor +from movement.io.validators import ( + ValidFile, + ValidHDF5, + ValidPosesCSV, + ValidPoseTracks, +) + +logger = logging.getLogger(__name__) + + +def from_dlc_df(df: pd.DataFrame, fps: Optional[float] = None) -> xr.Dataset: + """Create an xarray.Dataset from a DLC_style pandas DataFrame. + + Parameters + ---------- + df : pandas.DataFrame + DataFrame containing the pose tracks and confidence scores. Must + be formatted as in DeepLabCut output files (see Notes). + fps : float, optional + The number of frames per second in the video. If None (default), + the `time` coordinates will be in frame numbers. + + Returns + ------- + xarray.Dataset + Dataset containing the pose tracks, confidence scores, and metadata. + + Notes + ----- + The DataFrame must have a multi-index column with the following levels: + "scorer", ("individuals"), "bodyparts", "coords". The "individuals" + level may be omitted if there is only one individual in the video. + The "coords" level contains the spatial coordinates "x", "y", + as well as "likelihood" (point-wise confidence scores). + The row index corresponds to the frame number. + """ + + # read names of individuals and keypoints from the DataFrame + if "individuals" in df.columns.names: + individual_names = ( + df.columns.get_level_values("individuals").unique().to_list() + ) + else: + individual_names = ["individual_0"] + + keypoint_names = ( + df.columns.get_level_values("bodyparts").unique().to_list() + ) + + # reshape the data into (n_frames, n_individuals, n_keypoints, 3) + # where the last axis contains "x", "y", "likelihood" + tracks_with_scores = df.to_numpy().reshape( + (-1, len(individual_names), len(keypoint_names), 3) + ) + + try: + valid_data = ValidPoseTracks( + tracks_array=tracks_with_scores[:, :, :, :-1], + scores_array=tracks_with_scores[:, :, :, -1], + individual_names=individual_names, + keypoint_names=keypoint_names, + fps=fps, + ) + except ValueError as error: + logger.error(error) + raise error + else: + return _from_valid_data(valid_data) + + +def from_sleap_file( + file_path: Union[Path, str], fps: Optional[float] = None +) -> xr.Dataset: + """Load pose tracking data from a SLEAP labels or analysis file + into an xarray Dataset. + + Parameters + ---------- + file_path : pathlib.Path or str + Path to the file containing the SLEAP predictions, either in ".slp" + or ".h5" (analysis) format. See Notes for more information. + fps : float, optional + The number of frames per second in the video. If None (default), + the `time` coordinates will be in frame numbers. + + Returns + ------- + xarray.Dataset + Dataset containing the pose tracks, confidence scores, and metadata. + + Notes + ----- + The SLEAP predictions are normally saved in a ".slp" file, e.g. + "v1.predictions.slp". If this file contains both user-labeled and + predicted instances, only the predicted ones will be loaded. + + An analysis file, suffixed with ".h5" can be exported from the ".slp" + file, using either the command line tool `sleap-convert` (with the + "--format analysis" option enabled) or the SLEAP GUI (Choose + "Export Analysis HDF5…" from the "File" menu) [1]_. + + `movement` expects the tracks to be proofread before loading them, + meaning each track is interpreted as a single individual/animal. + Follow the SLEAP guide for tracking and proofreading [2]_. + + References + ---------- + .. [1] https://sleap.ai/tutorials/analysis.html + .. [2] https://sleap.ai/guides/proofreading.html + + Examples + -------- + >>> from movement.io import load_poses + >>> ds = load_poses.from_sleap_file("path/to/file.slp", fps=30) + """ + + try: + file = ValidFile( + file_path, + expected_permission="r", + expected_suffix=[".h5", ".slp"], + ) + except (OSError, ValueError) as error: + logger.error(error) + raise error + + # Load and validate data + if file.path.suffix == ".h5": + valid_data = _load_from_sleap_analysis_file(file.path, fps=fps) + else: # file.path.suffix == ".slp" + valid_data = _load_from_sleap_labels_file(file.path, fps=fps) + logger.debug(f"Validated pose tracks from {file.path}.") + + # Initialize an xarray dataset from the dictionary + ds = _from_valid_data(valid_data) + + # Add metadata as attrs + ds.attrs["source_software"] = "SLEAP" + ds.attrs["source_file"] = file.path.as_posix() + + logger.info(f"Loaded pose tracks from {file.path}:") + logger.info(ds) + return ds + + +def from_dlc_file( + file_path: Union[Path, str], fps: Optional[float] = None +) -> xr.Dataset: + """Load pose tracking data from a DeepLabCut (DLC) output file + into an xarray Dataset. + + Parameters + ---------- + file_path : pathlib.Path or str + Path to the file containing the DLC poses, either in ".h5" + or ".csv" format. + fps : float, optional + The number of frames per second in the video. If None (default), + the `time` coordinates will be in frame numbers. + + Returns + ------- + xarray.Dataset + Dataset containing the pose tracks, confidence scores, and metadata. + + Examples + -------- + >>> from movement.io import load_poses + >>> ds = load_poses.from_dlc_file("path/to/file.h5", fps=30) + """ + + try: + file = ValidFile( + file_path, + expected_permission="r", + expected_suffix=[".csv", ".h5"], + ) + except (OSError, ValueError) as error: + logger.error(error) + raise error + + # Load the DLC poses into a DataFrame + if file.path.suffix == ".csv": + df = _parse_dlc_csv_to_df(file.path) + else: # file.path.suffix == ".h5" + df = _load_df_from_dlc_h5(file.path) + + logger.debug(f"Loaded poses from {file.path} into a DataFrame.") + # Convert the DataFrame to an xarray dataset + ds = from_dlc_df(df=df, fps=fps) + + # Add metadata as attrs + ds.attrs["source_software"] = "DeepLabCut" + ds.attrs["source_file"] = file.path.as_posix() + + logger.info(f"Loaded pose tracks from {file.path}:") + logger.info(ds) + return ds + + +def _load_from_sleap_analysis_file( + file_path: Path, fps: Optional[float] +) -> ValidPoseTracks: + """Load and validate pose tracks and confidence scores from a SLEAP + analysis file. + + Parameters + ---------- + file_path : pathlib.Path + Path to the SLEAP analysis file containing predicted pose tracks. + fps : float, optional + The number of frames per second in the video. If None (default), + the `time` coordinates will be in frame units. + + Returns + ------- + movement.io.tracks_validators.ValidPoseTracks + The validated pose tracks and confidence scores. + """ + + file = ValidHDF5(file_path, expected_datasets=["tracks"]) + + with h5py.File(file.path, "r") as f: + tracks = f["tracks"][:].T + n_frames, n_keypoints, n_space, n_tracks = tracks.shape + tracks = tracks.reshape((n_frames, n_tracks, n_keypoints, n_space)) + # Create an array of NaNs for the confidence scores + scores = np.full( + (n_frames, n_tracks, n_keypoints), np.nan, dtype="float32" + ) + # If present, read the point-wise scores, and reshape them + if "point_scores" in f.keys(): + scores = f["point_scores"][:].reshape( + (n_frames, n_tracks, n_keypoints) + ) + + try: + valid_data = ValidPoseTracks( + tracks_array=tracks, + scores_array=scores, + individual_names=[n.decode() for n in f["track_names"][:]], + keypoint_names=[n.decode() for n in f["node_names"][:]], + fps=fps, + ) + except ValueError as error: + logger.error(error) + raise error + else: + return valid_data + + +def _load_from_sleap_labels_file( + file_path: Path, fps: Optional[float] +) -> ValidPoseTracks: + """Load and validate pose tracks and confidence scores from a SLEAP + labels file. + + Parameters + ---------- + file_path : pathlib.Path + Path to the SLEAP labels file containing predicted pose tracks. + fps : float, optional + The number of frames per second in the video. If None (default), + the `time` coordinates will be in frame units. + + Returns + ------- + movement.io.tracks_validators.ValidPoseTracks + The validated pose tracks and confidence scores. + """ + + file = ValidHDF5(file_path, expected_datasets=["pred_points", "metadata"]) + + labels = read_labels(file.path.as_posix()) + tracks_with_scores = labels.numpy(return_confidence=True) + + try: + valid_data = ValidPoseTracks( + tracks_array=tracks_with_scores[:, :, :, :-1], + scores_array=tracks_with_scores[:, :, :, -1], + individual_names=[track.name for track in labels.tracks], + keypoint_names=[kp.name for kp in labels.skeletons[0].nodes], + fps=fps, + ) + except ValueError as error: + logger.error(error) + raise error + else: + return valid_data + + +def _parse_dlc_csv_to_df(file_path: Path) -> pd.DataFrame: + """If poses are loaded from a DeepLabCut .csv file, the DataFrame + lacks the multi-index columns that are present in the .h5 file. This + function parses the csv file to a pandas DataFrame with multi-index + columns, i.e. the same format as in the .h5 file. + + Parameters + ---------- + file_path : pathlib.Path + Path to the DeepLabCut-style CSV file. + + Returns + ------- + pandas.DataFrame + DeepLabCut-style DataFrame with multi-index columns. + """ + + file = ValidPosesCSV(file_path, multianimal=False) + + possible_level_names = ["scorer", "individuals", "bodyparts", "coords"] + with open(file.path, "r") as f: + # if line starts with a possible level name, split it into a list + # of strings, and add it to the list of header lines + header_lines = [ + line.strip().split(",") + for line in f.readlines() + if line.split(",")[0] in possible_level_names + ] + + # Form multi-index column names from the header lines + level_names = [line[0] for line in header_lines] + column_tuples = list(zip(*[line[1:] for line in header_lines])) + columns = pd.MultiIndex.from_tuples(column_tuples, names=level_names) + + # Import the DLC poses as a DataFrame + df = pd.read_csv( + file.path, + skiprows=len(header_lines), + index_col=0, + names=np.array(columns), + ) + df.columns.rename(level_names, inplace=True) + return df + + +def _load_df_from_dlc_h5(file_path: Path) -> pd.DataFrame: + """Load pose tracks and likelihood scores from a DeepLabCut .h5 file + into a pandas DataFrame. + + Parameters + ---------- + file_path : pathlib.Path + Path to the DeepLabCut-style HDF5 file containing pose tracks. + + Returns + ------- + pandas.DataFrame + DeepLabCut-style Dataframe. + """ + + file = ValidHDF5(file_path, expected_datasets=["df_with_missing"]) + + try: + # pd.read_hdf does not always return a DataFrame + df = pd.DataFrame(pd.read_hdf(file.path, key="df_with_missing")) + except Exception as error: + logger.error(error) + raise error + else: + return df + + +def _from_valid_data(data: ValidPoseTracks) -> xr.Dataset: + """Convert already validated pose tracking data to an xarray Dataset. + + Parameters + ---------- + data : movement.io.tracks_validators.ValidPoseTracks + The validated data object. + + Returns + ------- + xarray.Dataset + Dataset containing the pose tracks, confidence scores, and metadata. + """ + + n_frames = data.tracks_array.shape[0] + n_space = data.tracks_array.shape[-1] + + # Create the time coordinate, depending on the value of fps + time_coords = np.arange(n_frames, dtype=int) + time_unit = "frames" + if data.fps is not None: + time_coords = time_coords / data.fps + time_unit = "seconds" + + DIM_NAMES = PosesAccessor.dim_names + # Convert data to an xarray.Dataset + return xr.Dataset( + data_vars={ + "pose_tracks": xr.DataArray(data.tracks_array, dims=DIM_NAMES), + "confidence": xr.DataArray(data.scores_array, dims=DIM_NAMES[:-1]), + }, + coords={ + DIM_NAMES[0]: time_coords, + DIM_NAMES[1]: data.individual_names, + DIM_NAMES[2]: data.keypoint_names, + DIM_NAMES[3]: ["x", "y", "z"][:n_space], + }, + attrs={ + "fps": data.fps, + "time_unit": time_unit, + "source_software": None, + "source_file": None, + }, + ) diff --git a/movement/io/pose_tracks.py b/movement/io/pose_tracks.py deleted file mode 100644 index 49de5fa6..00000000 --- a/movement/io/pose_tracks.py +++ /dev/null @@ -1,461 +0,0 @@ -import logging -from pathlib import Path -from typing import ClassVar, Optional, Union - -import h5py -import numpy as np -import pandas as pd -import xarray as xr -from sleap_io.io.slp import read_labels - -from movement.io.file_validators import ValidFile, ValidHDF5, ValidPosesCSV -from movement.io.tracks_validators import ValidPoseTracks - -# get logger -logger = logging.getLogger(__name__) - - -class PoseTracks(xr.Dataset): - """Dataset containing pose tracks and point-wise confidence scores. - - This is an `xarray.Dataset` object, with the following dimensions: - - `time`: the number of frames in the video - - `individuals`: the number of individuals in the video - - `keypoints`: the number of keypoints in the skeleton - - `space`: the number of spatial dimensions, either 2 or 3 - - Appropriate coordinate labels are assigned to each dimension: - list of unique names (str) for `individuals` and `keypoints`, - ['x','y',('z')] for `space`. The coordinates of the `time` dimension are - in seconds if `fps` is provided, otherwise they are in frame numbers. - - The dataset contains two data variables (`xarray.DataArray` objects): - - `pose_tracks`: with shape (`time`, `individuals`, `keypoints`, `space`) - - `confidence`: with shape (`time`, `individuals`, `keypoints`) - - The dataset may also contain following attributes as metadata: - - `fps`: the number of frames per second in the video - - `time_unit`: the unit of the `time` coordinates, frames or seconds - - `source_software`: the software from which the pose tracks were loaded - - `source_file`: the file from which the pose tracks were loaded - """ - - dim_names: ClassVar[tuple] = ( - "time", - "individuals", - "keypoints", - "space", - ) - - __slots__ = ("fps", "time_unit", "source_software", "source_file") - - @classmethod - def _from_valid_data(cls, data: ValidPoseTracks): - """Initialize a `PoseTracks` xarray.Dataset from already validated - data - i.e. a `ValidPoseTracks` object. - - Parameters - ---------- - data : movement.io.tracks_validators.ValidPoseTracks - The validated data object. - """ - - n_frames = data.tracks_array.shape[0] - n_space = data.tracks_array.shape[-1] - - # Create the time coordinate, depending on the value of fps - time_coords = np.arange(n_frames, dtype=int) - time_unit = "frames" - if data.fps is not None: - time_coords = time_coords / data.fps - time_unit = "seconds" - - # Convert data to an xarray.Dataset - return cls( - data_vars={ - "pose_tracks": xr.DataArray( - data.tracks_array, dims=cls.dim_names - ), - "confidence": xr.DataArray( - data.scores_array, dims=cls.dim_names[:-1] - ), - }, - coords={ - cls.dim_names[0]: time_coords, - cls.dim_names[1]: data.individual_names, - cls.dim_names[2]: data.keypoint_names, - cls.dim_names[3]: ["x", "y", "z"][:n_space], - }, - attrs={ - "fps": data.fps, - "time_unit": time_unit, - "source_software": None, - "source_file": None, - }, - ) - - @classmethod - def from_dlc_df(cls, df: pd.DataFrame, fps: Optional[float] = None): - """Create a `PoseTracks` dataset from a DLC_style pandas DataFrame. - - Parameters - ---------- - df : pandas DataFrame - DataFrame containing the pose tracks and confidence scores. Must - be formatted as in DeepLabCut output files (see Notes). - fps : float, optional - The number of frames per second in the video. If None (default), - the `time` coordinates will be in frame numbers. - - Notes - ----- - The DataFrame must have a multi-index column with the following levels: - "scorer", ("individuals"), "bodyparts", "coords". The "individuals" - level may be omitted if there is only one individual in the video. - The "coords" level contains the spatial coordinates "x", "y", - as well as "likelihood" (point-wise confidence scores). - The row index corresponds to the frame number. - """ - - # read names of individuals and keypoints from the DataFrame - if "individuals" in df.columns.names: - individual_names = ( - df.columns.get_level_values("individuals").unique().to_list() - ) - else: - individual_names = ["individual_0"] - - keypoint_names = ( - df.columns.get_level_values("bodyparts").unique().to_list() - ) - - # reshape the data into (n_frames, n_individuals, n_keypoints, 3) - # where the last axis contains "x", "y", "likelihood" - tracks_with_scores = df.to_numpy().reshape( - (-1, len(individual_names), len(keypoint_names), 3) - ) - - try: - valid_data = ValidPoseTracks( - tracks_array=tracks_with_scores[:, :, :, :-1], - scores_array=tracks_with_scores[:, :, :, -1], - individual_names=individual_names, - keypoint_names=keypoint_names, - fps=fps, - ) - except ValueError as error: - logger.error(error) - raise error - else: - return cls._from_valid_data(valid_data) - - @classmethod - def from_sleap_file( - cls, file_path: Union[Path, str], fps: Optional[float] = None - ): - """Load pose tracking data from a SLEAP labels or analysis file. - - Parameters - ---------- - file_path : pathlib Path or str - Path to the file containing the SLEAP predictions, either in ".slp" - or ".h5" (analysis) format. See Notes for more information. - fps : float, optional - The number of frames per second in the video. If None (default), - the `time` coordinates will be in frame numbers. - - - Notes - ----- - The SLEAP predictions are normally saved in a ".slp" file, e.g. - "v1.predictions.slp". If this file contains both user-labeled and - predicted instances, only the predicted ones will be loaded. - - An analysis file, suffixed with ".h5" can be exported from the ".slp" - file, using either the command line tool `sleap-convert` (with the - "--format analysis" option enabled) or the SLEAP GUI (Choose - "Export Analysis HDF5…" from the "File" menu) [1]_. - - `movement` expects the tracks to be proofread before loading them, - meaning each track is interpreted as a single individual/animal. - Follow the SLEAP guide for tracking and proofreading [2]_. - - References - ---------- - .. [1] https://sleap.ai/tutorials/analysis.html - .. [2] https://sleap.ai/guides/proofreading.html - - Examples - -------- - >>> from movement.io import PoseTracks - >>> poses = PoseTracks.from_sleap_file("path/to/file.slp", fps=30) - """ - - try: - file = ValidFile( - file_path, - expected_permission="r", - expected_suffix=[".h5", ".slp"], - ) - except (OSError, ValueError) as error: - logger.error(error) - raise error - - # Load and validate data - if file.path.suffix == ".h5": - valid_data = cls._load_from_sleap_analysis_file(file.path, fps=fps) - else: # file.path.suffix == ".slp" - valid_data = cls._load_from_sleap_labels_file(file.path, fps=fps) - logger.debug(f"Validated pose tracks from {file.path}.") - - # Initialize a PoseTracks dataset from the dictionary - ds = cls._from_valid_data(valid_data) - - # Add metadata as attrs - ds.attrs["source_software"] = "SLEAP" - ds.attrs["source_file"] = file.path.as_posix() - - logger.info(f"Loaded pose tracks from {file.path}:") - logger.info(ds) - return ds - - @classmethod - def from_dlc_file( - cls, file_path: Union[Path, str], fps: Optional[float] = None - ): - """Load pose tracking data from a DeepLabCut (DLC) output file. - - Parameters - ---------- - file_path : pathlib Path or str - Path to the file containing the DLC poses, either in ".h5" - or ".csv" format. - fps : float, optional - The number of frames per second in the video. If None (default), - the `time` coordinates will be in frame numbers. - - Examples - -------- - >>> from movement.io import PoseTracks - >>> poses = PoseTracks.from_dlc_file("path/to/file.h5", fps=30) - """ - - try: - file = ValidFile( - file_path, - expected_permission="r", - expected_suffix=[".csv", ".h5"], - ) - except (OSError, ValueError) as error: - logger.error(error) - raise error - - # Load the DLC poses into a DataFrame - if file.path.suffix == ".csv": - df = cls._parse_dlc_csv_to_df(file.path) - else: # file.path.suffix == ".h5" - df = cls._load_df_from_dlc_h5(file.path) - - logger.debug(f"Loaded poses from {file.path} into a DataFrame.") - # Convert the DataFrame to a PoseTracks dataset - ds = cls.from_dlc_df(df=df, fps=fps) - - # Add metadata as attrs - ds.attrs["source_software"] = "DeepLabCut" - ds.attrs["source_file"] = file.path.as_posix() - - logger.info(f"Loaded pose tracks from {file.path}:") - logger.info(ds) - return ds - - def to_dlc_df(self) -> pd.DataFrame: - """Convert the PoseTracks dataset to a DeepLabCut-style pandas - DataFrame with multi-index columns. - See the Notes section of the `from_dlc_df()` method for details. - - Returns - ------- - pandas DataFrame - - Notes - ----- - The DataFrame will have a multi-index column with the following levels: - "scorer", "individuals", "bodyparts", "coords" (even if there is only - one individual present). Regardless of the provenance of the - points-wise confidence scores, they will be referred to as - "likelihood", and stored in the "coords" level (as DeepLabCut expects). - """ - - # Concatenate the pose tracks and confidence scores into one array - tracks_with_scores = np.concatenate( - ( - self.pose_tracks.data, - self.confidence.data[..., np.newaxis], - ), - axis=-1, - ) - - # Create the DLC-style multi-index columns - # Use the DLC terminology: scorer, individuals, bodyparts, coords - scorer = ["movement"] - individuals = self.coords["individuals"].data.tolist() - bodyparts = self.coords["keypoints"].data.tolist() - # The confidence scores in DLC are referred to as "likelihood" - coords = self.coords["space"].data.tolist() + ["likelihood"] - - index_levels = ["scorer", "individuals", "bodyparts", "coords"] - columns = pd.MultiIndex.from_product( - [scorer, individuals, bodyparts, coords], names=index_levels - ) - df = pd.DataFrame( - data=tracks_with_scores.reshape(self.dims["time"], -1), - index=np.arange(self.dims["time"], dtype=int), - columns=columns, - dtype=float, - ) - logger.info("Converted PoseTracks dataset to DLC-style DataFrame.") - return df - - def to_dlc_file(self, file_path: Union[str, Path]): - """Save the dataset to a DeepLabCut-style .h5 or .csv file - - Parameters - ---------- - file_path : pathlib Path or str - Path to the file to save the DLC poses to. The file extension - must be either ".h5" (recommended) or ".csv". - """ - - try: - file = ValidFile( - file_path, - expected_permission="w", - expected_suffix=[".csv", ".h5"], - ) - except (OSError, ValueError) as error: - logger.error(error) - raise error - - # Convert the PoseTracks dataset to a DataFrame - df = self.to_dlc_df() - if file.path.suffix == ".csv": - df.to_csv(file.path, sep=",") - else: # file.path.suffix == ".h5" - df.to_hdf(file.path, key="df_with_missing") - logger.info(f"Saved PoseTracks dataset to {file.path}.") - - @staticmethod - def _load_from_sleap_analysis_file( - file_path: Path, fps: Optional[float] - ) -> ValidPoseTracks: - """Load and validate pose tracks and confidence scores from a SLEAP - analysis file""" - - file = ValidHDF5(file_path, expected_datasets=["tracks"]) - - with h5py.File(file.path, "r") as f: - tracks = f["tracks"][:].T - n_frames, n_keypoints, n_space, n_tracks = tracks.shape - tracks = tracks.reshape((n_frames, n_tracks, n_keypoints, n_space)) - # Create an array of NaNs for the confidence scores - scores = np.full( - (n_frames, n_tracks, n_keypoints), np.nan, dtype="float32" - ) - # If present, read the point-wise scores, and reshape them - if "point_scores" in f.keys(): - scores = f["point_scores"][:].reshape( - (n_frames, n_tracks, n_keypoints) - ) - - try: - valid_data = ValidPoseTracks( - tracks_array=tracks, - scores_array=scores, - individual_names=[n.decode() for n in f["track_names"][:]], - keypoint_names=[n.decode() for n in f["node_names"][:]], - fps=fps, - ) - except ValueError as error: - logger.error(error) - raise error - else: - return valid_data - - @staticmethod - def _load_from_sleap_labels_file( - file_path: Path, fps: Optional[float] - ) -> ValidPoseTracks: - """Load and validate pose tracks and confidence scores from a SLEAP - labels file.""" - - file = ValidHDF5( - file_path, expected_datasets=["pred_points", "metadata"] - ) - - labels = read_labels(file.path.as_posix()) - tracks_with_scores = labels.numpy(return_confidence=True) - - try: - valid_data = ValidPoseTracks( - tracks_array=tracks_with_scores[:, :, :, :-1], - scores_array=tracks_with_scores[:, :, :, -1], - individual_names=[track.name for track in labels.tracks], - keypoint_names=[kp.name for kp in labels.skeletons[0].nodes], - fps=fps, - ) - except ValueError as error: - logger.error(error) - raise error - else: - return valid_data - - @staticmethod - def _parse_dlc_csv_to_df(file_path: Path) -> pd.DataFrame: - """If poses are loaded from a DeepLabCut.csv file, the DataFrame - lacks the multi-index columns that are present in the .h5 file. This - function parses the csv file to a pandas DataFrame with multi-index - columns, i.e. the same format as in the .h5 file. - """ - - file = ValidPosesCSV(file_path, multianimal=False) - - possible_level_names = ["scorer", "individuals", "bodyparts", "coords"] - with open(file.path, "r") as f: - # if line starts with a possible level name, split it into a list - # of strings, and add it to the list of header lines - header_lines = [ - line.strip().split(",") - for line in f.readlines() - if line.split(",")[0] in possible_level_names - ] - - # Form multi-index column names from the header lines - level_names = [line[0] for line in header_lines] - column_tuples = list(zip(*[line[1:] for line in header_lines])) - columns = pd.MultiIndex.from_tuples(column_tuples, names=level_names) - - # Import the DLC poses as a DataFrame - df = pd.read_csv( - file.path, - skiprows=len(header_lines), - index_col=0, - names=np.array(columns), - ) - df.columns.rename(level_names, inplace=True) - return df - - @staticmethod - def _load_df_from_dlc_h5(file_path: Path) -> pd.DataFrame: - """Load pose tracks and likelihood scores from a DeepLabCut .h5 file - into a pandas DataFrame.""" - - file = ValidHDF5(file_path, expected_datasets=["df_with_missing"]) - - try: - # pd.read_hdf does not always return a DataFrame - df = pd.DataFrame(pd.read_hdf(file.path, key="df_with_missing")) - except Exception as error: - logger.error(error) - raise error - else: - return df diff --git a/movement/io/poses_accessor.py b/movement/io/poses_accessor.py new file mode 100644 index 00000000..4a92476a --- /dev/null +++ b/movement/io/poses_accessor.py @@ -0,0 +1,83 @@ +import logging +from typing import ClassVar + +import xarray as xr + +from movement.io.validators import ValidPoseTracks + +logger = logging.getLogger(__name__) + +# Preserve the attributes (metadata) of xarray objects after operations +xr.set_options(keep_attrs=True) + + +@xr.register_dataset_accessor("poses") +class PosesAccessor: + """An accessor that extends an `xarray.Dataset` object. + + The `xarray.Dataset` has the following dimensions: + - `time`: the number of frames in the video + - `individuals`: the number of individuals in the video + - `keypoints`: the number of keypoints in the skeleton + - `space`: the number of spatial dimensions, either 2 or 3 + + Appropriate coordinate labels are assigned to each dimension: + list of unique names (str) for `individuals` and `keypoints`, + ['x','y',('z')] for `space`. The coordinates of the `time` dimension are + in seconds if `fps` is provided, otherwise they are in frame numbers. + + The dataset contains two data variables (`xarray.DataArray` objects): + - `pose_tracks`: with shape (`time`, `individuals`, `keypoints`, `space`) + - `confidence`: with shape (`time`, `individuals`, `keypoints`) + + The dataset may also contain following attributes as metadata: + - `fps`: the number of frames per second in the video + - `time_unit`: the unit of the `time` coordinates, frames or seconds + - `source_software`: the software from which the pose tracks were loaded + - `source_file`: the file from which the pose tracks were loaded + + Notes + ----- + Using an acessor is the recommended way to extend xarray objects. + See [1]_ for more details. + + Methods/properties that are specific to this class can be used via + the `.poses` accessor, e.g. `ds.poses.to_dlc_df()`. + + References + ---------- + .. _1: https://docs.xarray.dev/en/stable/internals/extending-xarray.html + """ + + # Names of the expected dimensions in the dataset + dim_names: ClassVar[tuple] = ( + "time", + "individuals", + "keypoints", + "space", + ) + + # Names of the expected data variables in the dataset + var_names: ClassVar[tuple] = ( + "pose_tracks", + "confidence", + ) + + def __init__(self, ds: xr.Dataset): + self._obj = ds + + def validate(self) -> None: + """Validate the PoseTracks dataset.""" + fps = self._obj.attrs.get("fps", None) + try: + ValidPoseTracks( + tracks_array=self._obj[self.var_names[0]].values, + scores_array=self._obj[self.var_names[1]].values, + individual_names=self._obj.coords[self.dim_names[1]].values, + keypoint_names=self._obj.coords[self.dim_names[2]].values, + fps=fps, + ) + except Exception as e: + error_msg = "The dataset does not contain valid pose tracks." + logger.error(error_msg) + raise ValueError(error_msg) from e diff --git a/movement/io/save_poses.py b/movement/io/save_poses.py new file mode 100644 index 00000000..21a47cfa --- /dev/null +++ b/movement/io/save_poses.py @@ -0,0 +1,102 @@ +import logging +from pathlib import Path +from typing import Union + +import numpy as np +import pandas as pd +import xarray as xr + +from movement.io.validators import ValidFile + +logger = logging.getLogger(__name__) + + +def to_dlc_df(ds: xr.Dataset) -> pd.DataFrame: + """Convert an xarray dataset containing pose tracks into a + DeepLabCut-style pandas DataFrame with multi-index columns. + + Parameters + ---------- + ds : xarray Dataset + Dataset containing pose tracks, confidence scores, and metadata. + + Returns + ------- + pandas DataFrame + + Notes + ----- + The DataFrame will have a multi-index column with the following levels: + "scorer", "individuals", "bodyparts", "coords" (even if there is only + one individual present). Regardless of the provenance of the + points-wise confidence scores, they will be referred to as + "likelihood", and stored in the "coords" level (as DeepLabCut expects). + """ + + if not isinstance(ds, xr.Dataset): + error_msg = f"Expected an xarray Dataset, but got {type(ds)}. " + logger.error(error_msg) + raise ValueError(error_msg) + + ds.poses.validate() # validate the dataset + + # Concatenate the pose tracks and confidence scores into one array + tracks_with_scores = np.concatenate( + ( + ds.pose_tracks.data, + ds.confidence.data[..., np.newaxis], + ), + axis=-1, + ) + + # Create the DLC-style multi-index columns + # Use the DLC terminology: scorer, individuals, bodyparts, coords + scorer = ["movement"] + individuals = ds.coords["individuals"].data.tolist() + bodyparts = ds.coords["keypoints"].data.tolist() + # The confidence scores in DLC are referred to as "likelihood" + coords = ds.coords["space"].data.tolist() + ["likelihood"] + + index_levels = ["scorer", "individuals", "bodyparts", "coords"] + columns = pd.MultiIndex.from_product( + [scorer, individuals, bodyparts, coords], names=index_levels + ) + df = pd.DataFrame( + data=tracks_with_scores.reshape(ds.dims["time"], -1), + index=np.arange(ds.dims["time"], dtype=int), + columns=columns, + dtype=float, + ) + logger.info("Converted PoseTracks dataset to DLC-style DataFrame.") + return df + + +def to_dlc_file(ds: xr.Dataset, file_path: Union[str, Path]) -> None: + """Export the xarray dataset containing pose tracks to a + DeepLabCut-style .h5 or .csv file. + + Parameters + ---------- + ds : xarray Dataset + Dataset containing pose tracks, confidence scores, and metadata. + file_path : pathlib Path or str + Path to the file to save the DLC poses to. The file extension + must be either ".h5" (recommended) or ".csv". + """ + + try: + file = ValidFile( + file_path, + expected_permission="w", + expected_suffix=[".csv", ".h5"], + ) + except (OSError, ValueError) as error: + logger.error(error) + raise error + + df = to_dlc_df(ds) # convert to pandas DataFrame + if file.path.suffix == ".csv": + df.to_csv(file.path, sep=",") + else: # file.path.suffix == ".h5" + df.to_hdf(file.path, key="df_with_missing") + logger.info(f"Saved PoseTracks dataset to {file.path}.") diff --git a/movement/io/tracks_validators.py b/movement/io/tracks_validators.py deleted file mode 100644 index 410e3c7d..00000000 --- a/movement/io/tracks_validators.py +++ /dev/null @@ -1,157 +0,0 @@ -import logging -from typing import Any, Iterable, List, Optional, Union - -import numpy as np -from attrs import converters, define, field - -# get logger -logger = logging.getLogger(__name__) - - -def _list_of_str(value: Union[str, Iterable[Any]]) -> List[str]: - """Try to coerce the value into a list of strings. - Otherwise, raise a ValueError.""" - if type(value) is str: - warning_msg = ( - f"Invalid value ({value}). Expected a list of strings. " - "Converting to a list of length 1." - ) - logger.warning(warning_msg) - return [value] - elif isinstance(value, Iterable): - return [str(item) for item in value] - else: - error_msg = f"Invalid value ({value}). Expected a list of strings." - logger.error(error_msg) - raise ValueError(error_msg) - - -def _ensure_type_ndarray(value: Any) -> None: - """Raise ValueError the value is a not numpy array.""" - if type(value) is not np.ndarray: - raise ValueError(f"Expected a numpy array, but got {type(value)}.") - - -def _set_fps_to_none_if_invalid(fps: Optional[float]) -> Optional[float]: - """Set fps to None if a non-positive float is passed.""" - if fps is not None and fps <= 0: - logger.warning( - f"Invalid fps value ({fps}). Expected a positive number. " - "Setting fps to None." - ) - return None - return fps - - -@define(kw_only=True) -class ValidPoseTracks: - """Class for validating pose tracking data imported from files, before - they are converted to a `PoseTracks` object. - - Attributes - ---------- - tracks_array : np.ndarray - Array of shape (n_frames, n_individuals, n_keypoints, n_space) - containing the pose tracks. It will be converted to a - `xarray.DataArray` object named "pose_tracks". - scores_array : np.ndarray, optional - Array of shape (n_frames, n_individuals, n_keypoints) containing - the point-wise confidence scores. It will be converted to a - `xarray.DataArray` object named "confidence". - If None (default), the scores will be set to an array of NaNs. - individual_names : list of str, optional - List of unique names for the individuals in the video. If None - (default), the individuals will be named "individual_0", - "individual_1", etc. - keypoint_names : list of str, optional - List of unique names for the keypoints in the skeleton. If None - (default), the keypoints will be named "keypoint_0", "keypoint_1", - etc. - fps : float, optional - Frames per second of the video. Defaults to None. - """ - - # Define class attributes - tracks_array: np.ndarray = field() - scores_array: Optional[np.ndarray] = field(default=None) - individual_names: Optional[List[str]] = field( - default=None, - converter=converters.optional(_list_of_str), - ) - keypoint_names: Optional[List[str]] = field( - default=None, - converter=converters.optional(_list_of_str), - ) - fps: Optional[float] = field( - default=None, - converter=converters.pipe( # type: ignore - converters.optional(float), _set_fps_to_none_if_invalid - ), - ) - - # Add validators - @tracks_array.validator - def _validate_tracks_array(self, attribute, value): - _ensure_type_ndarray(value) - if value.ndim != 4: - raise ValueError( - f"Expected `{attribute}` to have 4 dimensions, " - f"but got {value.ndim}." - ) - if value.shape[-1] not in [2, 3]: - raise ValueError( - f"Expected `{attribute}` to have 2 or 3 spatial dimensions, " - f"but got {value.shape[-1]}." - ) - - @scores_array.validator - def _validate_scores_array(self, attribute, value): - if value is not None: - _ensure_type_ndarray(value) - if value.shape != self.tracks_array.shape[:-1]: - raise ValueError( - f"Expected `{attribute}` to have shape " - f"{self.tracks_array.shape[:-1]}, but got {value.shape}." - ) - - @individual_names.validator - def _validate_individual_names(self, attribute, value): - if (value is not None) and (len(value) != self.tracks_array.shape[1]): - raise ValueError( - f"Expected {self.tracks_array.shape[1]} `{attribute}`, " - f"but got {len(value)}." - ) - - @keypoint_names.validator - def _validate_keypoint_names(self, attribute, value): - if (value is not None) and (len(value) != self.tracks_array.shape[2]): - raise ValueError( - f"Expected {self.tracks_array.shape[2]} `{attribute}`, " - f"but got {len(value)}." - ) - - def __attrs_post_init__(self): - """Assign default values to optional attributes (if None)""" - if self.scores_array is None: - self.scores_array = np.full( - (self.tracks_array.shape[:-1]), np.nan, dtype="float32" - ) - logger.warning( - "Scores array was not provided. Setting to an array of NaNs." - ) - if self.individual_names is None: - self.individual_names = [ - f"individual_{i}" for i in range(self.tracks_array.shape[1]) - ] - logger.warning( - "Individual names were not provided. " - f"Setting to {self.individual_names}." - ) - if self.keypoint_names is None: - self.keypoint_names = [ - f"keypoint_{i}" for i in range(self.tracks_array.shape[2]) - ] - logger.warning( - "Keypoint names were not provided. " - f"Setting to {self.keypoint_names}." - ) diff --git a/movement/io/file_validators.py b/movement/io/validators.py similarity index 53% rename from movement/io/file_validators.py rename to movement/io/validators.py index c9d278ce..f256ae7a 100644 --- a/movement/io/file_validators.py +++ b/movement/io/validators.py @@ -1,10 +1,11 @@ import logging import os from pathlib import Path -from typing import List, Literal +from typing import Any, Iterable, List, Literal, Optional, Union import h5py -from attrs import define, field, validators +import numpy as np +from attrs import converters, define, field, validators # get logger logger = logging.getLogger(__name__) @@ -182,3 +183,151 @@ def csv_file_contains_expected_levels(self, attribute, value): "contain all expected index column levels " f"{expected_levels}." ) + + +def _list_of_str(value: Union[str, Iterable[Any]]) -> List[str]: + """Try to coerce the value into a list of strings. + Otherwise, raise a ValueError.""" + if type(value) is str: + warning_msg = ( + f"Invalid value ({value}). Expected a list of strings. " + "Converting to a list of length 1." + ) + logger.warning(warning_msg) + return [value] + elif isinstance(value, Iterable): + return [str(item) for item in value] + else: + error_msg = f"Invalid value ({value}). Expected a list of strings." + logger.error(error_msg) + raise ValueError(error_msg) + + +def _ensure_type_ndarray(value: Any) -> None: + """Raise ValueError the value is a not numpy array.""" + if type(value) is not np.ndarray: + raise ValueError(f"Expected a numpy array, but got {type(value)}.") + + +def _set_fps_to_none_if_invalid(fps: Optional[float]) -> Optional[float]: + """Set fps to None if a non-positive float is passed.""" + if fps is not None and fps <= 0: + logger.warning( + f"Invalid fps value ({fps}). Expected a positive number. " + "Setting fps to None." + ) + return None + return fps + + +@define(kw_only=True) +class ValidPoseTracks: + """Class for validating pose tracking data imported from a file. + + Attributes + ---------- + tracks_array : np.ndarray + Array of shape (n_frames, n_individuals, n_keypoints, n_space) + containing the pose tracks. It will be converted to a + `xarray.DataArray` object named "pose_tracks". + scores_array : np.ndarray, optional + Array of shape (n_frames, n_individuals, n_keypoints) containing + the point-wise confidence scores. It will be converted to a + `xarray.DataArray` object named "confidence". + If None (default), the scores will be set to an array of NaNs. + individual_names : list of str, optional + List of unique names for the individuals in the video. If None + (default), the individuals will be named "individual_0", + "individual_1", etc. + keypoint_names : list of str, optional + List of unique names for the keypoints in the skeleton. If None + (default), the keypoints will be named "keypoint_0", "keypoint_1", + etc. + fps : float, optional + Frames per second of the video. Defaults to None. + """ + + # Define class attributes + tracks_array: np.ndarray = field() + scores_array: Optional[np.ndarray] = field(default=None) + individual_names: Optional[List[str]] = field( + default=None, + converter=converters.optional(_list_of_str), + ) + keypoint_names: Optional[List[str]] = field( + default=None, + converter=converters.optional(_list_of_str), + ) + fps: Optional[float] = field( + default=None, + converter=converters.pipe( # type: ignore + converters.optional(float), _set_fps_to_none_if_invalid + ), + ) + + # Add validators + @tracks_array.validator + def _validate_tracks_array(self, attribute, value): + _ensure_type_ndarray(value) + if value.ndim != 4: + raise ValueError( + f"Expected `{attribute}` to have 4 dimensions, " + f"but got {value.ndim}." + ) + if value.shape[-1] not in [2, 3]: + raise ValueError( + f"Expected `{attribute}` to have 2 or 3 spatial dimensions, " + f"but got {value.shape[-1]}." + ) + + @scores_array.validator + def _validate_scores_array(self, attribute, value): + if value is not None: + _ensure_type_ndarray(value) + if value.shape != self.tracks_array.shape[:-1]: + raise ValueError( + f"Expected `{attribute}` to have shape " + f"{self.tracks_array.shape[:-1]}, but got {value.shape}." + ) + + @individual_names.validator + def _validate_individual_names(self, attribute, value): + if (value is not None) and (len(value) != self.tracks_array.shape[1]): + raise ValueError( + f"Expected {self.tracks_array.shape[1]} `{attribute}`, " + f"but got {len(value)}." + ) + + @keypoint_names.validator + def _validate_keypoint_names(self, attribute, value): + if (value is not None) and (len(value) != self.tracks_array.shape[2]): + raise ValueError( + f"Expected {self.tracks_array.shape[2]} `{attribute}`, " + f"but got {len(value)}." + ) + + def __attrs_post_init__(self): + """Assign default values to optional attributes (if None)""" + if self.scores_array is None: + self.scores_array = np.full( + (self.tracks_array.shape[:-1]), np.nan, dtype="float32" + ) + logger.warning( + "Scores array was not provided. Setting to an array of NaNs." + ) + if self.individual_names is None: + self.individual_names = [ + f"individual_{i}" for i in range(self.tracks_array.shape[1]) + ] + logger.warning( + "Individual names were not provided. " + f"Setting to {self.individual_names}." + ) + if self.keypoint_names is None: + self.keypoint_names = [ + f"keypoint_{i}" for i in range(self.tracks_array.shape[2]) + ] + logger.warning( + "Keypoint names were not provided. " + f"Setting to {self.keypoint_names}." + ) diff --git a/tests/test_unit/test_io.py b/tests/test_unit/test_io.py index 2ba5f9c8..a9c65a46 100644 --- a/tests/test_unit/test_io.py +++ b/tests/test_unit/test_io.py @@ -4,16 +4,19 @@ import numpy as np import pandas as pd import pytest -from xarray import DataArray -from xarray.testing import assert_allclose +import xarray as xr from movement.datasets import fetch_pose_data_path -from movement.io import PoseTracks -from movement.io.file_validators import ValidFile, ValidHDF5, ValidPosesCSV -from movement.io.tracks_validators import ValidPoseTracks +from movement.io import PosesAccessor, load_poses, save_poses +from movement.io.validators import ( + ValidFile, + ValidHDF5, + ValidPosesCSV, + ValidPoseTracks, +) -class TestPoseTracksIO: +class TestPosesIO: """Test the IO functionalities of the PoseTracks class.""" @pytest.fixture @@ -21,6 +24,43 @@ def valid_tracks_array(self): """Return a valid tracks array.""" return np.zeros((10, 2, 2, 2)) + @pytest.fixture + def valid_pose_dataset(self, valid_tracks_array): + """Return a valid pose tracks dataset.""" + dim_names = PosesAccessor.dim_names + return xr.Dataset( + data_vars={ + "pose_tracks": xr.DataArray( + valid_tracks_array, dims=dim_names + ), + "confidence": xr.DataArray( + valid_tracks_array[..., 0], dims=dim_names[:-1] + ), + }, + coords={ + "time": np.arange(valid_tracks_array.shape[0]), + "individuals": ["ind1", "ind2"], + "keypoints": ["key1", "key2"], + "space": ["x", "y"], + }, + attrs={ + "fps": None, + "time_unit": "frames", + "source_software": "SLEAP", + "source_file": "test.h5", + }, + ) + + @pytest.fixture + def invalid_pose_datasets(self, valid_pose_dataset): + """Return a list of invalid pose tracks datasets.""" + return { + "not_a_dataset": [1, 2, 3], + "empty_dataset": xr.Dataset(), + "missing_var": valid_pose_dataset.drop_vars("pose_tracks"), + "missing_dim": valid_pose_dataset.drop_dims("time"), + } + @pytest.fixture def dlc_file_h5_single(self): """Return the path to a valid DLC h5 file containing pose data @@ -142,20 +182,21 @@ def test_load_from_valid_files(self, valid_files): for file_type, file_path in valid_files.items(): if file_type.startswith("DLC"): - ds = PoseTracks.from_dlc_file(file_path) + ds = load_poses.from_dlc_file(file_path) elif file_type.startswith("SLEAP"): - ds = PoseTracks.from_sleap_file(file_path) + ds = load_poses.from_sleap_file(file_path) - assert isinstance(ds, PoseTracks) + assert isinstance(ds, xr.Dataset) # Expected variables are present and of right shape/type for var in ["pose_tracks", "confidence"]: assert var in ds.data_vars - assert isinstance(ds[var], DataArray) + assert isinstance(ds[var], xr.DataArray) assert ds.pose_tracks.ndim == 4 assert ds.confidence.shape == ds.pose_tracks.shape[:-1] # Check the dims and coords - assert all([i in ds.dims for i in ds.dim_names]) - for d, dim in enumerate(ds.dim_names[1:]): + DIM_NAMES = PosesAccessor.dim_names + assert all([i in ds.dims for i in DIM_NAMES]) + for d, dim in enumerate(DIM_NAMES[1:]): assert ds.dims[dim] == ds.pose_tracks.shape[d + 1] assert all([isinstance(s, str) for s in ds.coords[dim].values]) assert all([i in ds.coords["space"] for i in ["x", "y"]]) @@ -169,20 +210,20 @@ def test_load_from_invalid_files(self, invalid_files): raises the appropriate errors.""" for file_path in invalid_files.values(): with pytest.raises((OSError, ValueError)): - PoseTracks.from_dlc_file(file_path) + load_poses.from_dlc_file(file_path) with pytest.raises((OSError, ValueError)): - PoseTracks.from_sleap_file(file_path) + load_poses.from_sleap_file(file_path) @pytest.mark.parametrize("file_path", [1, 1.0, True, None, [], {}]) def test_load_with_incorrect_file_path_types(self, file_path): """Test loading poses from a file_path with an incorrect type.""" with pytest.raises(TypeError): - PoseTracks.from_dlc_file(file_path) + load_poses.from_dlc_file(file_path) with pytest.raises(TypeError): - PoseTracks.from_sleap_file(file_path) + load_poses.from_sleap_file(file_path) def test_file_validator(self, invalid_files): - """Test that the file validator class raoses the right errors.""" + """Test that the file validator class raises the right errors.""" for file_type, file_path in invalid_files.items(): if file_type == "unreadable": with pytest.raises(PermissionError): @@ -207,39 +248,50 @@ def test_file_validator(self, invalid_files): with pytest.raises(ValueError): ValidPosesCSV(path=file_path) - def test_write_to_dlc_file( - self, sleap_file_h5_multi, invalid_files, tmp_path - ): - """Test that writing pose tracks to DLC .h5 and .csv files and then - reading them back in returns the same Dataset.""" - ds = PoseTracks.from_sleap_file(sleap_file_h5_multi) - ds.to_dlc_file(tmp_path / "dlc.h5") - ds.to_dlc_file(tmp_path / "dlc.csv") - ds_from_h5 = PoseTracks.from_dlc_file(tmp_path / "dlc.h5") - ds_from_csv = PoseTracks.from_dlc_file(tmp_path / "dlc.csv") - assert_allclose(ds_from_h5, ds) - assert_allclose(ds_from_csv, ds) + def test_load_and_save_to_dlc_df(self, dlc_style_df): + """Test that loading pose tracks from a DLC-style DataFrame and + converting back to a DataFrame returns the same data values.""" + ds = load_poses.from_dlc_df(dlc_style_df) + df = save_poses.to_dlc_df(ds) + assert np.allclose(df.values, dlc_style_df.values) + def test_save_and_load_dlc_file(self, valid_pose_dataset, tmp_path): + """Test that saving pose tracks to DLC .h5 and .csv files and then + loading them back in returns the same Dataset.""" + save_poses.to_dlc_file(valid_pose_dataset, tmp_path / "dlc.h5") + save_poses.to_dlc_file(valid_pose_dataset, tmp_path / "dlc.csv") + ds_from_h5 = load_poses.from_dlc_file(tmp_path / "dlc.h5") + ds_from_csv = load_poses.from_dlc_file(tmp_path / "dlc.csv") + xr.testing.assert_allclose(ds_from_h5, valid_pose_dataset) + xr.testing.assert_allclose(ds_from_csv, valid_pose_dataset) + + def test_save_valid_dataset_to_invalid_file_paths( + self, valid_pose_dataset, invalid_files, tmp_path + ): with pytest.raises(FileExistsError): - ds.to_dlc_file(invalid_files["fake_h5"]) + save_poses.to_dlc_file( + valid_pose_dataset, invalid_files["fake_h5"] + ) with pytest.raises(ValueError): - ds.to_dlc_file(tmp_path / "dlc.txt") + save_poses.to_dlc_file(valid_pose_dataset, tmp_path / "dlc.txt") with pytest.raises(IsADirectoryError): - ds.to_dlc_file(invalid_files["directory"]) + save_poses.to_dlc_file( + valid_pose_dataset, invalid_files["directory"] + ) def test_load_from_dlc_file_csv_or_h5_file_returns_same( self, dlc_file_h5_single, dlc_file_csv_single ): """Test that loading pose tracks from DLC .csv and .h5 files return the same Dataset.""" - ds_from_h5 = PoseTracks.from_dlc_file(dlc_file_h5_single) - ds_from_csv = PoseTracks.from_dlc_file(dlc_file_csv_single) - assert_allclose(ds_from_h5, ds_from_csv) + ds_from_h5 = load_poses.from_dlc_file(dlc_file_h5_single) + ds_from_csv = load_poses.from_dlc_file(dlc_file_csv_single) + xr.testing.assert_allclose(ds_from_h5, ds_from_csv) @pytest.mark.parametrize("fps", [None, -5, 0, 30, 60.0]) def test_fps_and_time_coords(self, sleap_file_h5_multi, fps): """Test that time coordinates are set according to the fps.""" - ds = PoseTracks.from_sleap_file(sleap_file_h5_multi, fps=fps) + ds = load_poses.from_sleap_file(sleap_file_h5_multi, fps=fps) if (fps is None) or (fps <= 0): assert ds.fps is None assert ds.time_unit == "frames" @@ -251,20 +303,19 @@ def test_fps_and_time_coords(self, sleap_file_h5_multi, fps): np.arange(ds.dims["time"], dtype=int) / ds.attrs["fps"], ) - def test_from_and_to_dlc_df(self, dlc_style_df): - """Test that loading pose tracks from a DLC-style DataFrame and - converting back to a DataFrame returns the same data values.""" - ds = PoseTracks.from_dlc_df(dlc_style_df) - df = ds.to_dlc_df() - assert np.allclose(df.values, dlc_style_df.values) - def test_load_from_str_path(self, sleap_file_h5_single): """Test that file paths provided as strings are accepted as input.""" - assert_allclose( - PoseTracks.from_sleap_file(sleap_file_h5_single), - PoseTracks.from_sleap_file(sleap_file_h5_single.as_posix()), + xr.testing.assert_allclose( + load_poses.from_sleap_file(sleap_file_h5_single), + load_poses.from_sleap_file(sleap_file_h5_single.as_posix()), ) + def test_save_invalid_pose_datasets(self, invalid_pose_datasets, tmp_path): + """Test that saving invalid pose datasets raises ValueError.""" + for ds in invalid_pose_datasets.values(): + with pytest.raises(ValueError): + save_poses.to_dlc_file(ds, tmp_path / "test.h5") + @pytest.mark.parametrize( "tracks_array", [ From 31a1ed2defb2fa968a85b540c5371e342654598d Mon Sep 17 00:00:00 2001 From: niksirbi Date: Tue, 8 Aug 2023 19:11:35 +0100 Subject: [PATCH 36/79] use isinstance instead of type --- tests/test_unit/test_io.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_unit/test_io.py b/tests/test_unit/test_io.py index a9c65a46..e4e2ce19 100644 --- a/tests/test_unit/test_io.py +++ b/tests/test_unit/test_io.py @@ -369,13 +369,13 @@ def test_individual_names_validation( individual_names=individual_names, ) assert poses.individual_names == ["individual_0", "individual_1"] - elif type(individual_names) in (list, tuple): + elif isinstance(individual_names, (list, tuple)): poses = ValidPoseTracks( tracks_array=valid_tracks_array, individual_names=individual_names, ) assert poses.individual_names == [str(i) for i in individual_names] - elif type(individual_names) == str: + elif isinstance(individual_names, str): poses = ValidPoseTracks( tracks_array=np.zeros((10, 1, 2, 2)), individual_names=individual_names, @@ -413,12 +413,12 @@ def test_keypoint_names_validation( tracks_array=valid_tracks_array, keypoint_names=keypoint_names ) assert poses.keypoint_names == ["keypoint_0", "keypoint_1"] - elif type(keypoint_names) in (list, tuple): + elif isinstance(keypoint_names, (list, tuple)): poses = ValidPoseTracks( tracks_array=valid_tracks_array, keypoint_names=keypoint_names ) assert poses.keypoint_names == [str(i) for i in keypoint_names] - elif type(keypoint_names) == str: + elif isinstance(keypoint_names, str): poses = ValidPoseTracks( tracks_array=np.zeros((10, 2, 1, 2)), keypoint_names=keypoint_names, From 01fe34c6a47996d0400ca785154534bd83ee2617 Mon Sep 17 00:00:00 2001 From: niksirbi Date: Fri, 11 Aug 2023 17:02:31 +0100 Subject: [PATCH 37/79] test saving without write permissions --- tests/test_unit/test_io.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/test_unit/test_io.py b/tests/test_unit/test_io.py index e4e2ce19..38c0d9eb 100644 --- a/tests/test_unit/test_io.py +++ b/tests/test_unit/test_io.py @@ -132,11 +132,17 @@ def valid_files( @pytest.fixture def invalid_files(self, tmp_path): + """Return a dictionary containing paths to invalid files.""" unreadable_file = tmp_path / "unreadable.h5" with open(unreadable_file, "w") as f: f.write("unreadable data") os.chmod(f.name, 0o000) + unwriteable_dir = tmp_path / "no_write" + unwriteable_dir.mkdir() + os.chmod(unwriteable_dir, 0o555) + unwritable_file = unwriteable_dir / "unwritable.h5" + wrong_ext_file = tmp_path / "wrong_extension.txt" with open(wrong_ext_file, "w") as f: f.write("") @@ -161,6 +167,7 @@ def invalid_files(self, tmp_path): return { "unreadable": unreadable_file, + "unwritable": unwritable_file, "wrong_ext": wrong_ext_file, "no_dataframe": h5_file_no_dataframe, "nonexistent": nonexistent_file, @@ -228,6 +235,9 @@ def test_file_validator(self, invalid_files): if file_type == "unreadable": with pytest.raises(PermissionError): ValidFile(path=file_path, expected_permission="r") + elif file_type == "unwritable": + with pytest.raises(PermissionError): + ValidFile(path=file_path, expected_permission="w") elif file_type == "wrong_ext": with pytest.raises(ValueError): ValidFile( From d5f5a80bc201909820eef1291e1829359acc4267 Mon Sep 17 00:00:00 2001 From: niksirbi Date: Fri, 11 Aug 2023 17:13:49 +0100 Subject: [PATCH 38/79] test dlc csv validator --- tests/test_unit/test_io.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/test_unit/test_io.py b/tests/test_unit/test_io.py index 38c0d9eb..1a7b10ae 100644 --- a/tests/test_unit/test_io.py +++ b/tests/test_unit/test_io.py @@ -258,6 +258,12 @@ def test_file_validator(self, invalid_files): with pytest.raises(ValueError): ValidPosesCSV(path=file_path) + def test_dlc_poses_csv_validator(self, dlc_file_csv_single): + """Test that the validator for DLC .csv files raises error when + multianimal=True and the 'individuals' level is missing.""" + with pytest.raises(ValueError): + ValidPosesCSV(path=dlc_file_csv_single, multianimal=True) + def test_load_and_save_to_dlc_df(self, dlc_style_df): """Test that loading pose tracks from a DLC-style DataFrame and converting back to a DataFrame returns the same data values.""" From 094ebbadb5c8ff49213d2667e0c9bae6fe06d825 Mon Sep 17 00:00:00 2001 From: niksirbi Date: Fri, 11 Aug 2023 18:24:58 +0100 Subject: [PATCH 39/79] refactored logging and add log utils --- movement/__init__.py | 2 +- movement/io/load_poses.py | 98 ++++++++++---------------- movement/io/validators.py | 93 ++++++++++++++---------- movement/{log_config.py => logging.py} | 31 ++++++++ tests/conftest.py | 2 +- tests/test_unit/test_logging.py | 19 +++++ 6 files changed, 143 insertions(+), 102 deletions(-) rename movement/{log_config.py => logging.py} (71%) diff --git a/movement/__init__.py b/movement/__init__.py index 4e64c32d..283ebd28 100644 --- a/movement/__init__.py +++ b/movement/__init__.py @@ -1,5 +1,5 @@ from importlib.metadata import PackageNotFoundError, version -from movement.log_config import configure_logging +from movement.logging import configure_logging try: __version__ = version("movement") diff --git a/movement/io/load_poses.py b/movement/io/load_poses.py index 1723016f..7df3497a 100644 --- a/movement/io/load_poses.py +++ b/movement/io/load_poses.py @@ -15,6 +15,7 @@ ValidPosesCSV, ValidPoseTracks, ) +from movement.logging import log_and_raise_error logger = logging.getLogger(__name__) @@ -64,19 +65,14 @@ def from_dlc_df(df: pd.DataFrame, fps: Optional[float] = None) -> xr.Dataset: (-1, len(individual_names), len(keypoint_names), 3) ) - try: - valid_data = ValidPoseTracks( - tracks_array=tracks_with_scores[:, :, :, :-1], - scores_array=tracks_with_scores[:, :, :, -1], - individual_names=individual_names, - keypoint_names=keypoint_names, - fps=fps, - ) - except ValueError as error: - logger.error(error) - raise error - else: - return _from_valid_data(valid_data) + valid_data = ValidPoseTracks( + tracks_array=tracks_with_scores[:, :, :, :-1], + scores_array=tracks_with_scores[:, :, :, -1], + individual_names=individual_names, + keypoint_names=keypoint_names, + fps=fps, + ) + return _from_valid_data(valid_data) def from_sleap_file( @@ -125,15 +121,11 @@ def from_sleap_file( >>> ds = load_poses.from_sleap_file("path/to/file.slp", fps=30) """ - try: - file = ValidFile( - file_path, - expected_permission="r", - expected_suffix=[".h5", ".slp"], - ) - except (OSError, ValueError) as error: - logger.error(error) - raise error + file = ValidFile( + file_path, + expected_permission="r", + expected_suffix=[".h5", ".slp"], + ) # Load and validate data if file.path.suffix == ".h5": @@ -180,15 +172,11 @@ def from_dlc_file( >>> ds = load_poses.from_dlc_file("path/to/file.h5", fps=30) """ - try: - file = ValidFile( - file_path, - expected_permission="r", - expected_suffix=[".csv", ".h5"], - ) - except (OSError, ValueError) as error: - logger.error(error) - raise error + file = ValidFile( + file_path, + expected_permission="r", + expected_suffix=[".csv", ".h5"], + ) # Load the DLC poses into a DataFrame if file.path.suffix == ".csv": @@ -245,19 +233,13 @@ def _load_from_sleap_analysis_file( (n_frames, n_tracks, n_keypoints) ) - try: - valid_data = ValidPoseTracks( - tracks_array=tracks, - scores_array=scores, - individual_names=[n.decode() for n in f["track_names"][:]], - keypoint_names=[n.decode() for n in f["node_names"][:]], - fps=fps, - ) - except ValueError as error: - logger.error(error) - raise error - else: - return valid_data + return ValidPoseTracks( + tracks_array=tracks, + scores_array=scores, + individual_names=[n.decode() for n in f["track_names"][:]], + keypoint_names=[n.decode() for n in f["node_names"][:]], + fps=fps, + ) def _load_from_sleap_labels_file( @@ -285,19 +267,13 @@ def _load_from_sleap_labels_file( labels = read_labels(file.path.as_posix()) tracks_with_scores = labels.numpy(return_confidence=True) - try: - valid_data = ValidPoseTracks( - tracks_array=tracks_with_scores[:, :, :, :-1], - scores_array=tracks_with_scores[:, :, :, -1], - individual_names=[track.name for track in labels.tracks], - keypoint_names=[kp.name for kp in labels.skeletons[0].nodes], - fps=fps, - ) - except ValueError as error: - logger.error(error) - raise error - else: - return valid_data + return ValidPoseTracks( + tracks_array=tracks_with_scores[:, :, :, :-1], + scores_array=tracks_with_scores[:, :, :, -1], + individual_names=[track.name for track in labels.tracks], + keypoint_names=[kp.name for kp in labels.skeletons[0].nodes], + fps=fps, + ) def _parse_dlc_csv_to_df(file_path: Path) -> pd.DataFrame: @@ -366,10 +342,10 @@ def _load_df_from_dlc_h5(file_path: Path) -> pd.DataFrame: # pd.read_hdf does not always return a DataFrame df = pd.DataFrame(pd.read_hdf(file.path, key="df_with_missing")) except Exception as error: - logger.error(error) - raise error - else: - return df + log_and_raise_error( + error, f"Could not load a dataframe from {file.path}." + ) + return df def _from_valid_data(data: ValidPoseTracks) -> xr.Dataset: diff --git a/movement/io/validators.py b/movement/io/validators.py index f256ae7a..51f7e001 100644 --- a/movement/io/validators.py +++ b/movement/io/validators.py @@ -1,4 +1,3 @@ -import logging import os from pathlib import Path from typing import Any, Iterable, List, Literal, Optional, Union @@ -7,8 +6,7 @@ import numpy as np from attrs import converters, define, field, validators -# get logger -logger = logging.getLogger(__name__) +from movement.logging import log_and_raise_error, log_warning @define @@ -52,8 +50,9 @@ class ValidFile: def path_is_not_dir(self, attribute, value): """Ensures that the path does not point to a directory.""" if value.is_dir(): - raise IsADirectoryError( - f"Expected a file path but got a directory: {value}." + log_and_raise_error( + IsADirectoryError, + f"Expected a file path but got a directory: {value}.", ) @path.validator @@ -62,10 +61,14 @@ def file_exists_when_expected(self, attribute, value): usage (read and/or write).""" if "r" in self.expected_permission: if not value.exists(): - raise FileNotFoundError(f"File {value} does not exist.") + raise log_and_raise_error( + FileNotFoundError, f"File {value} does not exist." + ) else: # expected_permission is 'w' if value.exists(): - raise FileExistsError(f"File {value} already exists.") + raise log_and_raise_error( + FileExistsError, f"File {value} already exists." + ) @path.validator def file_has_access_permissions(self, attribute, value): @@ -74,14 +77,16 @@ def file_has_access_permissions(self, attribute, value): file_is_readable = os.access(value, os.R_OK) parent_is_writeable = os.access(value.parent, os.W_OK) if ("r" in self.expected_permission) and (not file_is_readable): - raise PermissionError( + raise log_and_raise_error( + PermissionError, f"Unable to read file: {value}. " - "Make sure that you have read permissions for it." + "Make sure that you have read permissions for it.", ) if ("w" in self.expected_permission) and (not parent_is_writeable): - raise PermissionError( + raise log_and_raise_error( + PermissionError, f"Unable to write to file: {value}. " - "Make sure that you have write permissions." + "Make sure that you have write permissions.", ) @path.validator @@ -89,9 +94,10 @@ def file_has_expected_suffix(self, attribute, value): """Ensures that the file has one of the expected suffix(es).""" if self.expected_suffix: # list is not empty if value.suffix not in self.expected_suffix: - raise ValueError( + raise log_and_raise_error( + ValueError, f"Expected file with suffix(es) {self.expected_suffix} " - f"but got suffix {value.suffix} instead." + f"but got suffix {value.suffix} instead.", ) @@ -124,8 +130,9 @@ def file_is_h5(self, attribute, value): with h5py.File(value, "r") as f: f.close() except Exception as e: - raise ValueError( - f"File {value} does not seem to be in valid" "HDF5 format." + raise log_and_raise_error( + ValueError, + f"File {value} does not seem to be in valid" "HDF5 format.", ) from e @path.validator @@ -135,9 +142,10 @@ def file_contains_expected_datasets(self, attribute, value): with h5py.File(value, "r") as f: diff = set(self.expected_datasets).difference(set(f.keys())) if len(diff) > 0: - raise ValueError( + raise log_and_raise_error( + ValueError, f"Could not find the expected dataset(s) {diff} " - f"in file: {value}. " + f"in file: {value}. ", ) @@ -178,10 +186,11 @@ def csv_file_contains_expected_levels(self, attribute, value): level in header_rows_start for level in expected_levels ] if not all(level_in_header_row_starts): - raise ValueError( + raise log_and_raise_error( + ValueError, f"The header rows of the CSV file {value} do not " "contain all expected index column levels " - f"{expected_levels}." + f"{expected_levels}.", ) @@ -189,30 +198,31 @@ def _list_of_str(value: Union[str, Iterable[Any]]) -> List[str]: """Try to coerce the value into a list of strings. Otherwise, raise a ValueError.""" if type(value) is str: - warning_msg = ( + log_warning( f"Invalid value ({value}). Expected a list of strings. " "Converting to a list of length 1." ) - logger.warning(warning_msg) return [value] elif isinstance(value, Iterable): return [str(item) for item in value] else: - error_msg = f"Invalid value ({value}). Expected a list of strings." - logger.error(error_msg) - raise ValueError(error_msg) + log_and_raise_error( + ValueError, f"Invalid value ({value}). Expected a list of strings." + ) def _ensure_type_ndarray(value: Any) -> None: """Raise ValueError the value is a not numpy array.""" if type(value) is not np.ndarray: - raise ValueError(f"Expected a numpy array, but got {type(value)}.") + raise log_and_raise_error( + ValueError, f"Expected a numpy array, but got {type(value)}." + ) def _set_fps_to_none_if_invalid(fps: Optional[float]) -> Optional[float]: """Set fps to None if a non-positive float is passed.""" if fps is not None and fps <= 0: - logger.warning( + log_warning( f"Invalid fps value ({fps}). Expected a positive number. " "Setting fps to None." ) @@ -270,14 +280,16 @@ class ValidPoseTracks: def _validate_tracks_array(self, attribute, value): _ensure_type_ndarray(value) if value.ndim != 4: - raise ValueError( + log_and_raise_error( + ValueError, f"Expected `{attribute}` to have 4 dimensions, " - f"but got {value.ndim}." + f"but got {value.ndim}.", ) if value.shape[-1] not in [2, 3]: - raise ValueError( + log_and_raise_error( + ValueError, f"Expected `{attribute}` to have 2 or 3 spatial dimensions, " - f"but got {value.shape[-1]}." + f"but got {value.shape[-1]}.", ) @scores_array.validator @@ -285,25 +297,28 @@ def _validate_scores_array(self, attribute, value): if value is not None: _ensure_type_ndarray(value) if value.shape != self.tracks_array.shape[:-1]: - raise ValueError( + log_and_raise_error( + ValueError, f"Expected `{attribute}` to have shape " - f"{self.tracks_array.shape[:-1]}, but got {value.shape}." + f"{self.tracks_array.shape[:-1]}, but got {value.shape}.", ) @individual_names.validator def _validate_individual_names(self, attribute, value): if (value is not None) and (len(value) != self.tracks_array.shape[1]): - raise ValueError( + log_and_raise_error( + ValueError, f"Expected {self.tracks_array.shape[1]} `{attribute}`, " - f"but got {len(value)}." + f"but got {len(value)}.", ) @keypoint_names.validator def _validate_keypoint_names(self, attribute, value): if (value is not None) and (len(value) != self.tracks_array.shape[2]): - raise ValueError( + log_and_raise_error( + ValueError, f"Expected {self.tracks_array.shape[2]} `{attribute}`, " - f"but got {len(value)}." + f"but got {len(value)}.", ) def __attrs_post_init__(self): @@ -312,14 +327,14 @@ def __attrs_post_init__(self): self.scores_array = np.full( (self.tracks_array.shape[:-1]), np.nan, dtype="float32" ) - logger.warning( + log_warning( "Scores array was not provided. Setting to an array of NaNs." ) if self.individual_names is None: self.individual_names = [ f"individual_{i}" for i in range(self.tracks_array.shape[1]) ] - logger.warning( + log_warning( "Individual names were not provided. " f"Setting to {self.individual_names}." ) @@ -327,7 +342,7 @@ def __attrs_post_init__(self): self.keypoint_names = [ f"keypoint_{i}" for i in range(self.tracks_array.shape[2]) ] - logger.warning( + log_warning( "Keypoint names were not provided. " f"Setting to {self.keypoint_names}." ) diff --git a/movement/log_config.py b/movement/logging.py similarity index 71% rename from movement/log_config.py rename to movement/logging.py index b788fa65..c86d3def 100644 --- a/movement/log_config.py +++ b/movement/logging.py @@ -59,3 +59,34 @@ def configure_logging( # Add the handler to the logger logger.addHandler(handler) + + +def log_and_raise_error(error, message: str, logger_name: str = "movement"): + """Log an error message and raise a ValueError. + + Parameters + ---------- + error : Exception + The error to log and raise. + message : str + The error message. + logger_name : str, optional + The name of the logger to use. Defaults to 'movement'. + """ + logger = logging.getLogger(logger_name) + logger.error(message) + raise error(message) + + +def log_warning(message: str, logger_name: str = "movement"): + """Log a warning message. + + Parameters + ---------- + message : str + The warning message. + logger_name : str, optional + The name of the logger to use. Defaults to 'movement'. + """ + logger = logging.getLogger("movement") + logger.warning(message) diff --git a/tests/conftest.py b/tests/conftest.py index 2b56fbae..1949167a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,7 +2,7 @@ import pytest -from movement.log_config import configure_logging +from movement.logging import configure_logging @pytest.fixture(autouse=True) diff --git a/tests/test_unit/test_logging.py b/tests/test_unit/test_logging.py index 6dd4b9ae..5574c420 100644 --- a/tests/test_unit/test_logging.py +++ b/tests/test_unit/test_logging.py @@ -2,6 +2,8 @@ import pytest +from movement.logging import log_and_raise_error, log_warning + log_messages = { "DEBUG": "This is a debug message", "INFO": "This is an info message", @@ -21,3 +23,20 @@ def test_logfile_contains_message(level, message): last_line = f.readlines()[-1] assert level in last_line assert message in last_line + + +def test_log_and_raise_error(caplog): + """Check if the log_and_raise_error function + logs the error message and raises a ValueError.""" + with pytest.raises(ValueError): + log_and_raise_error(ValueError, "This is a test error") + assert caplog.records[0].message == "This is a test error" + assert caplog.records[0].levelname == "ERROR" + + +def test_log_warning(caplog): + """Check if the log_warning function + logs the warning message.""" + log_warning("This is a test warning") + assert caplog.records[0].message == "This is a test warning" + assert caplog.records[0].levelname == "WARNING" From 28060fa1997d59e0dd9188e9e3cd645ea3e14125 Mon Sep 17 00:00:00 2001 From: niksirbi Date: Fri, 11 Aug 2023 18:39:25 +0100 Subject: [PATCH 40/79] use stat to change permissions --- movement/logging.py | 2 +- tests/test_unit/test_io.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/movement/logging.py b/movement/logging.py index c86d3def..d202e5fb 100644 --- a/movement/logging.py +++ b/movement/logging.py @@ -88,5 +88,5 @@ def log_warning(message: str, logger_name: str = "movement"): logger_name : str, optional The name of the logger to use. Defaults to 'movement'. """ - logger = logging.getLogger("movement") + logger = logging.getLogger(logger_name) logger.warning(message) diff --git a/tests/test_unit/test_io.py b/tests/test_unit/test_io.py index 1a7b10ae..a3803cc4 100644 --- a/tests/test_unit/test_io.py +++ b/tests/test_unit/test_io.py @@ -1,4 +1,5 @@ import os +import stat import h5py import numpy as np @@ -136,11 +137,11 @@ def invalid_files(self, tmp_path): unreadable_file = tmp_path / "unreadable.h5" with open(unreadable_file, "w") as f: f.write("unreadable data") - os.chmod(f.name, 0o000) + os.chmod(f.name, not stat.S_IRUSR) unwriteable_dir = tmp_path / "no_write" unwriteable_dir.mkdir() - os.chmod(unwriteable_dir, 0o555) + os.chmod(unwriteable_dir, not stat.S_IWUSR) unwritable_file = unwriteable_dir / "unwritable.h5" wrong_ext_file = tmp_path / "wrong_extension.txt" From f374e4ef2432ac332afd1f39a9de7c97a603b057 Mon Sep 17 00:00:00 2001 From: niksirbi Date: Fri, 11 Aug 2023 20:02:10 +0100 Subject: [PATCH 41/79] added sphinx design to docs requirements --- docs/requirements.txt | 1 + docs/source/conf.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index b5a754c8..c1e1d87e 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -5,3 +5,4 @@ pydata-sphinx-theme setuptools-scm sphinx sphinx-autodoc-typehints +sphinx-design diff --git a/docs/source/conf.py b/docs/source/conf.py index 9b7efb35..255bd724 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -43,6 +43,7 @@ "sphinx.ext.intersphinx", "myst_parser", "nbsphinx", + "sphinx_design", ] # Configure the myst parser to enable cool markdown features @@ -112,7 +113,7 @@ # The default is the URL of the GitHub pages # https://www.sphinx-doc.org/en/master/usage/extensions/githubpages.html github_user = "neuroinformatics-unit" -html_baseurl = f"https://neuroinformatics-unit.github.io/movement/" +html_baseurl = "https://neuroinformatics-unit.github.io/movement/" # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, From f92b48fabf995e1cdd72ce64c4ad42d6027d5d85 Mon Sep 17 00:00:00 2001 From: niksirbi Date: Fri, 11 Aug 2023 20:02:59 +0100 Subject: [PATCH 42/79] docs add home page and gettings started guide --- docs/source/getting_started.md | 102 +++++++++++++++++++++++++++++++-- docs/source/index.md | 29 ++++++++++ docs/source/index.rst | 13 ----- 3 files changed, 127 insertions(+), 17 deletions(-) create mode 100644 docs/source/index.md delete mode 100644 docs/source/index.rst diff --git a/docs/source/getting_started.md b/docs/source/getting_started.md index 169a31d6..b4b96d5b 100644 --- a/docs/source/getting_started.md +++ b/docs/source/getting_started.md @@ -1,11 +1,105 @@ -# Getting started +# Getting Started -Here you may demonstrate the basic functionalities your package. +## Installation -You can include code snippets using the usual Markdown syntax: +We recommend you use install `movement` inside a [conda](https://docs.conda.io/en/latest/) +or [mamba](https://mamba.readthedocs.io/en/latest/index.html) environment. +In the following we assume you have `conda` installed, +but the same commands will also work with `mamba`/`micromamba`. + + +First, create and activate an environment. +You can call your environment whatever you like, we've used `movement-env`. + +```sh +conda create -n movement-env -c conda-forge python=3.10 pytables +conda activate movement-env +``` + +Next install the `movement` package: + +::::{tab-set} + +:::{tab-item} Users +To get the latest release from PyPI: + +```sh +pip install movement +``` +If you have an older version of `movement` installed in the same environment, +you can update to the latest version with: + +```sh +pip install --upgrade movement +``` +::: + +:::{tab-item} Developers +To get the latest development version, clone the +[GitHub repository](https://neuroinformatics-unit.github.io/movement/) +and then run from inside the repository: + +```sh +pip install -e .[dev] # works on most shells +pip install -e '.[dev]' # works on zsh (the default shell on macOS) +``` + +This will install the package in editable mode, including all `dev` dependencies. +::: + +:::: + + +## Usage + +### Loading data +You can load predicted pose tracks for the pose estimation software packages +[DeepLabCut](http://www.mackenziemathislab.org/deeplabcut) or [SLEAP](https://sleap.ai/). + +First import the `load_poses` function from the `movement.io` module: ```python from movement.io import load_poses +``` + +Then, use the `from_dlc_file` or `from_sleap_file` functions to load the data. + +::::{tab-set} + +:::{tab-item} SLEAP + +Load from [SLEAP analysis files](https://sleap.ai/tutorials/analysis.html) (`.h5`): +```python +ds = load_poses.from_sleap_file("/path/to/file.analysis.h5", fps=30) +``` + +Alternatively, you can also directly load from `.slp` files, +assuming they contain predicted poses. +```python +ds = load_poses.from_sleap_file("/path/to/file.predictions.slp", fps=30) +``` +::: -df = load_poses.from_dlc('path/to/file.h5') +:::{tab-item} DeepLabCut + +Load pose estimation outputs from `.h5` files: +```python +ds = load_poses.from_dlc_file("/path/to/file.h5", fps=30) +``` + +You may also load `.csv` files (assuming they are formatted as DeepLabCut expects them): +```python +ds = load_poses.from_dlc_file("/path/to/file.csv", fps=30) ``` + +If you have already imported the data into a pandas DataFrame, you can +convert it to a `movement` dataset with: +```python +import pandas as pd + +df = pd.read_hdf("/path/to/file.h5") +ds = load_poses.from_dlc_df(df, fps=30) +``` +::: + +:::: diff --git a/docs/source/index.md b/docs/source/index.md new file mode 100644 index 00000000..58dc7cf8 --- /dev/null +++ b/docs/source/index.md @@ -0,0 +1,29 @@ +# movement + +Kinematic analysis of animal 🐝 🦀 🐀 🐒 body movements for neuroscience and ethology research. + +:::{warning} +- 🏗️ The package is currently in early development. Stay tuned ⌛ +- It is not sufficiently tested to be used for scientific analysis +- The interface is subject changes. [Open an issue](https://github.com/neuroinformatics-unit/movement/issues) if you have suggestions. +::: + +## Aims +* Load pose tracks from pose estimation software packages (e.g. [DeepLabCut](http://www.mackenziemathislab.org/deeplabcut) or [SLEAP](https://sleap.ai/)) +* Evaluate the quality of the tracks and perform data cleaning operations +* Calculate kinematic variables (e.g. speed, acceleration, joint angles, etc.) +* Produce reports and visualise the results + +## Related projects +The following projects cover related needs and served as inspiration for this project: +* [DLC2Kinematics](https://github.com/AdaptiveMotorControlLab/DLC2Kinematics) +* [PyRat](https://github.com/pyratlib/pyrat) +* [Kino](https://github.com/BrancoLab/Kino) +* [WAZP](https://github.com/SainsburyWellcomeCentre/WAZP) + + +```{toctree} +:maxdepth: 2 + +getting_started +``` diff --git a/docs/source/index.rst b/docs/source/index.rst deleted file mode 100644 index f9776ca5..00000000 --- a/docs/source/index.rst +++ /dev/null @@ -1,13 +0,0 @@ -Welcome to movement's documentation! -========================================================= - -.. toctree:: - :maxdepth: 2 - :caption: Contents: - - getting_started - -Index & Search --------------- -* :ref:`genindex` -* :ref:`search` From 93d1e9f5ab358ea006a614e4d42ef2ff81bb1469 Mon Sep 17 00:00:00 2001 From: niksirbi Date: Tue, 15 Aug 2023 10:24:07 +0100 Subject: [PATCH 43/79] remove git hash from version shown in docs --- docs/source/conf.py | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/conf.py b/docs/source/conf.py index 255bd724..d1266e03 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -25,6 +25,7 @@ author = "Niko Sirmpilatze" try: release = setuptools_scm.get_version(root="../..", relative_to=__file__) + release = release.split("+")[0] # remove git hash except LookupError: # if git is not initialised, still allow local build # with a dummy version From bc0293797235b3e66fb4f4bd74fd8082618ef308 Mon Sep 17 00:00:00 2001 From: niksirbi Date: Tue, 15 Aug 2023 10:32:57 +0100 Subject: [PATCH 44/79] temporarily allow doc deployment from pose-tracks-io branch --- .github/workflows/docs_build_and_deploy.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/docs_build_and_deploy.yml b/.github/workflows/docs_build_and_deploy.yml index a3a4995f..68bbabc2 100644 --- a/.github/workflows/docs_build_and_deploy.yml +++ b/.github/workflows/docs_build_and_deploy.yml @@ -9,6 +9,7 @@ on: push: branches: - main + - pose-tracks-io tags: - '*' pull_request: @@ -26,7 +27,7 @@ jobs: needs: build_sphinx_docs permissions: contents: write - if: github.event_name == 'push' && github.ref_type == 'tag' + if: github.event_name == 'push' && (github.ref_type == 'tag' || github.ref == 'refs/heads/pose-tracks-io')) runs-on: ubuntu-latest steps: - uses: neuroinformatics-unit/actions/deploy_sphinx_docs@v2 From 68de09f23bd3fcee768ca7b51c8525cee2a216c9 Mon Sep 17 00:00:00 2001 From: niksirbi Date: Tue, 15 Aug 2023 10:40:21 +0100 Subject: [PATCH 45/79] fix syntac error in workflow file --- .github/workflows/docs_build_and_deploy.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/docs_build_and_deploy.yml b/.github/workflows/docs_build_and_deploy.yml index 68bbabc2..8768f335 100644 --- a/.github/workflows/docs_build_and_deploy.yml +++ b/.github/workflows/docs_build_and_deploy.yml @@ -27,7 +27,7 @@ jobs: needs: build_sphinx_docs permissions: contents: write - if: github.event_name == 'push' && (github.ref_type == 'tag' || github.ref == 'refs/heads/pose-tracks-io')) + if: github.event_name == 'push' && github.ref == 'refs/heads/pose-tracks-io' runs-on: ubuntu-latest steps: - uses: neuroinformatics-unit/actions/deploy_sphinx_docs@v2 From ed00ff7034dd8100369d64e3cc1a9e5425ff7b7d Mon Sep 17 00:00:00 2001 From: niksirbi Date: Tue, 15 Aug 2023 11:57:19 +0100 Subject: [PATCH 46/79] added sphinx example gallery --- .gitignore | 3 ++- MANIFEST.in | 1 + docs/requirements.txt | 1 + docs/source/conf.py | 7 +++++++ docs/source/index.md | 1 + examples/README.rst | 4 ++++ examples/load_xarray.py | 29 +++++++++++++++++++++++++++++ 7 files changed, 45 insertions(+), 1 deletion(-) create mode 100644 examples/README.rst create mode 100644 examples/load_xarray.py diff --git a/.gitignore b/.gitignore index 8046550a..ca2e1839 100644 --- a/.gitignore +++ b/.gitignore @@ -57,7 +57,8 @@ local_settings.py instance/ # Sphinx documentation -docs/_build/ +docs/build/ +docs/source/auto_examples/ # MkDocs documentation /site/ diff --git a/MANIFEST.in b/MANIFEST.in index 77097270..94e5a60d 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -6,4 +6,5 @@ exclude .cruft.json recursive-exclude * __pycache__ recursive-exclude * *.py[co] recursive-exclude docs * +recursive-exclude examples * recursive-exclude tests * diff --git a/docs/requirements.txt b/docs/requirements.txt index c1e1d87e..446e3571 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -6,3 +6,4 @@ setuptools-scm sphinx sphinx-autodoc-typehints sphinx-design +sphinx-gallery diff --git a/docs/source/conf.py b/docs/source/conf.py index d1266e03..c42ded02 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -45,6 +45,7 @@ "myst_parser", "nbsphinx", "sphinx_design", + "sphinx_gallery.gen_gallery", ] # Configure the myst parser to enable cool markdown features @@ -84,6 +85,12 @@ "**/includes/**", ] +# Configure Sphinx gallery +sphinx_gallery_conf = { + "examples_dirs": ["../../examples"], + "filename_pattern": "/*.py", # which files to execute before inclusion +} + # -- Options for HTML output ------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output html_theme = "pydata_sphinx_theme" diff --git a/docs/source/index.md b/docs/source/index.md index 58dc7cf8..7a3ebfd5 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -26,4 +26,5 @@ The following projects cover related needs and served as inspiration for this pr :maxdepth: 2 getting_started +auto_examples/index ``` diff --git a/examples/README.rst b/examples/README.rst new file mode 100644 index 00000000..311ded86 --- /dev/null +++ b/examples/README.rst @@ -0,0 +1,4 @@ +Examples +======== + +Below is a gallery of examples using `movement`. diff --git a/examples/load_xarray.py b/examples/load_xarray.py new file mode 100644 index 00000000..1164bb27 --- /dev/null +++ b/examples/load_xarray.py @@ -0,0 +1,29 @@ +""" +Load pose tracks +================ + +Load and explore example dataset of pose tracks. +""" + +# %% +# Imports +# ------- +from movement import datasets +from movement.io import load_poses + +# %% +# Fetch an example dataset +# ------------------------ +# Feel free to replace this with the path to your own dataset. +# e.g., `h5_path = "/path/to/my/data.h5"` + +h5_path = datasets.fetch_pose_data_path( + "SLEAP_two-mice_social-interaction.analysis.h5" +) + +# %% +# Load the dataset +# ---------------- + +ds = load_poses.from_sleap_file(h5_path, fps=40) +ds From e48b21fa764dee23cc8681c687bedb7aeb1cd3d9 Mon Sep 17 00:00:00 2001 From: niksirbi Date: Tue, 15 Aug 2023 11:59:13 +0100 Subject: [PATCH 47/79] modify docs workflow file to allow deployment from this PR --- .github/workflows/docs_build_and_deploy.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/docs_build_and_deploy.yml b/.github/workflows/docs_build_and_deploy.yml index 8768f335..8da2c152 100644 --- a/.github/workflows/docs_build_and_deploy.yml +++ b/.github/workflows/docs_build_and_deploy.yml @@ -13,6 +13,9 @@ on: tags: - '*' pull_request: + branches: + - main + - pose-tracks-io workflow_dispatch: jobs: @@ -27,7 +30,6 @@ jobs: needs: build_sphinx_docs permissions: contents: write - if: github.event_name == 'push' && github.ref == 'refs/heads/pose-tracks-io' runs-on: ubuntu-latest steps: - uses: neuroinformatics-unit/actions/deploy_sphinx_docs@v2 From 04ca5a5c8f1efe3f03bc49e8dd5c3bb791c042e8 Mon Sep 17 00:00:00 2001 From: niksirbi Date: Tue, 15 Aug 2023 13:01:38 +0100 Subject: [PATCH 48/79] addad matplotlib dependency to docs --- docs/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/requirements.txt b/docs/requirements.txt index 446e3571..9baf425f 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,4 +1,5 @@ linkify-it-py +matplotlib myst-parser nbsphinx pydata-sphinx-theme From e77fe17709ad1c03c2d21ab2fbdb2e4881aced7b Mon Sep 17 00:00:00 2001 From: niksirbi Date: Tue, 15 Aug 2023 14:37:01 +0100 Subject: [PATCH 49/79] add function for listing available sample data --- movement/datasets.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/movement/datasets.py b/movement/datasets.py index ae0b5dc8..48512a30 100644 --- a/movement/datasets.py +++ b/movement/datasets.py @@ -6,6 +6,7 @@ """ from pathlib import Path +from typing import List import pooch @@ -40,6 +41,11 @@ ) +def find_pose_data() -> List[str]: + """Find available sample pose data.""" + return list(POSE_DATA.registry.keys()) + + def fetch_pose_data_path(filename: str) -> Path: """Fetch sample pose data from the remote repository. From b0ff61352a142d620873f7940d96e3fd4aa1a950 Mon Sep 17 00:00:00 2001 From: niksirbi Date: Tue, 15 Aug 2023 14:37:45 +0100 Subject: [PATCH 50/79] renamed and expanded gallery example --- examples/load_and_explore_poses.py | 85 ++++++++++++++++++++++++++++++ examples/load_xarray.py | 29 ---------- 2 files changed, 85 insertions(+), 29 deletions(-) create mode 100644 examples/load_and_explore_poses.py delete mode 100644 examples/load_xarray.py diff --git a/examples/load_and_explore_poses.py b/examples/load_and_explore_poses.py new file mode 100644 index 00000000..5a425f8c --- /dev/null +++ b/examples/load_and_explore_poses.py @@ -0,0 +1,85 @@ +""" +Load and explore pose tracks +============================ + +Load and explore an example dataset of pose tracks. +""" + +# %% +# Imports +# ------- +from matplotlib import pyplot as plt + +from movement import datasets +from movement.io import load_poses + +# %% +# Fetch an example dataset +# ------------------------ +# Print a list of available datasets: + +print(datasets.find_pose_data()) + +# %% +# Fetch the path to an example dataset +# (Feel free to replace this with the path to your own dataset. +# e.g., `file_path = "/path/to/my/data.h5"`) +file_path = datasets.fetch_pose_data_path( + "SLEAP_three-mice_Aeon_proofread.analysis.h5" +) + +# %% +# Load the dataset +# ---------------- + +ds = load_poses.from_sleap_file(file_path, fps=60) +ds + +# %% +# The loaded dataset contains two data variables: +# `pose_tracks` and `confidence` +# To get the pose tracks: +pose_tracks = ds["pose_tracks"] + +# %% +# Slect and plot data with ``xarray`` +# ----------------------------------- +# You can use the ``sel`` method to index into ``xarray`` objects. +# For example, we can get a `DataArray` containing only data +# for the "centroid" keypoint of the first individual: + +da = pose_tracks.sel(individuals="AEON3B_NTP", keypoints="centroid") + +# %% +# We could plot the x,y coordinates of this keypoint over time, +# using ``xarray``'s built-in plotting methods: +da.plot.line(x="time", row="space", aspect=2, size=2.5) + +# %% +# Similarly we could plot the same keypoint's x,y coordinates +# for all individuals: + +pose_tracks.sel(keypoints="centroid").plot.line( + x="time", row="individuals", aspect=2, size=2.5 +) + +# %%s +# Trajectory plots +# ---------------- +# We are not limited to ``xarray``'s built-in plots. +# For example, we can use ``matplotlib`` to plot trajectories +# (using scatter plots): + +individuals = pose_tracks.individuals.values +for i, ind in enumerate(individuals): + da_ind = pose_tracks.sel(individuals=ind, keypoints="centroid") + plt.scatter( + da_ind.sel(space="x"), + da_ind.sel(space="y"), + s=2, + color=plt.cm.tab10(i), + label=ind, + ) + plt.xlabel("x") + plt.ylabel("y") + plt.legend() diff --git a/examples/load_xarray.py b/examples/load_xarray.py deleted file mode 100644 index 1164bb27..00000000 --- a/examples/load_xarray.py +++ /dev/null @@ -1,29 +0,0 @@ -""" -Load pose tracks -================ - -Load and explore example dataset of pose tracks. -""" - -# %% -# Imports -# ------- -from movement import datasets -from movement.io import load_poses - -# %% -# Fetch an example dataset -# ------------------------ -# Feel free to replace this with the path to your own dataset. -# e.g., `h5_path = "/path/to/my/data.h5"` - -h5_path = datasets.fetch_pose_data_path( - "SLEAP_two-mice_social-interaction.analysis.h5" -) - -# %% -# Load the dataset -# ---------------- - -ds = load_poses.from_sleap_file(h5_path, fps=40) -ds From 743480670cf280abba844a32104b4d1bd3084bfc Mon Sep 17 00:00:00 2001 From: niksirbi Date: Tue, 15 Aug 2023 14:40:11 +0100 Subject: [PATCH 51/79] added movement to docs dependencies --- docs/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/requirements.txt b/docs/requirements.txt index 9baf425f..72f6dbfc 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,5 +1,6 @@ linkify-it-py matplotlib +movement myst-parser nbsphinx pydata-sphinx-theme From 254b9b025147dfedc3270b75b6184589e0cb5e48 Mon Sep 17 00:00:00 2001 From: niksirbi Date: Tue, 15 Aug 2023 14:58:49 +0100 Subject: [PATCH 52/79] make movement docs requirement a local editable install --- docs/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index 72f6dbfc..ae1bd4b5 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,6 +1,6 @@ +-e . linkify-it-py matplotlib -movement myst-parser nbsphinx pydata-sphinx-theme From 210280a10433d18c1009a46d284c255137d726cf Mon Sep 17 00:00:00 2001 From: niksirbi Date: Tue, 15 Aug 2023 17:07:48 +0100 Subject: [PATCH 53/79] fix matrix transpose when importing from sleap analysis files --- movement/io/load_poses.py | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/movement/io/load_poses.py b/movement/io/load_poses.py index 7df3497a..03e1ad8a 100644 --- a/movement/io/load_poses.py +++ b/movement/io/load_poses.py @@ -84,8 +84,9 @@ def from_sleap_file( Parameters ---------- file_path : pathlib.Path or str - Path to the file containing the SLEAP predictions, either in ".slp" - or ".h5" (analysis) format. See Notes for more information. + Path to the file containing the SLEAP predictions, either in ".h5" + (analysis) or in ".slp" format. The analysis file is preferred. + See Notes for more information. fps : float, optional The number of frames per second in the video. If None (default), the `time` coordinates will be in frame numbers. @@ -100,11 +101,15 @@ def from_sleap_file( The SLEAP predictions are normally saved in a ".slp" file, e.g. "v1.predictions.slp". If this file contains both user-labeled and predicted instances, only the predicted ones will be loaded. + Loading from such files is intended to function primarily for single-video + prediction results. If there are multiple videos in the file, only the + first one will be used. An analysis file, suffixed with ".h5" can be exported from the ".slp" file, using either the command line tool `sleap-convert` (with the "--format analysis" option enabled) or the SLEAP GUI (Choose - "Export Analysis HDF5…" from the "File" menu) [1]_. + "Export Analysis HDF5…" from the "File" menu) [1]_. This is the + preferred format for loading pose tracks from SLEAP into `movement`. `movement` expects the tracks to be proofread before loading them, meaning each track is interpreted as a single individual/animal. @@ -220,18 +225,14 @@ def _load_from_sleap_analysis_file( file = ValidHDF5(file_path, expected_datasets=["tracks"]) with h5py.File(file.path, "r") as f: - tracks = f["tracks"][:].T - n_frames, n_keypoints, n_space, n_tracks = tracks.shape - tracks = tracks.reshape((n_frames, n_tracks, n_keypoints, n_space)) + # transpose to shape: (n_frames, n_tracks, n_keypoints, n_space) + tracks = f["tracks"][:].transpose((3, 0, 2, 1)) # Create an array of NaNs for the confidence scores - scores = np.full( - (n_frames, n_tracks, n_keypoints), np.nan, dtype="float32" - ) - # If present, read the point-wise scores, and reshape them + scores = np.full(tracks.shape[:-1], np.nan, dtype="float32") + # If present, read the point-wise scores, + # and transpose to shape: (n_frames, n_tracks, n_keypoints) if "point_scores" in f.keys(): - scores = f["point_scores"][:].reshape( - (n_frames, n_tracks, n_keypoints) - ) + scores = f["point_scores"][:].transpose((2, 0, 1)) return ValidPoseTracks( tracks_array=tracks, @@ -265,7 +266,7 @@ def _load_from_sleap_labels_file( file = ValidHDF5(file_path, expected_datasets=["pred_points", "metadata"]) labels = read_labels(file.path.as_posix()) - tracks_with_scores = labels.numpy(return_confidence=True) + tracks_with_scores = labels.numpy(untracked=False, return_confidence=True) return ValidPoseTracks( tracks_array=tracks_with_scores[:, :, :, :-1], From 4285ef7e5ea1ffc3261a54ce39b2c4d41f9330c4 Mon Sep 17 00:00:00 2001 From: niksirbi Date: Tue, 15 Aug 2023 17:10:24 +0100 Subject: [PATCH 54/79] set xarray.keep_attrs True globally --- movement/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/movement/__init__.py b/movement/__init__.py index 283ebd28..9ac02704 100644 --- a/movement/__init__.py +++ b/movement/__init__.py @@ -7,6 +7,10 @@ # package is not installed pass +# set xarray global options +import xarray as xr + +xr.set_options(keep_attrs=True, display_expand_data=False) # initialize logger upon import configure_logging() From 705e7d5b4f702155f82d1eea4d22d0af60c91f5d Mon Sep 17 00:00:00 2001 From: niksirbi Date: Tue, 15 Aug 2023 17:56:16 +0100 Subject: [PATCH 55/79] emphasise analysis files as the primary choice when loading from SLEAP --- docs/source/getting_started.md | 6 ------ movement/io/load_poses.py | 32 ++++++++++++++++---------------- 2 files changed, 16 insertions(+), 22 deletions(-) diff --git a/docs/source/getting_started.md b/docs/source/getting_started.md index b4b96d5b..321512bc 100644 --- a/docs/source/getting_started.md +++ b/docs/source/getting_started.md @@ -72,12 +72,6 @@ Load from [SLEAP analysis files](https://sleap.ai/tutorials/analysis.html) (`.h5 ```python ds = load_poses.from_sleap_file("/path/to/file.analysis.h5", fps=30) ``` - -Alternatively, you can also directly load from `.slp` files, -assuming they contain predicted poses. -```python -ds = load_poses.from_sleap_file("/path/to/file.predictions.slp", fps=30) -``` ::: :::{tab-item} DeepLabCut diff --git a/movement/io/load_poses.py b/movement/io/load_poses.py index 03e1ad8a..a414fbe6 100644 --- a/movement/io/load_poses.py +++ b/movement/io/load_poses.py @@ -84,9 +84,9 @@ def from_sleap_file( Parameters ---------- file_path : pathlib.Path or str - Path to the file containing the SLEAP predictions, either in ".h5" - (analysis) or in ".slp" format. The analysis file is preferred. - See Notes for more information. + Path to the file containing the SLEAP predictions in ".h5" + (analysis) format. Alternatively, an ".slp" (labels) file can + also be supplied (but this feature is experimental, see Notes). fps : float, optional The number of frames per second in the video. If None (default), the `time` coordinates will be in frame numbers. @@ -98,21 +98,21 @@ def from_sleap_file( Notes ----- - The SLEAP predictions are normally saved in a ".slp" file, e.g. - "v1.predictions.slp". If this file contains both user-labeled and - predicted instances, only the predicted ones will be loaded. - Loading from such files is intended to function primarily for single-video - prediction results. If there are multiple videos in the file, only the - first one will be used. - - An analysis file, suffixed with ".h5" can be exported from the ".slp" - file, using either the command line tool `sleap-convert` (with the - "--format analysis" option enabled) or the SLEAP GUI (Choose + The SLEAP predictions are normally saved in ".slp" files, e.g. + "v1.predictions.slp". An analysis file, suffixed with ".h5" can be exported + from the ".slp" file, using either the command line tool `sleap-convert` + (with the "--format analysis" option enabled) or the SLEAP GUI (Choose "Export Analysis HDF5…" from the "File" menu) [1]_. This is the preferred format for loading pose tracks from SLEAP into `movement`. - `movement` expects the tracks to be proofread before loading them, - meaning each track is interpreted as a single individual/animal. + You can also try directly loading te ".slp" file, but this feature is + experimental and doesnot work in all cases. If the ".slp" file contains + both user-labeled and predicted instances, only the predicted ones will be + loaded. If there are multiple videos in the file, only the first one will + be used. + + `movement` expects the tracks to be assigned and proofread before loading + them, meaning each track is interpreted as a single individual/animal. Follow the SLEAP guide for tracking and proofreading [2]_. References @@ -123,7 +123,7 @@ def from_sleap_file( Examples -------- >>> from movement.io import load_poses - >>> ds = load_poses.from_sleap_file("path/to/file.slp", fps=30) + >>> ds = load_poses.from_sleap_file("path/to/file.analysis.h5", fps=30) """ file = ValidFile( From 4b50c4e7bd48dc5171a47bdf1690a697828fa6f2 Mon Sep 17 00:00:00 2001 From: niksirbi Date: Tue, 15 Aug 2023 18:53:37 +0100 Subject: [PATCH 56/79] update sphinx example --- examples/load_and_explore_poses.py | 57 +++++++++++++++--------------- 1 file changed, 29 insertions(+), 28 deletions(-) diff --git a/examples/load_and_explore_poses.py b/examples/load_and_explore_poses.py index 5a425f8c..ef31b15e 100644 --- a/examples/load_and_explore_poses.py +++ b/examples/load_and_explore_poses.py @@ -18,12 +18,13 @@ # ------------------------ # Print a list of available datasets: -print(datasets.find_pose_data()) +for file_name in datasets.find_pose_data(): + print(file_name) # %% -# Fetch the path to an example dataset -# (Feel free to replace this with the path to your own dataset. -# e.g., `file_path = "/path/to/my/data.h5"`) +# Fetch the path to an example dataset. +# Feel free to replace this with the path to your own dataset. +# e.g., ``file_path = "/path/to/my/data.h5"``) file_path = datasets.fetch_pose_data_path( "SLEAP_three-mice_Aeon_proofread.analysis.h5" ) @@ -33,35 +34,35 @@ # ---------------- ds = load_poses.from_sleap_file(file_path, fps=60) -ds +print(ds) # %% # The loaded dataset contains two data variables: -# `pose_tracks` and `confidence` +# ``pose_tracks`` and ``confidence``` # To get the pose tracks: pose_tracks = ds["pose_tracks"] # %% -# Slect and plot data with ``xarray`` -# ----------------------------------- +# Select and plot data with xarray +# -------------------------------- # You can use the ``sel`` method to index into ``xarray`` objects. -# For example, we can get a `DataArray` containing only data -# for the "centroid" keypoint of the first individual: +# For example, we can get a ``DataArray`` containing only data +# for a single keypoint of the first individual: da = pose_tracks.sel(individuals="AEON3B_NTP", keypoints="centroid") +print(da) # %% -# We could plot the x,y coordinates of this keypoint over time, +# We could plot the x, y coordinates of this keypoint over time, # using ``xarray``'s built-in plotting methods: da.plot.line(x="time", row="space", aspect=2, size=2.5) # %% -# Similarly we could plot the same keypoint's x,y coordinates +# Similarly we could plot the same keypoint's x, y coordinates # for all individuals: -pose_tracks.sel(keypoints="centroid").plot.line( - x="time", row="individuals", aspect=2, size=2.5 -) +da = pose_tracks.sel(keypoints="centroid") +da.plot.line(x="time", row="individuals", aspect=2, size=2.5) # %%s # Trajectory plots @@ -70,16 +71,16 @@ # For example, we can use ``matplotlib`` to plot trajectories # (using scatter plots): -individuals = pose_tracks.individuals.values -for i, ind in enumerate(individuals): - da_ind = pose_tracks.sel(individuals=ind, keypoints="centroid") - plt.scatter( - da_ind.sel(space="x"), - da_ind.sel(space="y"), - s=2, - color=plt.cm.tab10(i), - label=ind, - ) - plt.xlabel("x") - plt.ylabel("y") - plt.legend() +mouse_name = "AEON3B_TP1" + +plt.scatter( + da.sel(individuals=mouse_name, space="x"), + da.sel(individuals=mouse_name, space="y"), + s=2, + c=da.time, + cmap="viridis", +) +plt.title(f"Trajectory of {mouse_name}") +plt.xlabel("x") +plt.ylabel("y") +plt.colorbar(label="time (sec)") From 1f77af32006c58d84b77a4d619575baedfd2cd8d Mon Sep 17 00:00:00 2001 From: niksirbi Date: Wed, 16 Aug 2023 11:29:47 +0100 Subject: [PATCH 57/79] Added API reference to docs --- .gitignore | 1 + docs/source/api_index.rst | 47 +++++++++++++++++++++++++++++++++++++++ docs/source/index.md | 28 +++++++++++++++++++++++ 3 files changed, 76 insertions(+) create mode 100644 docs/source/api_index.rst diff --git a/.gitignore b/.gitignore index ca2e1839..085834ba 100644 --- a/.gitignore +++ b/.gitignore @@ -59,6 +59,7 @@ instance/ # Sphinx documentation docs/build/ docs/source/auto_examples/ +docs/source/auto_api/ # MkDocs documentation /site/ diff --git a/docs/source/api_index.rst b/docs/source/api_index.rst new file mode 100644 index 00000000..17fc54d1 --- /dev/null +++ b/docs/source/api_index.rst @@ -0,0 +1,47 @@ +API Reference +============= + + +Input/Output +------------ +.. currentmodule:: movement.io.load_poses +.. autosummary:: + :toctree: auto_api + + from_sleap_file + from_dlc_file + from_dlc_df + +.. currentmodule:: movement.io.save_poses +.. autosummary:: + :toctree: auto_api + + to_dlc_file + to_dlc_df + +.. currentmodule:: movement.io.validators +.. autosummary:: + :toctree: auto_api + + ValidFile + ValidHDF5 + ValidPosesCSV + +Datasets +-------- +.. currentmodule:: movement.datasets +.. autosummary:: + :toctree: auto_api + + find_pose_data + fetch_pose_data_path + +Logging +------- +.. currentmodule:: movement.logging +.. autosummary:: + :toctree: auto_api + + configure_logging + log_and_raise_error + log_warning diff --git a/docs/source/index.md b/docs/source/index.md index 7a3ebfd5..896b0699 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -8,6 +8,32 @@ Kinematic analysis of animal 🐝 🦀 🐀 🐒 body movements for neuroscience - The interface is subject changes. [Open an issue](https://github.com/neuroinformatics-unit/movement/issues) if you have suggestions. ::: +::::{grid} 1 2 2 3 +:gutter: 3 + +:::{grid-item-card} {fas}`rocket;sd-text-primary` Getting Started +:link: getting_started +:link-type: doc + +Install and try it out. +::: + +:::{grid-item-card} {fas}`chalkboard-user;sd-text-primary` Examples +:link: auto_examples/index +:link-type: doc + +Example use cases. +::: + +:::{grid-item-card} {fas}`code;sd-text-primary` API Reference +:link: api_index +:link-type: doc + +Index of all functions, classes, and methods. +::: +:::: + + ## Aims * Load pose tracks from pose estimation software packages (e.g. [DeepLabCut](http://www.mackenziemathislab.org/deeplabcut) or [SLEAP](https://sleap.ai/)) * Evaluate the quality of the tracks and perform data cleaning operations @@ -24,7 +50,9 @@ The following projects cover related needs and served as inspiration for this pr ```{toctree} :maxdepth: 2 +:hidden: getting_started auto_examples/index +api_index ``` From 3925f84384fd97a4217c2d8e2ca3de73f356cf15 Mon Sep 17 00:00:00 2001 From: niksirbi Date: Wed, 16 Aug 2023 11:38:36 +0100 Subject: [PATCH 58/79] harmonised docs homepage and repo README --- README.md | 24 ++++++++++++++---------- docs/source/index.md | 13 +++++++------ 2 files changed, 21 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index 3d7905c1..46111de3 100644 --- a/README.md +++ b/README.md @@ -9,12 +9,17 @@ Kinematic analysis of animal 🐝 🦀 🐀 🐒 body movements for neuroscience and ethology research 🔬. +Read the [documentation](https://neuroinformatics-unit.github.io/movement/) for more information. + ## Status -The package is currently in early development 🏗️ and is not yet ready for use. Stay tuned ⌛ +> **Warning** +> - 🏗️ The package is currently in early development. Stay tuned ⌛ +> - It is not sufficiently tested to be used for scientific analysis +> - The interface is subject to changes. [Open an issue](https://github.com/neuroinformatics-unit/movement/issues) if you have suggestions. ## Aims -* Load keypoint tracks from pose estimation software (e.g. [DeepLabCut](http://www.mackenziemathislab.org/deeplabcut) or [SLEAP](https://sleap.ai/)) -* Evaluate the quality of the tracks and perform data cleaning +* Load pose tracks from pose estimation software packages (e.g. [DeepLabCut](http://www.mackenziemathislab.org/deeplabcut) or [SLEAP](https://sleap.ai/)) +* Evaluate the quality of the tracks and perform data cleaning operations * Calculate kinematic variables (e.g. speed, acceleration, joint angles, etc.) * Produce reports and visualise the results @@ -25,6 +30,12 @@ The following projects cover related needs and served as inspiration for this pr * [Kino](https://github.com/BrancoLab/Kino) * [WAZP](https://github.com/SainsburyWellcomeCentre/WAZP) +## License +⚖️ [BSD 3-Clause](./LICENSE) + +## Template +This package layout and configuration (including pre-commit hooks and GitHub actions) have been copied from the [python-cookiecutter](https://github.com/SainsburyWellcomeCentre/python-cookiecutter) template. + ## How to contribute ### Setup * We recommend you install `movement` inside a [conda](https://docs.conda.io/en/latest/) environment. @@ -72,10 +83,3 @@ git commit -m "Add new changes" git tag -a v1.0.0 -m "Bump to version 1.0.0" git push --follow-tags ``` - -## License - -⚖️ [BSD 3-Clause](./LICENSE) - -## Template -This package layout and configuration (including pre-commit hooks and GitHub actions) have been copied from the [python-cookiecutter](https://github.com/SainsburyWellcomeCentre/python-cookiecutter) template. diff --git a/docs/source/index.md b/docs/source/index.md index 896b0699..b429b607 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -2,12 +2,6 @@ Kinematic analysis of animal 🐝 🦀 🐀 🐒 body movements for neuroscience and ethology research. -:::{warning} -- 🏗️ The package is currently in early development. Stay tuned ⌛ -- It is not sufficiently tested to be used for scientific analysis -- The interface is subject changes. [Open an issue](https://github.com/neuroinformatics-unit/movement/issues) if you have suggestions. -::: - ::::{grid} 1 2 2 3 :gutter: 3 @@ -33,6 +27,13 @@ Index of all functions, classes, and methods. ::: :::: +## Status +:::{warning} +- 🏗️ The package is currently in early development. Stay tuned ⌛ +- It is not sufficiently tested to be used for scientific analysis +- The interface is subject to changes. [Open an issue](https://github.com/neuroinformatics-unit/movement/issues) if you have suggestions. +::: + ## Aims * Load pose tracks from pose estimation software packages (e.g. [DeepLabCut](http://www.mackenziemathislab.org/deeplabcut) or [SLEAP](https://sleap.ai/)) From 2fb1a59eefcd85107a1a0913a98d03d9951c3d34 Mon Sep 17 00:00:00 2001 From: niksirbi Date: Wed, 16 Aug 2023 14:53:08 +0100 Subject: [PATCH 59/79] added contributing guide --- CONTRIBUTING.md | 276 +++++++++++++++++++++++++++++++++ MANIFEST.in | 2 +- README.md | 53 +------ docs/source/contributing.rst | 2 + docs/source/getting_started.md | 1 + docs/source/index.md | 1 + 6 files changed, 284 insertions(+), 51 deletions(-) create mode 100644 CONTRIBUTING.md create mode 100644 docs/source/contributing.rst diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 00000000..9cc0031a --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,276 @@ +# How to Contribute + +## Introduction + +**Contributors to movement are absolutely encouraged**, whether to fix a bug, +develop a new feature, or improve the documentation. +If you're unsure about any part of the contributing process, please get in touch. +It's best to reach out in public, e.g. by [opening an issue](https://github.com/neuroinformatics-unit/movement/issues) +so that others can benefit from the discussion. + +## Contributing code + +### Creating a development environment + +It is recommended to use [conda](https://docs.conda.io/en/latest/) +or [mamba](https://mamba.readthedocs.io/en/latest/index.html) to create a +development environment for movement. In the following we assume you have +`conda` installed, but the same commands will also work with `mamba`/`micromamba`. + +First, create and activate a `conda` environment with some pre-requisites: + +```sh +conda create -n movement-dev -c conda-forge python=3.10 pytables +conda activate movement-dev +``` + +The above method ensures that you will get packages that often can't be +installed via `pip`, including [hdf5](https://www.hdfgroup.org/solutions/hdf5/). + +To install movement for development, clone the GitHub repository, +and then run from inside the repository: + +```sh +pip install -e .[dev] # works on most shells +pip install -e '.[dev]' # works on zsh (the default shell on macOS) +``` + +This will install the package in editable mode, including all dependencies +required for development. + +Finally, initialise the [pre-commit hooks](#formatting-and-pre-commit-hooks): + +```bash +pre-commit install +``` + +### Pull requests + +In all cases, please submit code to the main repository via a pull request (PR). +We recommend, and adhere, to the following conventions: + +- Please submit _draft_ PRs as early as possible to allow for discussion. +- The PR title should be descriptive e.g. "Add new function to do X" or "Fix bug in Y". +- The PR description should be used to provide context and motivation for the changes. +- One approval of a PR (by a repo owner) is enough for it to be merged. +- Unless someone approves the PR with optional comments, the PR is immediately merged by the approving reviewer. +- Ask for a review from someone specific if you think they would be a particularly suited reviewer. +- PRs are preferenitally merged via the ["squash and merge"](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/incorporating-changes-from-a-pull-request/about-pull-request-merges#squash-and-merge-your-commits) option, to keep a clean commit history on the _main_ branch. + +A typical PR workflow would be: +* Create a new branch, make your changes, and stage them. +* When you try to commit, the [pre-commit hooks](#formatting-and-pre-commit-hooks) will be triggered. +* Stage any changes made by the hooks, and commit. +* You may also run the pre-commit hooks manually, at any time, with `pre-commit run -a`. +* Make sure to write tests for any new features or bug fixes. See [testing](#testing) below. +* Don't forget to update the documentation, if necessary. See [contributing documentation](#contributing-documentation) below. +* Push your changes to GitHub and open a draft pull request, with a meaningful title and a thorough description of the changes. +* If all checks (e.g. linting, type checking, testing) run successfully, you may mark the pull request as ready for review. +* Respond to review comments and implement any requested changes. +* Sucess 🎉 !! Your PR will be (squash-)merged into the _main_ branch. + +## Development guidelines + +### Formatting and pre-commit hooks + +Running `pre-commit install` will set up [pre-commit hooks](https://pre-commit.com/) to ensure a consistent formatting style. Currently, these include: +* [ruff](https://github.com/charliermarsh/ruff) does a number of jobs, including enforcing PEP8 and sorting imports +* [black](https://black.readthedocs.io/en/stable/) for auto-formatting +* [mypy](https://mypy.readthedocs.io/en/stable/index.html) as a static type checker +* [check-manifest](https://github.com/mgedmin/check-manifest) to ensure that the right files are included in the pip package. + +These will prevent code from being committed if any of these hooks fail. To run them individually (from the root of the repository), you can use: + +```sh +ruff . +black ./ +mypy -p movement +check-manifest +``` + +To run all the hooks before committing: + +```sh +pre-commit run # for staged files +pre-commit run -a # for all files in the repository +``` + +For docstrings, we adhere to the [numpydoc](https://numpydoc.readthedocs.io/en/latest/format.html) style. + +### Testing + +We use [pytest](https://docs.pytest.org/en/latest/) for testing and aim for +~100% test coverage (as far as is reasonable). +All new features should be tested. +Write your test methods and classes in the _tests_ folder. + +For some tests, you will need to use real experimental data. +Do not include these data in the repository, especially if they are large. +We store several sample datasets in an external data repository. +See [sample data](#sample-data) for more information. + + +### Continuous integration +All pushes and pull requests will be built by [GitHub actions](https://docs.github.com/en/actions). +This will usually include linting, testing and deployment. + +A GitHub actions workflow (`.github/workflows/test_and_deploy.yml`) has been set up to run (on each push/PR): +* Linting checks (pre-commit). +* Testing (only if linting checks pass) +* Release to PyPI (only if a git tag is present and if tests pass). + +### Versioning and releases +We use [semantic versioning](https://semver.org/), which includes `MAJOR`.`MINOR`.`PATCH` version numbers: + +* PATCH = small bugfix +* MINOR = new feature +* MAJOR = breaking change + +We use [setuptools_scm](https://github.com/pypa/setuptools_scm) to automatically version movement. +It has been pre-configured in the `pyproject.toml` file. +`setuptools_scm` will automatically [infer the version using git](https://github.com/pypa/setuptools_scm#default-versioning-scheme). +To manually set a new semantic version, create a tag and make sure the tag is pushed to GitHub. +Make sure you commit any changes you wish to be included in this version. E.g. to bump the version to `1.0.0`: + +```sh +git add . +git commit -m "Add new changes" +git tag -a v1.0.0 -m "Bump to version 1.0.0" +git push --follow-tags +``` +Alternatively, you can also use the GitHub web interface to create a new release and tag. + +The addition of a GitHub tag triggers the package's deployment to PyPI. +The version number is automatically determined from the latest tag on the _main_ branch. + +## Contributing documentation + +The documentation is hosted via [GitHub pages](https://pages.github.com/) at +[neuroinformatics-unit.github.io/movement](https://neuroinformatics-unit.github.io/movement/). +Its source files are located in the `docs` folder of this repository. +They are written in either [reStructuredText](https://docutils.sourceforge.io/rst.html) or +[markdown](https://myst-parser.readthedocs.io/en/stable/syntax/typography.html). +The `index.md` file corresponds to the homepage of the documentation website. +Other `.rst` or `.md` files are linked to the homepage via the `toctree` directive. + +We use [Sphinx](https://www.sphinx-doc.org/en/master/) and the +[PyData Sphinx Theme](https://pydata-sphinx-theme.readthedocs.io/en/stable/index.html) +to build the source files into html output. +This is handled by a GitHub actions workflow (`.github/workflows/docs_build_and_deploy.yml`). +The build job is triggered on each PR, ensuring that the documentation build is not broken by new changes. +The deployment job is only triggerred whenever a tag is pushed to the _main_ branch, +ensuring that the documentation is published in sync with each PyPI release. + +### Editing the documentation + +To edit the documentation, first clone the repository, and install movement in a +[development environment](#creating-a-development-environment). + +Now open a new branch, edit the documentation source files (`.md` or `.rst` in the `docs` folder), +and commit your changes. Submit your documentation changes via a pull request, +following the [same guidelines as for code changes](#pull-requests). +Make sure that the header levels in your `.md` or `.rst` files are incremented +consistently (H1 > H2 > H3, etc.) without skipping any levels. + +If you create a new documentation source file (e.g. `my_new_file.md` or `my_new_file.rst`), +you will need to add it to the `toctree` directive in `index.md` +for it to be included in the documentation website: + +```rst +:maxdepth: 2 +:hidden: + +existing_file +my_new_file +``` + +### Updating the API reference +If your PR introduces new public-facing functions, classes, or methods, +make sure to add them to the `docs/source/api_index.rst` page, so that they are +included in the [API reference](https://neuroinformatics-unit.github.io/movement/api_index.html), +e.g.: + +```rst +My new module +-------------- +.. currentmodule:: movement.new_module +.. autosummary:: + :toctree: auto_api + + new_function + NewClass +``` + +For this to work, your functions/classes/methods will need to have docstrings +that follow the [numpydoc](https://numpydoc.readthedocs.io/en/latest/format.html) style. + +### Updating the examples +We use [sphinx-gallery](https://sphinx-gallery.github.io/stable/index.html) +to create the [examples](https://neuroinformatics-unit.github.io/movement/auto_examples/index.html). +To add new examples, you will need to create a new `.py` file in `examples/`. +The file should be structured as specified in the relevant +[sphinx-gallery documentation](https://sphinx-gallery.github.io/stable/syntax.html). + + +### Building the documentation locally +We recommend that you build and view the documentation website locally, before you push it. +To do so, first install the requirements for building the documentation: +```sh +pip install -r docs/requirements.txt +``` + +Then, from the root of the repository, run: +```sh +sphinx-build docs/source docs/build +``` + +You can view the local build by opening `docs/build/index.html` in a browser. +To refresh the documentation, after making changes, remove the `docs/build` folder and re-run the above command: + +```sh +rm -rf docs/build && sphinx-build docs/source docs/build +``` + +## Sample data + +We maintain some sample data to be used for testing, examples and tutorials on an +[external data repository](https://gin.g-node.org/neuroinformatics/movement-test-data). +Our hosting platform of choice is called [GIN](https://gin.g-node.org/) and is maintained +by the [German Neuroinformatics Node](https://www.g-node.org/). +GIN has a GitHub-like interface and git-like +[CLI](https://gin.g-node.org/G-Node/Info/wiki/GIN+CLI+Setup#quickstart) functionalities. + +Currently the data repository contains sample pose estimation data files +stored in the `poses` folder. Each file name starts with either "DLC" or "SLEAP", +depending on the pose estimation software used to generate the data. + +### Fetching data +To fetch the data from GIN, we use the [pooch](https://www.fatiando.org/pooch/latest/index.html) +Python package, which can download data from pre-specified URLs and store them +locally for all subsequent uses. It also provides some nice utilities, +like verification of sha256 hashes and decompression of archives. + +The relevant functionality is implemented in the `movement.datasets.py` module. +The most important parts of this module are: + +1. The `POSE_DATA` download manager object, which contains a list of stored files and their known hashes. +2. The `find_pose_data()` function, which returns a list of the available files in the data repository. +3. The `fetch_pose_data_path()` function, which downloads a file (if not already cached locally) and returns the local path to it. + +By default, the downloaded files are stored in the `~/.movement/data` folder. +This can be changed by setting the `DATA_DIR` variable in the `movement.datasets.py` module. + +### Adding new data +Only core movement developers may add new files to the external data repository. +To add a new file, you will need to: + +1. Create a [GIN](https://gin.g-node.org/) account +2. Ask to be added as a collaborator on the [movement data repository](https://gin.g-node.org/neuroinformatics/movement-test-data) (if not already) +3. Download the [GIN CLI](https://gin.g-node.org/G-Node/Info/wiki/GIN+CLI+Setup#quickstart) and set it up with your GIN credentials, by running `gin login` in a terminal. +4. Clone the movement data repository to your local machine, by running `gin get neuroinformatics/movement-test-data` in a terminal. +5. Add your new files and commit them with `gin commit -m `. +6. Upload the commited changes to the GIN repository, by running `gin upload`. Latest changes to the repository can be pulled via `gin download`. `gin sync` will synchronise the latest changes bidirectionally. +7. Determine the sha256 checksum hash of each new file, by running `sha256sum ` in a terminal. Alternatively, you can use `pooch` to do this for you: `python -c "import pooch; pooch.file_hash('/path/to/file')"`. If you wish to generate a text file containing the hashes of all the files in a given folder, you can use `python -c "import pooch; pooch.make_registry('/path/to/folder', 'sha256_registry.txt')`. +8. Update the `movement.datasets.py` module on the [movement GitHub repository](https://github.com/SainsburyWellcomeCentre/movement) by adding the new files to the `POSE_DATA` registry. Make sure to include the correct sha256 hash, as determined in the previous step. Follow all the usual [guidelines for contributing code](#contributing-code). Make sure to test whether the new files can be fetched successfully (see [fetching data](#fetching-data) above) before submitting your pull request. + +You can also perform steps 3-6 via the GIN web interface, if you prefer to avoid using the CLI. diff --git a/MANIFEST.in b/MANIFEST.in index 94e5a60d..cad8162a 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,5 +1,5 @@ include LICENSE -include README.md +include include *.md exclude .pre-commit-config.yaml exclude .cruft.json diff --git a/README.md b/README.md index 46111de3..d824db1b 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,8 @@ Kinematic analysis of animal 🐝 🦀 🐀 🐒 body movements for neuroscience and ethology research 🔬. -Read the [documentation](https://neuroinformatics-unit.github.io/movement/) for more information. +- Read the [documentation](https://neuroinformatics-unit.github.io/movement/) for more information. +- If you wish to contribute, please read the [contributing guide](./CONTRIBUTING.md). ## Status > **Warning** @@ -34,52 +35,4 @@ The following projects cover related needs and served as inspiration for this pr ⚖️ [BSD 3-Clause](./LICENSE) ## Template -This package layout and configuration (including pre-commit hooks and GitHub actions) have been copied from the [python-cookiecutter](https://github.com/SainsburyWellcomeCentre/python-cookiecutter) template. - -## How to contribute -### Setup -* We recommend you install `movement` inside a [conda](https://docs.conda.io/en/latest/) environment. -Assuming you have `conda` installed, the following will create and activate an environment containing Python 3 as well as the required `pytables` library. You can call your environment whatever you like, we've used `movement-env`. - - ```sh - conda create -n movement-env -c conda-forge python=3.11 pytables - conda activate movement-env - ``` - -* Next clone the repository and install the package in editable mode (including all `dev` dependencies): - - ```bash - git clone https://github.com/neuroinformatics-unit/movement - cd movement - pip install -e '.[dev]' - ``` -* Initialize the pre-commit hooks: - - ```bash - pre-commit install - ``` - -### Workflow -* Create a new branch, make your changes, and stage them. -* When you try to commit, the pre-commit hooks will be triggered. These include linting with [`ruff`](https://github.com/charliermarsh/ruff) and auto-formatting with [`black`](https://github.com/psf/black). Stage any changes made by the hooks, and commit. You may also run the pre-commit hooks manually, at any time, with `pre-commit run --all-files`. -* Push your changes to GitHub and open a draft pull request. -* If all checks (e.g. linting, type checking, testing) run successfully, you may mark the pull request as ready for review. -* For debugging purposes, you may also want to run the tests and the type checks locally, before pushing. This can be done with the following commands: - ```bash - cd movement - pytest - mypy -p movement - ``` -* When your pull request is approved, squash-merge it into the `main` branch and delete the feature branch. - -### Versioning and deployment -The package is deployed to PyPI automatically when a new release is created on GitHub. We use [semantic versioning](https://semver.org/), with `MAJOR`.`MINOR`.`PATCH` version numbers. - -We use [`setuptools_scm`](https://github.com/pypa/setuptools_scm), which automatically [infers the version using git](https://github.com/pypa/setuptools_scm#default-versioning-scheme). To manually set a new semantic version, create an appropriate tag and push it to GitHub. Make sure to commit any changes you wish to be included in this version. E.g. to bump the version to `1.0.0`: - -```bash -git add . -git commit -m "Add new changes" -git tag -a v1.0.0 -m "Bump to version 1.0.0" -git push --follow-tags -``` +This package layout and configuration (including pre-commit hooks and GitHub actions) have been copied from the [python-cookiecutter]((https://github.com/neuroinformatics-unit/python-cookiecutter)) template. diff --git a/docs/source/contributing.rst b/docs/source/contributing.rst new file mode 100644 index 00000000..7a661a30 --- /dev/null +++ b/docs/source/contributing.rst @@ -0,0 +1,2 @@ +.. include:: ../../CONTRIBUTING.md + :parser: myst_parser.sphinx_ diff --git a/docs/source/getting_started.md b/docs/source/getting_started.md index 321512bc..2c7c3a17 100644 --- a/docs/source/getting_started.md +++ b/docs/source/getting_started.md @@ -45,6 +45,7 @@ pip install -e '.[dev]' # works on zsh (the default shell on macOS) ``` This will install the package in editable mode, including all `dev` dependencies. +Please see the [contributing guide](./contributing.rst) for more information. ::: :::: diff --git a/docs/source/index.md b/docs/source/index.md index b429b607..6399c5ed 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -56,4 +56,5 @@ The following projects cover related needs and served as inspiration for this pr getting_started auto_examples/index api_index +contributing ``` From 3572b4701fb11e77c094fa00b6fe32ed393e6759 Mon Sep 17 00:00:00 2001 From: niksirbi Date: Wed, 16 Aug 2023 16:11:24 +0100 Subject: [PATCH 60/79] expanded getting started guide --- docs/source/getting_started.md | 108 ++++++++++++++++++++++++++--- examples/load_and_explore_poses.py | 2 +- 2 files changed, 101 insertions(+), 9 deletions(-) diff --git a/docs/source/getting_started.md b/docs/source/getting_started.md index 2c7c3a17..5b3d7ae9 100644 --- a/docs/source/getting_started.md +++ b/docs/source/getting_started.md @@ -2,14 +2,14 @@ ## Installation -We recommend you use install `movement` inside a [conda](https://docs.conda.io/en/latest/) +We recommend you install movement inside a [conda](https://docs.conda.io/en/latest/) or [mamba](https://mamba.readthedocs.io/en/latest/index.html) environment. In the following we assume you have `conda` installed, but the same commands will also work with `mamba`/`micromamba`. First, create and activate an environment. -You can call your environment whatever you like, we've used `movement-env`. +You can call your environment whatever you like, we've used "movement-env". ```sh conda create -n movement-env -c conda-forge python=3.10 pytables @@ -51,13 +51,11 @@ Please see the [contributing guide](./contributing.rst) for more information. :::: -## Usage - -### Loading data -You can load predicted pose tracks for the pose estimation software packages +## Loading data +You can load predicted pose tracks from the pose estimation software packages [DeepLabCut](http://www.mackenziemathislab.org/deeplabcut) or [SLEAP](https://sleap.ai/). -First import the `load_poses` function from the `movement.io` module: +First import the `movement.io.load_poses` module: ```python from movement.io import load_poses @@ -88,7 +86,7 @@ ds = load_poses.from_dlc_file("/path/to/file.csv", fps=30) ``` If you have already imported the data into a pandas DataFrame, you can -convert it to a `movement` dataset with: +convert it to a movement dataset with: ```python import pandas as pd @@ -98,3 +96,97 @@ ds = load_poses.from_dlc_df(df, fps=30) ::: :::: + +## Working with movement datasets + +Loaded pose estimation data are represented in movement as +[`xarray.Dataset`](https://docs.xarray.dev/en/stable/generated/xarray.Dataset.html) objects. + +You can view information about the loaded dataset by printing it: +```python +ds = load_poses.from_dlc_file("/path/to/file.h5", fps=30) +print(ds) +``` +If you working in a Jupyter notebook, you can also view an interactive +representation of the dataset by simply typing its name - e.g. `ds` - in a cell. + +### Dataset structure + +The movement `xarray.Dataset` has the following dimensions: +- `time`: the number of frames in the video +- `individuals`: the number of individuals in the video +- `keypoints`: the number of keypoints in the skeleton +- `space`: the number of spatial dimensions, either 2 or 3 + +Appropriate coordinate labels are assigned to each dimension: +list of unique names (str) for `individuals` and `keypoints`, +['x','y',('z')] for `space`. The coordinates of the `time` dimension are +in seconds if `fps` is provided, otherwise they are in frame numbers. + +The dataset contains two data variables stored as +[`xarray.DataArray`](https://docs.xarray.dev/en/latest/generated/xarray.DataArray.html#xarray.DataArray) objects: +- `pose_tracks`: with shape (`time`, `individuals`, `keypoints`, `space`) +- `confidence`: with shape (`time`, `individuals`, `keypoints`) + +You can think of a `DataArray` as a `numpy.ndarray` with `pandas`-style +indexing and labelling. To learn more about `xarray` data structures, see the +relevant [documentation](https://docs.xarray.dev/en/latest/user-guide/data-structures.html). + +The dataset may also contain the following attributes as metadata: +- `fps`: the number of frames per second in the video +- `time_unit`: the unit of the `time` coordinates, frames or seconds +- `source_software`: the software from which the pose tracks were loaded +- `source_file`: the file from which the pose tracks were loaded + +### Indexing and selection +You can access the data variables and attributes of the dataset as follows: +```python +pose_tracks = ds.pose_tracks # ds['pose_tracks'] also works +confidence = ds.confidence + +fps = ds.fps # ds.attrs['fps'] also works +``` + +You can select subsets of the data using the `sel` method: +```python +# select the first 100 seconds of data +ds_sel = ds.sel(time=slice(0, 100)) + +# select specific individuals or keypoints +ds_sel = ds.sel(individuals=["mouse1", "mouse2"]) +ds_sel = ds.sel(keypoints="snout") + +# combine selections +ds_sel = ds.sel(time=slice(0, 100), individuals"mouse1", keypoints="snout") +``` +All of the above selections can also be applied to the data variables, +resulting in a `DataArray` rather than a `Dataset`: + +```python +pose_tracks = ds.pose_tracks.sel(individuals="mouse1", keypoints="snout") +``` +You may also use all the other powerful [indexing and selection](https://docs.xarray.dev/en/latest/user-guide/indexing.html) methods provided by `xarray`. + +### Plotting + +You can also use the built-in [`xarray` plotting methods](https://docs.xarray.dev/en/latest/user-guide/plotting.html) +to visualise the data. Check out the [Load and explore pose tracks](./auto_examples/load_and_explore_poses.rst) +example for inspiration. + +## Saving data +You can save movement datasets to disk in a variety of formats. +Currently, only saving to DeepLabCut-style files is supported. + +```python +from movement.io import save_poses + +save_poses.to_dlc_file(ds, "/path/to/file.h5") # preferred +save_poses.to_dlc_file(ds, "/path/to/file.csv") +``` + +Instead of saving to file directly, you can also convert the dataset to a +DeepLabCut-style `pandas.DataFrame` first: +```python +df = save_poses.to_dlc_df(ds) +``` +and then save it to file using any `pandas` method, e.g. `to_hdf` or `to_csv`. diff --git a/examples/load_and_explore_poses.py b/examples/load_and_explore_poses.py index ef31b15e..17fb8630 100644 --- a/examples/load_and_explore_poses.py +++ b/examples/load_and_explore_poses.py @@ -40,7 +40,7 @@ # The loaded dataset contains two data variables: # ``pose_tracks`` and ``confidence``` # To get the pose tracks: -pose_tracks = ds["pose_tracks"] +pose_tracks = ds.pose_tracks # %% # Select and plot data with xarray From e84883366c868f66aff523a760dbf5cc6263aa40 Mon Sep 17 00:00:00 2001 From: niksirbi Date: Wed, 16 Aug 2023 17:50:51 +0100 Subject: [PATCH 61/79] added info about fetching sample data to the getting started section --- docs/source/getting_started.md | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/docs/source/getting_started.md b/docs/source/getting_started.md index 5b3d7ae9..1d53cb25 100644 --- a/docs/source/getting_started.md +++ b/docs/source/getting_started.md @@ -97,6 +97,35 @@ ds = load_poses.from_dlc_df(df, fps=30) :::: +You can also try movement out on some sample data included in the package. +:::{dropdown} Fetching sample data + +You can view the available sample data files with: + +```python +from movement import datasets + +file_names = datasets.find_pose_data() +print(file_names) +``` +This will print a list of file names containing sample pose data. +The files are prefixed with the name of the pose estimation software package, +either "DLC" or "SLEAP". + +To get the path to one of the sample files, +you can use the `fetch_pose_data_path` function: + +```python +file_path = datasets.fetch_pose_data_path("DLC_single-wasp.predictions.h5") +``` +The first time you call this function, it will download the corresponding file +to your local machine and save it in the `~/.movement/data` directory. On +subsequent calls, it will simply return the path to that local file. + +You can feed the path to the `from_dlc_file` or `from_sleap_file` functions +and load the data, as shown above. +::: + ## Working with movement datasets Loaded pose estimation data are represented in movement as From 64b4c3de954d96c1c2a679867b5fcfc35403aa84 Mon Sep 17 00:00:00 2001 From: niksirbi Date: Wed, 16 Aug 2023 18:12:31 +0100 Subject: [PATCH 62/79] some fancier formatting for the contributing guide in docs --- CONTRIBUTING.md | 2 -- docs/source/contributing.rst | 11 +++++++++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 9cc0031a..1d803b5e 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,7 +1,5 @@ # How to Contribute -## Introduction - **Contributors to movement are absolutely encouraged**, whether to fix a bug, develop a new feature, or improve the documentation. If you're unsure about any part of the contributing process, please get in touch. diff --git a/docs/source/contributing.rst b/docs/source/contributing.rst index 7a661a30..498202d4 100644 --- a/docs/source/contributing.rst +++ b/docs/source/contributing.rst @@ -1,2 +1,13 @@ .. include:: ../../CONTRIBUTING.md :parser: myst_parser.sphinx_ + :end-before: **Contributors + +.. important:: + .. include:: ../../CONTRIBUTING.md + :parser: myst_parser.sphinx_ + :start-after: How to Contribute + :end-before: ## Contributing code + +.. include:: ../../CONTRIBUTING.md + :parser: myst_parser.sphinx_ + :start-after: from the discussion. From f56b740977e3ced7e4bfb75e3461d3759833fffe Mon Sep 17 00:00:00 2001 From: niksirbi Date: Thu, 17 Aug 2023 12:27:15 +0100 Subject: [PATCH 63/79] add style to dropdown --- docs/source/getting_started.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/getting_started.md b/docs/source/getting_started.md index 1d53cb25..021e9627 100644 --- a/docs/source/getting_started.md +++ b/docs/source/getting_started.md @@ -99,6 +99,8 @@ ds = load_poses.from_dlc_df(df, fps=30) You can also try movement out on some sample data included in the package. :::{dropdown} Fetching sample data +:color: primary +:icon: unlock You can view the available sample data files with: From cbade144bbc3413eac8e16308ede43a9b1ddf0f5 Mon Sep 17 00:00:00 2001 From: niksirbi Date: Thu, 17 Aug 2023 12:55:19 +0100 Subject: [PATCH 64/79] fixed docstirngs and API reference --- movement/datasets.py | 9 +++++++-- movement/io/load_poses.py | 21 ++++++++++++++------- movement/io/save_poses.py | 14 ++++++++++++-- movement/logging.py | 6 +++--- 4 files changed, 36 insertions(+), 14 deletions(-) diff --git a/movement/datasets.py b/movement/datasets.py index 48512a30..8488e990 100644 --- a/movement/datasets.py +++ b/movement/datasets.py @@ -42,12 +42,17 @@ def find_pose_data() -> List[str]: - """Find available sample pose data.""" + """Find available sample pose data in the *movement* data repository. + + Returns + ------- + filenames : list of str + List of filenames for available pose data.""" return list(POSE_DATA.registry.keys()) def fetch_pose_data_path(filename: str) -> Path: - """Fetch sample pose data from the remote repository. + """Fetch sample pose data from the *movement* data repository. The data are downloaded to the user's local machine the first time they are used and are stored in a local cache directory. The function returns the diff --git a/movement/io/load_poses.py b/movement/io/load_poses.py index a414fbe6..20d229dc 100644 --- a/movement/io/load_poses.py +++ b/movement/io/load_poses.py @@ -21,7 +21,7 @@ def from_dlc_df(df: pd.DataFrame, fps: Optional[float] = None) -> xr.Dataset: - """Create an xarray.Dataset from a DLC_style pandas DataFrame. + """Create an xarray.Dataset from a DeepLabCut-style pandas DataFrame. Parameters ---------- @@ -45,6 +45,10 @@ def from_dlc_df(df: pd.DataFrame, fps: Optional[float] = None) -> xr.Dataset: The "coords" level contains the spatial coordinates "x", "y", as well as "likelihood" (point-wise confidence scores). The row index corresponds to the frame number. + + See Also + -------- + movement.io.load_poses.from_dlc_file : Load pose tracks directly from file. """ # read names of individuals and keypoints from the DataFrame @@ -78,8 +82,7 @@ def from_dlc_df(df: pd.DataFrame, fps: Optional[float] = None) -> xr.Dataset: def from_sleap_file( file_path: Union[Path, str], fps: Optional[float] = None ) -> xr.Dataset: - """Load pose tracking data from a SLEAP labels or analysis file - into an xarray Dataset. + """Load pose tracking data from a SLEAP file into an xarray Dataset. Parameters ---------- @@ -103,15 +106,15 @@ def from_sleap_file( from the ".slp" file, using either the command line tool `sleap-convert` (with the "--format analysis" option enabled) or the SLEAP GUI (Choose "Export Analysis HDF5…" from the "File" menu) [1]_. This is the - preferred format for loading pose tracks from SLEAP into `movement`. + preferred format for loading pose tracks from SLEAP into *movement*. - You can also try directly loading te ".slp" file, but this feature is + You can also try directly loading the ".slp" file, but this feature is experimental and doesnot work in all cases. If the ".slp" file contains both user-labeled and predicted instances, only the predicted ones will be loaded. If there are multiple videos in the file, only the first one will be used. - `movement` expects the tracks to be assigned and proofread before loading + *movement* expects the tracks to be assigned and proofread before loading them, meaning each track is interpreted as a single individual/animal. Follow the SLEAP guide for tracking and proofreading [2]_. @@ -160,7 +163,7 @@ def from_dlc_file( Parameters ---------- file_path : pathlib.Path or str - Path to the file containing the DLC poses, either in ".h5" + Path to the file containing the DLC predicted poses, either in ".h5" or ".csv" format. fps : float, optional The number of frames per second in the video. If None (default), @@ -171,6 +174,10 @@ def from_dlc_file( xarray.Dataset Dataset containing the pose tracks, confidence scores, and metadata. + See Also + -------- + movement.io.load_poses.from_dlc_df : Load pose tracks from a DataFrame. + Examples -------- >>> from movement.io import load_poses diff --git a/movement/io/save_poses.py b/movement/io/save_poses.py index 21a47cfa..e1be3f70 100644 --- a/movement/io/save_poses.py +++ b/movement/io/save_poses.py @@ -31,6 +31,11 @@ def to_dlc_df(ds: xr.Dataset) -> pd.DataFrame: one individual present). Regardless of the provenance of the points-wise confidence scores, they will be referred to as "likelihood", and stored in the "coords" level (as DeepLabCut expects). + + See Also + -------- + to_dlc_file : Save the xarray dataset containing pose tracks directly + to a DeepLabCut-style ".h5" or ".csv" file. """ if not isinstance(ds, xr.Dataset): @@ -72,8 +77,8 @@ def to_dlc_df(ds: xr.Dataset) -> pd.DataFrame: def to_dlc_file(ds: xr.Dataset, file_path: Union[str, Path]) -> None: - """Export the xarray dataset containing pose tracks to a - DeepLabCut-style .h5 or .csv file. + """Save the xarray dataset containing pose tracks to a + DeepLabCut-style ".h5" or ".csv" file. Parameters ---------- @@ -82,6 +87,11 @@ def to_dlc_file(ds: xr.Dataset, file_path: Union[str, Path]) -> None: file_path : pathlib Path or str Path to the file to save the DLC poses to. The file extension must be either ".h5" (recommended) or ".csv". + + See Also + -------- + to_dlc_df : Convert an xarray dataset containing pose tracks into a + DeepLabCut-style pandas DataFrame with multi-index columns. """ try: diff --git a/movement/logging.py b/movement/logging.py index d202e5fb..cbe5593d 100644 --- a/movement/logging.py +++ b/movement/logging.py @@ -22,7 +22,7 @@ def configure_logging( The logging level to use. Defaults to logging.INFO. logger_name : str, optional The name of the logger to configure. - Defaults to 'movement'. + Defaults to "movement". log_directory : pathlib.Path, optional The directory to store the log file in. Defaults to ~/.movement. A different directory can be specified, @@ -71,7 +71,7 @@ def log_and_raise_error(error, message: str, logger_name: str = "movement"): message : str The error message. logger_name : str, optional - The name of the logger to use. Defaults to 'movement'. + The name of the logger to use. Defaults to "movement". """ logger = logging.getLogger(logger_name) logger.error(message) @@ -86,7 +86,7 @@ def log_warning(message: str, logger_name: str = "movement"): message : str The warning message. logger_name : str, optional - The name of the logger to use. Defaults to 'movement'. + The name of the logger to use. Defaults to "movement". """ logger = logging.getLogger(logger_name) logger.warning(message) From 5c6a8375af76e6eae9bf7bb1bb4dd21636d8c47b Mon Sep 17 00:00:00 2001 From: niksirbi Date: Thu, 17 Aug 2023 15:10:16 +0100 Subject: [PATCH 65/79] fixed issue with duplicate source files generated by sphinx-gallery --- docs/source/conf.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/source/conf.py b/docs/source/conf.py index c42ded02..31c53e8a 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -83,6 +83,10 @@ # to ensure that include files (partial pages) aren't built, exclude them # https://github.com/sphinx-doc/sphinx/issues/1965#issuecomment-124732907 "**/includes/**", + # exclude .py and .ipynb files in auto_examples generated by sphinx-gallery + # this is to prevent sphinx from complaining about duplicate source files + "auto_examples/*.ipynb", + "auto_examples/*.py", ] # Configure Sphinx gallery From 95827b4f03bf525c160df2d7193b0cc55b2e0356 Mon Sep 17 00:00:00 2001 From: niksirbi Date: Thu, 17 Aug 2023 15:36:39 +0100 Subject: [PATCH 66/79] limit sphinx version to <7.2 --- docs/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index ae1bd4b5..cb6d2bac 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -5,7 +5,7 @@ myst-parser nbsphinx pydata-sphinx-theme setuptools-scm -sphinx +sphinx<7.2 sphinx-autodoc-typehints sphinx-design sphinx-gallery From 72608df629b1f5e100cae397c1f233e616093b08 Mon Sep 17 00:00:00 2001 From: niksirbi Date: Thu, 17 Aug 2023 15:52:45 +0100 Subject: [PATCH 67/79] replaced type with isinstance --- movement/io/validators.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/movement/io/validators.py b/movement/io/validators.py index 51f7e001..c498e10a 100644 --- a/movement/io/validators.py +++ b/movement/io/validators.py @@ -197,7 +197,7 @@ def csv_file_contains_expected_levels(self, attribute, value): def _list_of_str(value: Union[str, Iterable[Any]]) -> List[str]: """Try to coerce the value into a list of strings. Otherwise, raise a ValueError.""" - if type(value) is str: + if isinstance(value, str): log_warning( f"Invalid value ({value}). Expected a list of strings. " "Converting to a list of length 1." @@ -213,7 +213,7 @@ def _list_of_str(value: Union[str, Iterable[Any]]) -> List[str]: def _ensure_type_ndarray(value: Any) -> None: """Raise ValueError the value is a not numpy array.""" - if type(value) is not np.ndarray: + if not isinstance(value, np.ndarray): raise log_and_raise_error( ValueError, f"Expected a numpy array, but got {type(value)}." ) From 647af1f352d20ec1e071f71e878978dc2c0c9b68 Mon Sep 17 00:00:00 2001 From: niksirbi Date: Thu, 17 Aug 2023 16:55:23 +0100 Subject: [PATCH 68/79] added ValidPoseTracks to API reference --- docs/source/api_index.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/api_index.rst b/docs/source/api_index.rst index 17fc54d1..07c3d1b2 100644 --- a/docs/source/api_index.rst +++ b/docs/source/api_index.rst @@ -26,6 +26,7 @@ Input/Output ValidFile ValidHDF5 ValidPosesCSV + ValidPoseTracks Datasets -------- From 0eac6584c4a9a7614d7b285be49b10f01d700e1c Mon Sep 17 00:00:00 2001 From: Niko Sirmpilatze Date: Mon, 11 Sep 2023 15:37:39 +0100 Subject: [PATCH 69/79] Fix typos from code review Co-authored-by: Chang Huan Lo --- CONTRIBUTING.md | 8 ++++---- README.md | 2 +- docs/source/getting_started.md | 6 +++--- examples/load_and_explore_poses.py | 2 +- movement/io/validators.py | 2 +- movement/logging.py | 2 +- 6 files changed, 11 insertions(+), 11 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 1d803b5e..88611193 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -53,7 +53,7 @@ We recommend, and adhere, to the following conventions: - One approval of a PR (by a repo owner) is enough for it to be merged. - Unless someone approves the PR with optional comments, the PR is immediately merged by the approving reviewer. - Ask for a review from someone specific if you think they would be a particularly suited reviewer. -- PRs are preferenitally merged via the ["squash and merge"](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/incorporating-changes-from-a-pull-request/about-pull-request-merges#squash-and-merge-your-commits) option, to keep a clean commit history on the _main_ branch. +- PRs are preferably merged via the ["squash and merge"](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/incorporating-changes-from-a-pull-request/about-pull-request-merges#squash-and-merge-your-commits) option, to keep a clean commit history on the _main_ branch. A typical PR workflow would be: * Create a new branch, make your changes, and stage them. @@ -65,7 +65,7 @@ A typical PR workflow would be: * Push your changes to GitHub and open a draft pull request, with a meaningful title and a thorough description of the changes. * If all checks (e.g. linting, type checking, testing) run successfully, you may mark the pull request as ready for review. * Respond to review comments and implement any requested changes. -* Sucess 🎉 !! Your PR will be (squash-)merged into the _main_ branch. +* Success 🎉 !! Your PR will be (squash-)merged into the _main_ branch. ## Development guidelines @@ -153,7 +153,7 @@ Other `.rst` or `.md` files are linked to the homepage via the `toctree` direct We use [Sphinx](https://www.sphinx-doc.org/en/master/) and the [PyData Sphinx Theme](https://pydata-sphinx-theme.readthedocs.io/en/stable/index.html) -to build the source files into html output. +to build the source files into HTML output. This is handled by a GitHub actions workflow (`.github/workflows/docs_build_and_deploy.yml`). The build job is triggered on each PR, ensuring that the documentation build is not broken by new changes. The deployment job is only triggerred whenever a tag is pushed to the _main_ branch, @@ -164,7 +164,7 @@ ensuring that the documentation is published in sync with each PyPI release. To edit the documentation, first clone the repository, and install movement in a [development environment](#creating-a-development-environment). -Now open a new branch, edit the documentation source files (`.md` or `.rst` in the `docs` folder), +Now create a new branch, edit the documentation source files (`.md` or `.rst` in the `docs` folder), and commit your changes. Submit your documentation changes via a pull request, following the [same guidelines as for code changes](#pull-requests). Make sure that the header levels in your `.md` or `.rst` files are incremented diff --git a/README.md b/README.md index d824db1b..ec7256c0 100644 --- a/README.md +++ b/README.md @@ -35,4 +35,4 @@ The following projects cover related needs and served as inspiration for this pr ⚖️ [BSD 3-Clause](./LICENSE) ## Template -This package layout and configuration (including pre-commit hooks and GitHub actions) have been copied from the [python-cookiecutter]((https://github.com/neuroinformatics-unit/python-cookiecutter)) template. +This package layout and configuration (including pre-commit hooks and GitHub actions) have been copied from the [python-cookiecutter](https://github.com/neuroinformatics-unit/python-cookiecutter) template. diff --git a/docs/source/getting_started.md b/docs/source/getting_started.md index 021e9627..5958f0c2 100644 --- a/docs/source/getting_started.md +++ b/docs/source/getting_started.md @@ -138,7 +138,7 @@ You can view information about the loaded dataset by printing it: ds = load_poses.from_dlc_file("/path/to/file.h5", fps=30) print(ds) ``` -If you working in a Jupyter notebook, you can also view an interactive +If you are working in a Jupyter notebook, you can also view an interactive representation of the dataset by simply typing its name - e.g. `ds` - in a cell. ### Dataset structure @@ -172,10 +172,10 @@ The dataset may also contain the following attributes as metadata: ### Indexing and selection You can access the data variables and attributes of the dataset as follows: ```python -pose_tracks = ds.pose_tracks # ds['pose_tracks'] also works +pose_tracks = ds.pose_tracks # ds["pose_tracks"] also works confidence = ds.confidence -fps = ds.fps # ds.attrs['fps'] also works +fps = ds.fps # ds.attrs["fps"] also works ``` You can select subsets of the data using the `sel` method: diff --git a/examples/load_and_explore_poses.py b/examples/load_and_explore_poses.py index 17fb8630..1a47647c 100644 --- a/examples/load_and_explore_poses.py +++ b/examples/load_and_explore_poses.py @@ -64,7 +64,7 @@ da = pose_tracks.sel(keypoints="centroid") da.plot.line(x="time", row="individuals", aspect=2, size=2.5) -# %%s +# %% # Trajectory plots # ---------------- # We are not limited to ``xarray``'s built-in plots. diff --git a/movement/io/validators.py b/movement/io/validators.py index c498e10a..c07b2152 100644 --- a/movement/io/validators.py +++ b/movement/io/validators.py @@ -80,7 +80,7 @@ def file_has_access_permissions(self, attribute, value): raise log_and_raise_error( PermissionError, f"Unable to read file: {value}. " - "Make sure that you have read permissions for it.", + "Make sure that you have read permissions.", ) if ("w" in self.expected_permission) and (not parent_is_writeable): raise log_and_raise_error( diff --git a/movement/logging.py b/movement/logging.py index cbe5593d..4888322d 100644 --- a/movement/logging.py +++ b/movement/logging.py @@ -62,7 +62,7 @@ def configure_logging( def log_and_raise_error(error, message: str, logger_name: str = "movement"): - """Log an error message and raise a ValueError. + """Log an error message and raise an Error. Parameters ---------- From f7c2726e41433bdeb56377f2782d2f409a7e59d9 Mon Sep 17 00:00:00 2001 From: niksirbi Date: Thu, 14 Sep 2023 11:14:04 +0100 Subject: [PATCH 70/79] added check-manifest as dev dependency --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index dfdb2e82..6674e52b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,7 @@ dev = [ "setuptools_scm", "pandas-stubs", "types-attrs", + "check-manifest", ] [build-system] From 90778a0bcaffae00f948de00576ec23d649df19e Mon Sep 17 00:00:00 2001 From: niksirbi Date: Thu, 14 Sep 2023 11:17:14 +0100 Subject: [PATCH 71/79] removed duplicate word in manifest --- MANIFEST.in | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/MANIFEST.in b/MANIFEST.in index cad8162a..ff091745 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,5 +1,5 @@ include LICENSE -include include *.md +include *.md exclude .pre-commit-config.yaml exclude .cruft.json From 79acf0d2380003e117089576ceaafbed53a1fc41 Mon Sep 17 00:00:00 2001 From: niksirbi Date: Thu, 14 Sep 2023 11:30:35 +0100 Subject: [PATCH 72/79] edit code examples in getting started guide to avoid key errors --- docs/source/getting_started.md | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/docs/source/getting_started.md b/docs/source/getting_started.md index 5958f0c2..017ade87 100644 --- a/docs/source/getting_started.md +++ b/docs/source/getting_started.md @@ -98,6 +98,7 @@ ds = load_poses.from_dlc_df(df, fps=30) :::: You can also try movement out on some sample data included in the package. + :::{dropdown} Fetching sample data :color: primary :icon: unlock @@ -118,7 +119,7 @@ To get the path to one of the sample files, you can use the `fetch_pose_data_path` function: ```python -file_path = datasets.fetch_pose_data_path("DLC_single-wasp.predictions.h5") +file_path = datasets.fetch_pose_data_path("DLC_two-mice.predictions.csv") ``` The first time you call this function, it will download the corresponding file to your local machine and save it in the `~/.movement/data` directory. On @@ -184,17 +185,17 @@ You can select subsets of the data using the `sel` method: ds_sel = ds.sel(time=slice(0, 100)) # select specific individuals or keypoints -ds_sel = ds.sel(individuals=["mouse1", "mouse2"]) +ds_sel = ds.sel(individuals=["individual1", "individual2"]) ds_sel = ds.sel(keypoints="snout") # combine selections -ds_sel = ds.sel(time=slice(0, 100), individuals"mouse1", keypoints="snout") +ds_sel = ds.sel(time=slice(0, 100), individuals=["individual1", "individual2"], keypoints="snout") ``` All of the above selections can also be applied to the data variables, resulting in a `DataArray` rather than a `Dataset`: ```python -pose_tracks = ds.pose_tracks.sel(individuals="mouse1", keypoints="snout") +pose_tracks = ds.pose_tracks.sel(individuals="individual1", keypoints="snout") ``` You may also use all the other powerful [indexing and selection](https://docs.xarray.dev/en/latest/user-guide/indexing.html) methods provided by `xarray`. From 636f98f304bbfe10392a1d5eed68951208da6c3a Mon Sep 17 00:00:00 2001 From: niksirbi Date: Thu, 14 Sep 2023 11:32:47 +0100 Subject: [PATCH 73/79] fix fps value in example --- examples/load_and_explore_poses.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/load_and_explore_poses.py b/examples/load_and_explore_poses.py index 1a47647c..505ee87a 100644 --- a/examples/load_and_explore_poses.py +++ b/examples/load_and_explore_poses.py @@ -33,7 +33,7 @@ # Load the dataset # ---------------- -ds = load_poses.from_sleap_file(file_path, fps=60) +ds = load_poses.from_sleap_file(file_path, fps=50) print(ds) # %% From 49f7d33aa0d1415eed7e763dba3100de9481b561 Mon Sep 17 00:00:00 2001 From: niksirbi Date: Thu, 14 Sep 2023 11:47:14 +0100 Subject: [PATCH 74/79] renamed find_pose_data to list_pose_data --- CONTRIBUTING.md | 2 +- docs/source/api_index.rst | 2 +- docs/source/getting_started.md | 2 +- examples/load_and_explore_poses.py | 2 +- movement/datasets.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 88611193..50581c55 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -252,7 +252,7 @@ The relevant functionality is implemented in the `movement.datasets.py` module. The most important parts of this module are: 1. The `POSE_DATA` download manager object, which contains a list of stored files and their known hashes. -2. The `find_pose_data()` function, which returns a list of the available files in the data repository. +2. The `list_pose_data()` function, which returns a list of the available files in the data repository. 3. The `fetch_pose_data_path()` function, which downloads a file (if not already cached locally) and returns the local path to it. By default, the downloaded files are stored in the `~/.movement/data` folder. diff --git a/docs/source/api_index.rst b/docs/source/api_index.rst index 07c3d1b2..ae000ed9 100644 --- a/docs/source/api_index.rst +++ b/docs/source/api_index.rst @@ -34,7 +34,7 @@ Datasets .. autosummary:: :toctree: auto_api - find_pose_data + list_pose_data fetch_pose_data_path Logging diff --git a/docs/source/getting_started.md b/docs/source/getting_started.md index 017ade87..29bb21c0 100644 --- a/docs/source/getting_started.md +++ b/docs/source/getting_started.md @@ -108,7 +108,7 @@ You can view the available sample data files with: ```python from movement import datasets -file_names = datasets.find_pose_data() +file_names = datasets.list_pose_data() print(file_names) ``` This will print a list of file names containing sample pose data. diff --git a/examples/load_and_explore_poses.py b/examples/load_and_explore_poses.py index 505ee87a..8562b532 100644 --- a/examples/load_and_explore_poses.py +++ b/examples/load_and_explore_poses.py @@ -18,7 +18,7 @@ # ------------------------ # Print a list of available datasets: -for file_name in datasets.find_pose_data(): +for file_name in datasets.list_pose_data(): print(file_name) # %% diff --git a/movement/datasets.py b/movement/datasets.py index 8488e990..30760125 100644 --- a/movement/datasets.py +++ b/movement/datasets.py @@ -41,7 +41,7 @@ ) -def find_pose_data() -> List[str]: +def list_pose_data() -> List[str]: """Find available sample pose data in the *movement* data repository. Returns From 35c80ffe2edc23fd459a0b594aaef8f1173a4b35 Mon Sep 17 00:00:00 2001 From: niksirbi Date: Thu, 14 Sep 2023 13:18:35 +0100 Subject: [PATCH 75/79] use assert_allclose from numpy.testing --- tests/test_unit/test_io.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_unit/test_io.py b/tests/test_unit/test_io.py index a3803cc4..ea795084 100644 --- a/tests/test_unit/test_io.py +++ b/tests/test_unit/test_io.py @@ -270,7 +270,7 @@ def test_load_and_save_to_dlc_df(self, dlc_style_df): converting back to a DataFrame returns the same data values.""" ds = load_poses.from_dlc_df(dlc_style_df) df = save_poses.to_dlc_df(ds) - assert np.allclose(df.values, dlc_style_df.values) + np.testing.assert_allclose(df.values, dlc_style_df.values) def test_save_and_load_dlc_file(self, valid_pose_dataset, tmp_path): """Test that saving pose tracks to DLC .h5 and .csv files and then @@ -315,7 +315,7 @@ def test_fps_and_time_coords(self, sleap_file_h5_multi, fps): else: assert ds.fps == fps assert ds.time_unit == "seconds" - np.allclose( + np.testing.assert_allclose( ds.coords["time"].data, np.arange(ds.dims["time"], dtype=int) / ds.attrs["fps"], ) From 9d5af00d66a434bfda03a3ffcf1fcb2e76ef7070 Mon Sep 17 00:00:00 2001 From: niksirbi Date: Thu, 14 Sep 2023 14:04:44 +0100 Subject: [PATCH 76/79] modify function fo logging errors --- docs/source/api_index.rst | 2 +- movement/io/load_poses.py | 6 ++---- movement/io/validators.py | 34 ++++++++++++++++----------------- movement/logging.py | 13 +++++++++---- tests/test_unit/test_logging.py | 14 +++++++------- 5 files changed, 36 insertions(+), 33 deletions(-) diff --git a/docs/source/api_index.rst b/docs/source/api_index.rst index ae000ed9..32b9a7ce 100644 --- a/docs/source/api_index.rst +++ b/docs/source/api_index.rst @@ -44,5 +44,5 @@ Logging :toctree: auto_api configure_logging - log_and_raise_error + log_error log_warning diff --git a/movement/io/load_poses.py b/movement/io/load_poses.py index 20d229dc..b181780a 100644 --- a/movement/io/load_poses.py +++ b/movement/io/load_poses.py @@ -15,7 +15,7 @@ ValidPosesCSV, ValidPoseTracks, ) -from movement.logging import log_and_raise_error +from movement.logging import log_error logger = logging.getLogger(__name__) @@ -350,9 +350,7 @@ def _load_df_from_dlc_h5(file_path: Path) -> pd.DataFrame: # pd.read_hdf does not always return a DataFrame df = pd.DataFrame(pd.read_hdf(file.path, key="df_with_missing")) except Exception as error: - log_and_raise_error( - error, f"Could not load a dataframe from {file.path}." - ) + raise log_error(error, f"Could not load a dataframe from {file.path}.") return df diff --git a/movement/io/validators.py b/movement/io/validators.py index c07b2152..aaaf49d6 100644 --- a/movement/io/validators.py +++ b/movement/io/validators.py @@ -6,7 +6,7 @@ import numpy as np from attrs import converters, define, field, validators -from movement.logging import log_and_raise_error, log_warning +from movement.logging import log_error, log_warning @define @@ -50,7 +50,7 @@ class ValidFile: def path_is_not_dir(self, attribute, value): """Ensures that the path does not point to a directory.""" if value.is_dir(): - log_and_raise_error( + raise log_error( IsADirectoryError, f"Expected a file path but got a directory: {value}.", ) @@ -61,12 +61,12 @@ def file_exists_when_expected(self, attribute, value): usage (read and/or write).""" if "r" in self.expected_permission: if not value.exists(): - raise log_and_raise_error( + raise log_error( FileNotFoundError, f"File {value} does not exist." ) else: # expected_permission is 'w' if value.exists(): - raise log_and_raise_error( + raise log_error( FileExistsError, f"File {value} already exists." ) @@ -77,13 +77,13 @@ def file_has_access_permissions(self, attribute, value): file_is_readable = os.access(value, os.R_OK) parent_is_writeable = os.access(value.parent, os.W_OK) if ("r" in self.expected_permission) and (not file_is_readable): - raise log_and_raise_error( + raise log_error( PermissionError, f"Unable to read file: {value}. " "Make sure that you have read permissions.", ) if ("w" in self.expected_permission) and (not parent_is_writeable): - raise log_and_raise_error( + raise log_error( PermissionError, f"Unable to write to file: {value}. " "Make sure that you have write permissions.", @@ -94,7 +94,7 @@ def file_has_expected_suffix(self, attribute, value): """Ensures that the file has one of the expected suffix(es).""" if self.expected_suffix: # list is not empty if value.suffix not in self.expected_suffix: - raise log_and_raise_error( + raise log_error( ValueError, f"Expected file with suffix(es) {self.expected_suffix} " f"but got suffix {value.suffix} instead.", @@ -130,7 +130,7 @@ def file_is_h5(self, attribute, value): with h5py.File(value, "r") as f: f.close() except Exception as e: - raise log_and_raise_error( + raise log_error( ValueError, f"File {value} does not seem to be in valid" "HDF5 format.", ) from e @@ -142,7 +142,7 @@ def file_contains_expected_datasets(self, attribute, value): with h5py.File(value, "r") as f: diff = set(self.expected_datasets).difference(set(f.keys())) if len(diff) > 0: - raise log_and_raise_error( + raise log_error( ValueError, f"Could not find the expected dataset(s) {diff} " f"in file: {value}. ", @@ -186,7 +186,7 @@ def csv_file_contains_expected_levels(self, attribute, value): level in header_rows_start for level in expected_levels ] if not all(level_in_header_row_starts): - raise log_and_raise_error( + raise log_error( ValueError, f"The header rows of the CSV file {value} do not " "contain all expected index column levels " @@ -206,7 +206,7 @@ def _list_of_str(value: Union[str, Iterable[Any]]) -> List[str]: elif isinstance(value, Iterable): return [str(item) for item in value] else: - log_and_raise_error( + raise log_error( ValueError, f"Invalid value ({value}). Expected a list of strings." ) @@ -214,7 +214,7 @@ def _list_of_str(value: Union[str, Iterable[Any]]) -> List[str]: def _ensure_type_ndarray(value: Any) -> None: """Raise ValueError the value is a not numpy array.""" if not isinstance(value, np.ndarray): - raise log_and_raise_error( + raise log_error( ValueError, f"Expected a numpy array, but got {type(value)}." ) @@ -280,13 +280,13 @@ class ValidPoseTracks: def _validate_tracks_array(self, attribute, value): _ensure_type_ndarray(value) if value.ndim != 4: - log_and_raise_error( + raise log_error( ValueError, f"Expected `{attribute}` to have 4 dimensions, " f"but got {value.ndim}.", ) if value.shape[-1] not in [2, 3]: - log_and_raise_error( + raise log_error( ValueError, f"Expected `{attribute}` to have 2 or 3 spatial dimensions, " f"but got {value.shape[-1]}.", @@ -297,7 +297,7 @@ def _validate_scores_array(self, attribute, value): if value is not None: _ensure_type_ndarray(value) if value.shape != self.tracks_array.shape[:-1]: - log_and_raise_error( + raise log_error( ValueError, f"Expected `{attribute}` to have shape " f"{self.tracks_array.shape[:-1]}, but got {value.shape}.", @@ -306,7 +306,7 @@ def _validate_scores_array(self, attribute, value): @individual_names.validator def _validate_individual_names(self, attribute, value): if (value is not None) and (len(value) != self.tracks_array.shape[1]): - log_and_raise_error( + raise log_error( ValueError, f"Expected {self.tracks_array.shape[1]} `{attribute}`, " f"but got {len(value)}.", @@ -315,7 +315,7 @@ def _validate_individual_names(self, attribute, value): @keypoint_names.validator def _validate_keypoint_names(self, attribute, value): if (value is not None) and (len(value) != self.tracks_array.shape[2]): - log_and_raise_error( + raise log_error( ValueError, f"Expected {self.tracks_array.shape[2]} `{attribute}`, " f"but got {len(value)}.", diff --git a/movement/logging.py b/movement/logging.py index 4888322d..994b2d81 100644 --- a/movement/logging.py +++ b/movement/logging.py @@ -61,21 +61,26 @@ def configure_logging( logger.addHandler(handler) -def log_and_raise_error(error, message: str, logger_name: str = "movement"): - """Log an error message and raise an Error. +def log_error(error, message: str, logger_name: str = "movement"): + """Log an error message and return the Exception. Parameters ---------- error : Exception - The error to log and raise. + The error to log and return. message : str The error message. logger_name : str, optional The name of the logger to use. Defaults to "movement". + + Returns + ------- + Exception + The error that was passed in. """ logger = logging.getLogger(logger_name) logger.error(message) - raise error(message) + return error(message) def log_warning(message: str, logger_name: str = "movement"): diff --git a/tests/test_unit/test_logging.py b/tests/test_unit/test_logging.py index 5574c420..bd4d2f6e 100644 --- a/tests/test_unit/test_logging.py +++ b/tests/test_unit/test_logging.py @@ -2,7 +2,7 @@ import pytest -from movement.logging import log_and_raise_error, log_warning +from movement.logging import log_error, log_warning log_messages = { "DEBUG": "This is a debug message", @@ -25,13 +25,13 @@ def test_logfile_contains_message(level, message): assert message in last_line -def test_log_and_raise_error(caplog): - """Check if the log_and_raise_error function - logs the error message and raises a ValueError.""" +def test_log_error(caplog): + """Check if the log_error function + logs the error message and returns an Exception.""" with pytest.raises(ValueError): - log_and_raise_error(ValueError, "This is a test error") - assert caplog.records[0].message == "This is a test error" - assert caplog.records[0].levelname == "ERROR" + raise log_error(ValueError, "This is a test error") + assert caplog.records[0].message == "This is a test error" + assert caplog.records[0].levelname == "ERROR" def test_log_warning(caplog): From 839ec0f214f4dacc372cadc0b0c25e8709407a01 Mon Sep 17 00:00:00 2001 From: niksirbi Date: Mon, 18 Sep 2023 11:49:02 +0200 Subject: [PATCH 77/79] improved DLC pose CSV file validator --- movement/io/load_poses.py | 2 +- movement/io/validators.py | 32 +++++++++++++++----------------- tests/test_unit/test_io.py | 6 ------ 3 files changed, 16 insertions(+), 24 deletions(-) diff --git a/movement/io/load_poses.py b/movement/io/load_poses.py index b181780a..2ed8d4de 100644 --- a/movement/io/load_poses.py +++ b/movement/io/load_poses.py @@ -301,7 +301,7 @@ def _parse_dlc_csv_to_df(file_path: Path) -> pd.DataFrame: DeepLabCut-style DataFrame with multi-index columns. """ - file = ValidPosesCSV(file_path, multianimal=False) + file = ValidPosesCSV(file_path) possible_level_names = ["scorer", "individuals", "bodyparts", "coords"] with open(file.path, "r") as f: diff --git a/movement/io/validators.py b/movement/io/validators.py index aaaf49d6..da5e7bb9 100644 --- a/movement/io/validators.py +++ b/movement/io/validators.py @@ -158,9 +158,6 @@ class ValidPosesCSV: ---------- path : pathlib.Path Path to the CSV file. - multianimal : bool - Whether to ensure that the CSV file contains pose estimation outputs - for multiple animals. Default: False. Raises ------ @@ -170,28 +167,29 @@ class ValidPosesCSV: """ path: Path = field(validator=validators.instance_of(Path)) - multianimal: bool = field(default=False, kw_only=True) @path.validator def csv_file_contains_expected_levels(self, attribute, value): """Ensure that the CSV file contains the expected index column levels among its top rows.""" expected_levels = ["scorer", "bodyparts", "coords"] - if self.multianimal: - expected_levels.insert(1, "individuals") with open(value, "r") as f: - header_rows_start = [f.readline().split(",")[0] for _ in range(4)] - level_in_header_row_starts = [ - level in header_rows_start for level in expected_levels - ] - if not all(level_in_header_row_starts): - raise log_error( - ValueError, - f"The header rows of the CSV file {value} do not " - "contain all expected index column levels " - f"{expected_levels}.", - ) + top4_row_starts = [f.readline().split(",")[0] for _ in range(4)] + + if top4_row_starts[3].isdigit(): + # if 4th row starts with a digit, assume single-animal DLC file + expected_levels.append(top4_row_starts[3]) + else: + # otherwise, assume multi-animal DLC file + expected_levels.insert(1, "individuals") + + if top4_row_starts != expected_levels: + raise log_error( + ValueError, + "CSV header rows do not match the known format for " + "DeepLabCut pose estimation output files.", + ) def _list_of_str(value: Union[str, Iterable[Any]]) -> List[str]: diff --git a/tests/test_unit/test_io.py b/tests/test_unit/test_io.py index ea795084..bf656f5a 100644 --- a/tests/test_unit/test_io.py +++ b/tests/test_unit/test_io.py @@ -259,12 +259,6 @@ def test_file_validator(self, invalid_files): with pytest.raises(ValueError): ValidPosesCSV(path=file_path) - def test_dlc_poses_csv_validator(self, dlc_file_csv_single): - """Test that the validator for DLC .csv files raises error when - multianimal=True and the 'individuals' level is missing.""" - with pytest.raises(ValueError): - ValidPosesCSV(path=dlc_file_csv_single, multianimal=True) - def test_load_and_save_to_dlc_df(self, dlc_style_df): """Test that loading pose tracks from a DLC-style DataFrame and converting back to a DataFrame returns the same data values.""" From 3164e4462d8dfef9121fa7b138f24cb2ece06c82 Mon Sep 17 00:00:00 2001 From: niksirbi Date: Mon, 18 Sep 2023 12:36:22 +0200 Subject: [PATCH 78/79] write resuable list length validator --- movement/io/validators.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/movement/io/validators.py b/movement/io/validators.py index da5e7bb9..c56f5bb7 100644 --- a/movement/io/validators.py +++ b/movement/io/validators.py @@ -228,6 +228,18 @@ def _set_fps_to_none_if_invalid(fps: Optional[float]) -> Optional[float]: return fps +def _validate_list_length( + attribute: str, value: Optional[List], expected_length: int +): + """Raise a ValueError if the list does not have the expected length.""" + if (value is not None) and (len(value) != expected_length): + raise log_error( + ValueError, + f"Expected `{attribute}` to have length {expected_length}, " + f"but got {len(value)}.", + ) + + @define(kw_only=True) class ValidPoseTracks: """Class for validating pose tracking data imported from a file. @@ -303,21 +315,11 @@ def _validate_scores_array(self, attribute, value): @individual_names.validator def _validate_individual_names(self, attribute, value): - if (value is not None) and (len(value) != self.tracks_array.shape[1]): - raise log_error( - ValueError, - f"Expected {self.tracks_array.shape[1]} `{attribute}`, " - f"but got {len(value)}.", - ) + _validate_list_length(attribute, value, self.tracks_array.shape[1]) @keypoint_names.validator def _validate_keypoint_names(self, attribute, value): - if (value is not None) and (len(value) != self.tracks_array.shape[2]): - raise log_error( - ValueError, - f"Expected {self.tracks_array.shape[2]} `{attribute}`, " - f"but got {len(value)}.", - ) + _validate_list_length(attribute, value, self.tracks_array.shape[2]) def __attrs_post_init__(self): """Assign default values to optional attributes (if None)""" From 58cfa33d6cc5a464b051ae36d0990ef82dd32a55 Mon Sep 17 00:00:00 2001 From: niksirbi Date: Mon, 18 Sep 2023 16:29:25 +0200 Subject: [PATCH 79/79] reset docs deployment workflow --- .github/workflows/docs_build_and_deploy.yml | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/.github/workflows/docs_build_and_deploy.yml b/.github/workflows/docs_build_and_deploy.yml index 8da2c152..a3a4995f 100644 --- a/.github/workflows/docs_build_and_deploy.yml +++ b/.github/workflows/docs_build_and_deploy.yml @@ -9,13 +9,9 @@ on: push: branches: - main - - pose-tracks-io tags: - '*' pull_request: - branches: - - main - - pose-tracks-io workflow_dispatch: jobs: @@ -30,6 +26,7 @@ jobs: needs: build_sphinx_docs permissions: contents: write + if: github.event_name == 'push' && github.ref_type == 'tag' runs-on: ubuntu-latest steps: - uses: neuroinformatics-unit/actions/deploy_sphinx_docs@v2