Skip to content

Commit

Permalink
Loading function for Anipose data (#358)
Browse files Browse the repository at this point in the history
* first draft of loading function

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* adapted to new dimensions order

* adapted to work with new dims arrangement

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* anipose loader test

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* validator for anipose file

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* anipose validator finished

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* linting fixes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update tests/test_unit/test_validators/test_files_validators.py

Co-authored-by: Niko Sirmpilatze <[email protected]>

* simplified validator test

* Update movement/io/load_poses.py

Co-authored-by: Niko Sirmpilatze <[email protected]>

* Update movement/validators/files.py

Co-authored-by: Niko Sirmpilatze <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update movement/validators/files.py

Co-authored-by: Niko Sirmpilatze <[email protected]>

* Update movement/validators/files.py

Co-authored-by: Niko Sirmpilatze <[email protected]>

* implementing fixes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* more consistency fixes

* moved anipose loading test to load_poses

* fixed validators tests

* tests for anipose loading done properly

* docstring fixes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Implementing direct anipose load from from_file

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* ruffed

* trying to fix mypy check

* Update movement/io/load_poses.py

Co-authored-by: Niko Sirmpilatze <[email protected]>

* Update movement/io/load_poses.py

Co-authored-by: Niko Sirmpilatze <[email protected]>

* Update movement/io/load_poses.py

Co-authored-by: Niko Sirmpilatze <[email protected]>

* Update movement/io/load_poses.py

Co-authored-by: Niko Sirmpilatze <[email protected]>

* final touches to docstrings

* added entry in input_output docs

* define anipose link in conf.py

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Niko Sirmpilatze <[email protected]>
  • Loading branch information
3 people authored Dec 11, 2024
1 parent e7d8e47 commit d190ce5
Show file tree
Hide file tree
Showing 7 changed files with 350 additions and 4 deletions.
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@
"xarray": "https://docs.xarray.dev/en/stable/{{path}}#{{fragment}}",
"lp": "https://lightning-pose.readthedocs.io/en/stable/{{path}}#{{fragment}}",
"via": "https://www.robots.ox.ac.uk/~vgg/software/via/{{path}}#{{fragment}}",
"anipose": "https://anipose.readthedocs.io/en/latest/",
}

intersphinx_mapping = {
Expand Down
17 changes: 17 additions & 0 deletions docs/source/user_guide/input_output.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ To analyse pose tracks, `movement` supports loading data from various frameworks
- [DeepLabCut](dlc:) (DLC)
- [SLEAP](sleap:) (SLEAP)
- [LightingPose](lp:) (LP)
- [Anipose](anipose:) (Anipose)

To analyse bounding boxes' tracks, `movement` currently supports the [VGG Image Annotator](via:) (VIA) format for [tracks annotation](via:docs/face_track_annotation.html).

Expand Down Expand Up @@ -84,6 +85,22 @@ ds = load_poses.from_file(
```
:::

:::{tab-item} Anipose

To load Anipose files in .csv format:
```python
ds = load_poses.from_anipose_file(
"/path/to/file.analysis.csv", fps=30, individual_name="individual_0"
) # We can optionally specify the individual name, by default it is "individual_0"

# or equivalently
ds = load_poses.from_file(
"/path/to/file.analysis.csv", source_software="Anipose", fps=30, individual_name="individual_0"
)

```
:::

:::{tab-item} From NumPy

In the example below, we create random position data for two individuals, ``Alice`` and ``Bob``,
Expand Down
133 changes: 130 additions & 3 deletions movement/io/load_poses.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@

from movement.utils.logging import log_error, log_warning
from movement.validators.datasets import ValidPosesDataset
from movement.validators.files import ValidDeepLabCutCSV, ValidFile, ValidHDF5
from movement.validators.files import (
ValidAniposeCSV,
ValidDeepLabCutCSV,
ValidFile,
ValidHDF5,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -91,8 +96,11 @@ def from_numpy(

def from_file(
file_path: Path | str,
source_software: Literal["DeepLabCut", "SLEAP", "LightningPose"],
source_software: Literal[
"DeepLabCut", "SLEAP", "LightningPose", "Anipose"
],
fps: float | None = None,
**kwargs,
) -> xr.Dataset:
"""Create a ``movement`` poses dataset from any supported file.
Expand All @@ -104,11 +112,14 @@ def from_file(
``from_slp_file()`` or ``from_lp_file()`` functions. One of these
these functions will be called internally, based on
the value of ``source_software``.
source_software : "DeepLabCut", "SLEAP" or "LightningPose"
source_software : "DeepLabCut", "SLEAP", "LightningPose", or "Anipose"
The source software of the file.
fps : float, optional
The number of frames per second in the video. If None (default),
the ``time`` coordinates will be in frame numbers.
**kwargs : dict, optional
Additional keyword arguments to pass to the software-specific
loading functions that are listed under "See Also".
Returns
-------
Expand All @@ -121,6 +132,7 @@ def from_file(
movement.io.load_poses.from_dlc_file
movement.io.load_poses.from_sleap_file
movement.io.load_poses.from_lp_file
movement.io.load_poses.from_anipose_file
Examples
--------
Expand All @@ -136,6 +148,8 @@ def from_file(
return from_sleap_file(file_path, fps)
elif source_software == "LightningPose":
return from_lp_file(file_path, fps)
elif source_software == "Anipose":
return from_anipose_file(file_path, fps, **kwargs)
else:
raise log_error(
ValueError, f"Unsupported source software: {source_software}"
Expand Down Expand Up @@ -696,3 +710,116 @@ def _ds_from_valid_data(data: ValidPosesDataset) -> xr.Dataset:
"ds_type": "poses",
},
)


def from_anipose_style_df(
df: pd.DataFrame,
fps: float | None = None,
individual_name: str = "individual_0",
) -> xr.Dataset:
"""Create a ``movement`` poses dataset from an Anipose 3D dataframe.
Parameters
----------
df : pd.DataFrame
Anipose triangulation dataframe
fps : float, optional
The number of frames per second in the video. If None (default),
the ``time`` coordinates will be in frame units.
individual_name : str, optional
Name of the individual, by default "individual_0"
Returns
-------
xarray.Dataset
``movement`` dataset containing the pose tracks, confidence scores,
and associated metadata.
Notes
-----
Reshape dataframe with columns keypoint1_x, keypoint1_y, keypoint1_z,
keypoint1_score,keypoint2_x, keypoint2_y, keypoint2_z,
keypoint2_score...to array of positions with dimensions
time, space, keypoints, individuals, and array of confidence (from scores)
with dimensions time, keypoints, individuals.
"""
keypoint_names = sorted(
list(
set(
[
col.rsplit("_", 1)[0]
for col in df.columns
if any(col.endswith(f"_{s}") for s in ["x", "y", "z"])
]
)
)
)

n_frames = len(df)
n_keypoints = len(keypoint_names)

# Initialize arrays and fill
position_array = np.zeros(
(n_frames, 3, n_keypoints, 1)
) # 1 for single individual
confidence_array = np.zeros((n_frames, n_keypoints, 1))
for i, kp in enumerate(keypoint_names):
for j, coord in enumerate(["x", "y", "z"]):
position_array[:, j, i, 0] = df[f"{kp}_{coord}"]
confidence_array[:, i, 0] = df[f"{kp}_score"]

individual_names = [individual_name]

return from_numpy(
position_array=position_array,
confidence_array=confidence_array,
individual_names=individual_names,
keypoint_names=keypoint_names,
source_software="Anipose",
fps=fps,
)


def from_anipose_file(
file_path: Path | str,
fps: float | None = None,
individual_name: str = "individual_0",
) -> xr.Dataset:
"""Create a ``movement`` poses dataset from an Anipose 3D .csv file.
Parameters
----------
file_path : pathlib.Path
Path to the Anipose triangulation .csv file
fps : float, optional
The number of frames per second in the video. If None (default),
the ``time`` coordinates will be in frame units.
individual_name : str, optional
Name of the individual, by default "individual_0"
Returns
-------
xarray.Dataset
``movement`` dataset containing the pose tracks, confidence scores,
and associated metadata.
Notes
-----
We currently do not load all information, only x, y, z, and score
(confidence) for each keypoint. Future versions will load n of cameras
and error.
"""
file = ValidFile(
file_path,
expected_permission="r",
expected_suffix=[".csv"],
)
anipose_file = ValidAniposeCSV(file.path)
anipose_df = pd.read_csv(anipose_file.path)

return from_anipose_style_df(
anipose_df, fps=fps, individual_name=individual_name
)
88 changes: 88 additions & 0 deletions movement/validators/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,94 @@ def _file_contains_expected_levels(self, attribute, value):
)


@define
class ValidAniposeCSV:
"""Class for validating Anipose-style 3D pose .csv files.
The validator ensures that the file contains the
expected column names in its header (first row).
Attributes
----------
path : pathlib.Path
Path to the .csv file.
Raises
------
ValueError
If the .csv file does not contain the expected Anipose columns.
"""

path: Path = field(validator=validators.instance_of(Path))

@path.validator
def _file_contains_expected_columns(self, attribute, value):
"""Ensure that the .csv file contains the expected columns."""
expected_column_suffixes = [
"_x",
"_y",
"_z",
"_score",
"_error",
"_ncams",
]
expected_non_keypoint_columns = [
"fnum",
"center_0",
"center_1",
"center_2",
"M_00",
"M_01",
"M_02",
"M_10",
"M_11",
"M_12",
"M_20",
"M_21",
"M_22",
]

# Read the first line of the CSV to get the headers
with open(value) as f:
columns = f.readline().strip().split(",")

# Check that all expected headers are present
if not all(col in columns for col in expected_non_keypoint_columns):
raise log_error(
ValueError,
"CSV file is missing some expected columns."
f"Expected: {expected_non_keypoint_columns}.",
)

# For other headers, check they have expected suffixes and base names
other_columns = [
col for col in columns if col not in expected_non_keypoint_columns
]
for column in other_columns:
# Check suffix
if not any(
column.endswith(suffix) for suffix in expected_column_suffixes
):
raise log_error(
ValueError,
f"Column {column} ends with an unexpected suffix.",
)
# Get base name by removing suffix
base = column.rsplit("_", 1)[0]
# Check base name has all expected suffixes
if not all(
f"{base}{suffix}" in columns
for suffix in expected_column_suffixes
):
raise log_error(
ValueError,
f"Keypoint {base} is missing some expected suffixes."
f"Expected: {expected_column_suffixes};"
f"Got: {columns}.",
)


@define
class ValidVIATracksCSV:
"""Class for validating VIA tracks .csv files.
Expand Down
55 changes: 55 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,61 @@ def dlc_style_df():
return pd.read_hdf(pytest.DATA_PATHS.get("DLC_single-wasp.predictions.h5"))


@pytest.fixture
def missing_keypoint_columns_anipose_csv_file(tmp_path):
"""Return the file path for a fake single-individual .csv file."""
file_path = tmp_path / "missing_keypoint_columns.csv"
columns = [
"fnum",
"center_0",
"center_1",
"center_2",
"M_00",
"M_01",
"M_02",
"M_10",
"M_11",
"M_12",
"M_20",
"M_21",
"M_22",
]
# Here we are missing kp0_z:
columns.extend(["kp0_x", "kp0_y", "kp0_score", "kp0_error", "kp0_ncams"])
with open(file_path, "w") as f:
f.write(",".join(columns))
f.write("\n")
f.write(",".join(["1"] * len(columns)))
return file_path


@pytest.fixture
def spurious_column_anipose_csv_file(tmp_path):
"""Return the file path for a fake single-individual .csv file."""
file_path = tmp_path / "spurious_column.csv"
columns = [
"fnum",
"center_0",
"center_1",
"center_2",
"M_00",
"M_01",
"M_02",
"M_10",
"M_11",
"M_12",
"M_20",
"M_21",
"M_22",
]
columns.extend(["funny_column"])
with open(file_path, "w") as f:
f.write(",".join(columns))
f.write("\n")
f.write(",".join(["1"] * len(columns)))
return file_path


@pytest.fixture(
params=[
"SLEAP_single-mouse_EPM.analysis.h5",
Expand Down
Loading

0 comments on commit d190ce5

Please sign in to comment.