Skip to content

Commit

Permalink
Update datasets.py (#86)
Browse files Browse the repository at this point in the history
* Added `fetch_pose_data()` function

* Construct download registry from newest `metadata.yaml`

* Updated `datasets.py` module name, docstrings, and functions

* Added `from_lp_file` to `fetch_sample_data`

* Renamed `sample_dataset.py` and added test for `fetch_sample_data`

* Updated docs and docstrings

* Fixed `sample_data.fetch_sample_data()` to load data with correct FPS and renamed `POSE_DATA` download manager

* Cleaned up `pyproject.toml` dependencies and improved metadata-fetching logic in `sample_data.py`

* Removed hard-coded list of sample file names in `conftest.py`

* Minor cleanup of docs and docstrings

* Clarified "Adding New Data" instructions on `CONTRIBUTING.md`

Co-authored-by: Niko Sirmpilatze <niko.sirbiladze@gmail.com>

* Small edit to `getting_started.md`

* Extended `fetch_sample_data_path()` to catch case in which file is not in the registry + added test for `list_sample_data()`

* Fetch metadata using `pooch.retrieve()`

* Fixed bug in `sample_data.py`

* update fetch_metadata function

* refactored test_sample_data using a fixture

* refactor and test fetching of metadata

* more explicit mention of sample metadata in contributing guide

* renamed POSE_DATA to POSE_DATA_PATHS in testing suite, to be more explicit

---------

Co-authored-by: Niko Sirmpilatze <niko.sirbiladze@gmail.com>
b-peri and niksirbi authored Jan 10, 2024
1 parent 5249063 commit 03ddaf1
Showing 13 changed files with 386 additions and 144 deletions.
2 changes: 2 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -32,6 +32,8 @@ repos:
- types-setuptools
- pandas-stubs
- types-attrs
- types-PyYAML
- types-requests
- repo: https://github.com/mgedmin/check-manifest
rev: "0.49"
hooks:
28 changes: 14 additions & 14 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
@@ -259,25 +259,26 @@ by the [German Neuroinformatics Node](https://www.g-node.org/).
GIN has a GitHub-like interface and git-like
[CLI](gin:G-Node/Info/wiki/GIN+CLI+Setup#quickstart) functionalities.

Currently the data repository contains sample pose estimation data files
stored in the `poses` folder. Each file name starts with either "DLC" or "SLEAP",
depending on the pose estimation software used to generate the data.
Currently, the data repository contains sample pose estimation data files
stored in the `poses` folder. Metadata for these files, including information
about their provenance, is stored in the `poses_files_metadata.yaml` file.

### Fetching data
To fetch the data from GIN, we use the [pooch](https://www.fatiando.org/pooch/latest/index.html)
Python package, which can download data from pre-specified URLs and store them
locally for all subsequent uses. It also provides some nice utilities,
like verification of sha256 hashes and decompression of archives.

The relevant functionality is implemented in the `movement.datasets.py` module.
The relevant functionality is implemented in the `movement.sample_data.py` module.
The most important parts of this module are:

1. The `POSE_DATA` download manager object, which contains a list of stored files and their known hashes.
2. The `list_pose_data()` function, which returns a list of the available files in the data repository.
3. The `fetch_pose_data_path()` function, which downloads a file (if not already cached locally) and returns the local path to it.
1. The `SAMPLE_DATA` download manager object.
2. The `list_sample_data()` function, which returns a list of the available files in the data repository.
3. The `fetch_sample_data_path()` function, which downloads a file (if not already cached locally) and returns the local path to it.
4. The `fetch_sample_data()` function, which downloads a file and loads it into movement directly, returning an `xarray.Dataset` object.

By default, the downloaded files are stored in the `~/.movement/data` folder.
This can be changed by setting the `DATA_DIR` variable in the `movement.datasets.py` module.
This can be changed by setting the `DATA_DIR` variable in the `movement.sample_data.py` module.

### Adding new data
Only core movement developers may add new files to the external data repository.
@@ -287,9 +288,8 @@ To add a new file, you will need to:
2. Ask to be added as a collaborator on the [movement data repository](gin:neuroinformatics/movement-test-data) (if not already)
3. Download the [GIN CLI](gin:G-Node/Info/wiki/GIN+CLI+Setup#quickstart) and set it up with your GIN credentials, by running `gin login` in a terminal.
4. Clone the movement data repository to your local machine, by running `gin get neuroinformatics/movement-test-data` in a terminal.
5. Add your new files and commit them with `gin commit -m <message> <filename>`.
6. Upload the commited changes to the GIN repository, by running `gin upload`. Latest changes to the repository can be pulled via `gin download`. `gin sync` will synchronise the latest changes bidirectionally.
7. Determine the sha256 checksum hash of each new file, by running `sha256sum <filename>` in a terminal. Alternatively, you can use `pooch` to do this for you: `python -c "import pooch; pooch.file_hash('/path/to/file')"`. If you wish to generate a text file containing the hashes of all the files in a given folder, you can use `python -c "import pooch; pooch.make_registry('/path/to/folder', 'sha256_registry.txt')`.
8. Update the `movement.datasets.py` module on the [movement GitHub repository](movement-github:) by adding the new files to the `POSE_DATA` registry. Make sure to include the correct sha256 hash, as determined in the previous step. Follow all the usual [guidelines for contributing code](#contributing-code). Make sure to test whether the new files can be fetched successfully (see [fetching data](#fetching-data) above) before submitting your pull request.

You can also perform steps 3-6 via the GIN web interface, if you prefer to avoid using the CLI.
5. Add your new files to `/movement-test-data/poses/`.
6. Determine the sha256 checksum hash of each new file by running `sha256sum <filename>` in a terminal. Alternatively, you can use `pooch` to do this for you: `python -c "import pooch; hash = pooch.file_hash('/path/to/file'); print(hash)"`. If you wish to generate a text file containing the hashes of all the files in a given folder, you can use `python -c "import pooch; pooch.make_registry('/path/to/folder', 'sha256_registry.txt')`.
7. Add metadata for your new files to `poses_files_metadata.yaml`, including their sha256 hashes.
8. Commit your changes using `gin commit -m <message> <filename>`.
9. Upload the committed changes to the GIN repository by running `gin upload`. Latest changes to the repository can be pulled via `gin download`. `gin sync` will synchronise the latest changes bidirectionally.
11 changes: 6 additions & 5 deletions docs/source/api_index.rst
Original file line number Diff line number Diff line change
@@ -33,14 +33,15 @@ Input/Output
ValidPosesCSV
ValidPoseTracks

Datasets
--------
.. currentmodule:: movement.datasets
Sample Data
-----------
.. currentmodule:: movement.sample_data
.. autosummary::
:toctree: api

list_pose_data
fetch_pose_data_path
list_sample_data
fetch_sample_data_path
fetch_sample_data

Logging
-------
25 changes: 17 additions & 8 deletions docs/source/getting_started.md
Original file line number Diff line number Diff line change
@@ -53,7 +53,7 @@ Please see the [contributing guide](target-contributing) for more information.

## Loading data
You can load predicted pose tracks from the pose estimation software packages
[DeepLabCut](dlc:) or [SLEAP](sleap:).
[DeepLabCut](dlc:), [SLEAP](sleap:), or [LightingPose](lp:).

First import the `movement.io.load_poses` module:

@@ -114,27 +114,36 @@ You can also try movement out on some sample data included in the package.
You can view the available sample data files with:

```python
from movement import datasets
from movement import sample_data

file_names = datasets.list_pose_data()
file_names = sample_data.list_sample_data()
print(file_names)
```

This will print a list of file names containing sample pose data.
The files are prefixed with the name of the pose estimation software package,
either "DLC" or "SLEAP".
Each file is prefixed with the name of the pose estimation software package
that was used to generate it - either "DLC", "SLEAP", or "LP".

To get the path to one of the sample files,
you can use the `fetch_pose_data_path` function:

```python
file_path = datasets.fetch_pose_data_path("DLC_two-mice.predictions.csv")
file_path = sample_data.fetch_sample_data_path("DLC_two-mice.predictions.csv")
```
The first time you call this function, it will download the corresponding file
to your local machine and save it in the `~/.movement/data` directory. On
subsequent calls, it will simply return the path to that local file.

You can feed the path to the `from_dlc_file` or `from_sleap_file` functions
and load the data, as shown above.
You can feed the path to the `from_dlc_file`, `from_sleap_file`, or
`from_lp_file` functions and load the data, as shown above.

Alternatively, you can skip the `fetch_sample_data_path()` step and load the
data directly using the `fetch_sample_data()` function:

```python
ds = sample_data.fetch_sample_data("DLC_two-mice.predictions.csv")
```

:::

## Working with movement datasets
6 changes: 3 additions & 3 deletions examples/load_and_explore_poses.py
Original file line number Diff line number Diff line change
@@ -10,22 +10,22 @@
# -------
from matplotlib import pyplot as plt

from movement import datasets
from movement import sample_data
from movement.io import load_poses

# %%
# Fetch an example dataset
# ------------------------
# Print a list of available datasets:

for file_name in datasets.list_pose_data():
for file_name in sample_data.list_sample_data():
print(file_name)

# %%
# Fetch the path to an example dataset.
# Feel free to replace this with the path to your own dataset.
# e.g., ``file_path = "/path/to/my/data.h5"``)
file_path = datasets.fetch_pose_data_path(
file_path = sample_data.fetch_sample_data_path(
"SLEAP_three-mice_Aeon_proofread.analysis.h5"
)

74 changes: 0 additions & 74 deletions movement/datasets.py

This file was deleted.

188 changes: 188 additions & 0 deletions movement/sample_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
"""Module for fetching and loading sample datasets.
This module provides functions for fetching and loading sample data used in
tests, examples, and tutorials. The data are stored in a remote repository
on GIN and are downloaded to the user's local machine the first time they
are used.
"""

import logging
from pathlib import Path

import pooch
import xarray
import yaml
from requests.exceptions import RequestException

from movement.io import load_poses
from movement.logging import log_error, log_warning

logger = logging.getLogger(__name__)

# URL to the remote data repository on GIN
# noinspection PyInterpreter
DATA_URL = (
"https://gin.g-node.org/neuroinformatics/movement-test-data/raw/master"
)

# Save data in ~/.movement/data
DATA_DIR = Path("~", ".movement", "data").expanduser()
# Create the folder if it doesn't exist
DATA_DIR.mkdir(parents=True, exist_ok=True)


def _download_metadata_file(file_name: str, data_dir: Path = DATA_DIR) -> Path:
"""Download the yaml file containing sample metadata from the *movement*
data repository and save it in the specified directory with a temporary
filename - temp_{file_name} - to avoid overwriting any existing files.
Parameters
----------
file_name : str
Name of the metadata file to fetch.
data_dir : pathlib.Path, optional
Directory to store the metadata file in. Defaults to the constant
``DATA_DIR``. Can be overridden for testing purposes.
Returns
-------
path : pathlib.Path
Path to the downloaded file.
"""
local_file_path = pooch.retrieve(
url=f"{DATA_URL}/{file_name}",
known_hash=None,
path=data_dir,
fname=f"temp_{file_name}",
progressbar=False,
)
logger.debug(
f"Successfully downloaded sample metadata file {file_name} "
f"from {DATA_URL} to {data_dir}"
)
return Path(local_file_path)


def _fetch_metadata(file_name: str, data_dir: Path = DATA_DIR) -> list[dict]:
"""Download the yaml file containing metadata from the *movement* sample
data repository and load it as a list of dictionaries.
Parameters
----------
file_name : str
Name of the metadata file to fetch.
data_dir : pathlib.Path, optional
Directory to store the metadata file in. Defaults to
the constant ``DATA_DIR``. Can be overridden for testing purposes.
Returns
-------
list[dict]
A list of dictionaries containing metadata for each sample file.
"""

local_file_path = Path(data_dir / file_name)
failed_msg = "Failed to download the newest sample metadata file."

# try downloading the newest metadata file
try:
downloaded_file_path = _download_metadata_file(file_name, data_dir)
# if download succeeds, replace any existing local metadata file
downloaded_file_path.replace(local_file_path)
# if download fails, try loading an existing local metadata file,
# otherwise raise an error
except RequestException as exc_info:
if local_file_path.is_file():
log_warning(
f"{failed_msg} Will use the existing local version instead."
)
else:
raise log_error(RequestException, failed_msg) from exc_info

with open(local_file_path, "r") as metadata_file:
metadata = yaml.safe_load(metadata_file)
return metadata


metadata = _fetch_metadata("poses_files_metadata.yaml")

# Create a download manager for the pose data
SAMPLE_DATA = pooch.create(
path=DATA_DIR / "poses",
base_url=f"{DATA_URL}/poses/",
retry_if_failed=0,
registry={file["file_name"]: file["sha256sum"] for file in metadata},
)


def list_sample_data() -> list[str]:
"""Find available sample pose data in the *movement* data repository.
Returns
-------
filenames : list of str
List of filenames for available pose data."""
return list(SAMPLE_DATA.registry.keys())


def fetch_sample_data_path(filename: str) -> Path:
"""Download sample pose data from the *movement* data repository and return
its local filepath.
The data are downloaded to the user's local machine the first time they are
used and are stored in a local cache directory. The function returns the
path to the downloaded file, not the contents of the file itself.
Parameters
----------
filename : str
Name of the file to fetch.
Returns
-------
path : pathlib.Path
Path to the downloaded file.
"""
try:
return Path(SAMPLE_DATA.fetch(filename, progressbar=True))
except ValueError:
raise log_error(
ValueError,
f"File '{filename}' is not in the registry. Valid "
f"filenames are: {list_sample_data()}",
)


def fetch_sample_data(
filename: str,
) -> xarray.Dataset:
"""Download and return sample pose data from the *movement* data
repository.
The data are downloaded to the user's local machine the first time they are
used and are stored in a local cache directory. Returns sample pose data as
an xarray Dataset.
Parameters
----------
filename : str
Name of the file to fetch.
Returns
-------
ds : xarray.Dataset
Pose data contained in the fetched sample file.
"""

file_path = fetch_sample_data_path(filename)
file_metadata = next(
file for file in metadata if file["file_name"] == filename
)

if file_metadata["source_software"] == "SLEAP":
ds = load_poses.from_sleap_file(file_path, fps=file_metadata["fps"])
elif file_metadata["source_software"] == "DeepLabCut":
ds = load_poses.from_dlc_file(file_path, fps=file_metadata["fps"])
elif file_metadata["source_software"] == "LightningPose":
ds = load_poses.from_lp_file(file_path, fps=file_metadata["fps"])
return ds
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -20,6 +20,7 @@ dependencies = [
"tqdm",
"sleap-io",
"xarray",
"PyYAML",
]

classifiers = [
@@ -54,6 +55,8 @@ dev = [
"pandas-stubs",
"types-attrs",
"check-manifest",
"types-PyYAML",
"types-requests",
]

[build-system]
26 changes: 8 additions & 18 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -8,29 +8,17 @@
import pytest
import xarray as xr

from movement.datasets import fetch_pose_data_path
from movement.io import PosesAccessor
from movement.logging import configure_logging
from movement.sample_data import fetch_sample_data_path, list_sample_data


def pytest_configure():
"""Perform initial configuration for pytest.
Fetches pose data file paths as a dictionary for tests."""
pytest.POSE_DATA = {
file_name: fetch_pose_data_path(file_name)
for file_name in [
"DLC_single-wasp.predictions.h5",
"DLC_single-wasp.predictions.csv",
"DLC_two-mice.predictions.csv",
"SLEAP_single-mouse_EPM.analysis.h5",
"SLEAP_single-mouse_EPM.predictions.slp",
"SLEAP_three-mice_Aeon_proofread.analysis.h5",
"SLEAP_three-mice_Aeon_proofread.predictions.slp",
"SLEAP_three-mice_Aeon_mixed-labels.analysis.h5",
"SLEAP_three-mice_Aeon_mixed-labels.predictions.slp",
"LP_mouse-face_AIND.predictions.csv",
"LP_mouse-twoview_AIND.predictions.csv",
]
pytest.POSE_DATA_PATHS = {
file_name: fetch_sample_data_path(file_name)
for file_name in list_sample_data()
}


@@ -186,7 +174,9 @@ def new_csv_file(tmp_path):
@pytest.fixture
def dlc_style_df():
"""Return a valid DLC-style DataFrame."""
return pd.read_hdf(pytest.POSE_DATA.get("DLC_single-wasp.predictions.h5"))
return pd.read_hdf(
pytest.POSE_DATA_PATHS.get("DLC_single-wasp.predictions.h5")
)


@pytest.fixture(
@@ -201,7 +191,7 @@ def dlc_style_df():
)
def sleap_file(request):
"""Return the file path for a SLEAP .h5 or .slp file."""
return pytest.POSE_DATA.get(request.param)
return pytest.POSE_DATA_PATHS.get(request.param)


@pytest.fixture
6 changes: 3 additions & 3 deletions tests/test_integration/test_io.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@
import numpy as np
import pytest
import xarray as xr
from pytest import POSE_DATA
from pytest import POSE_DATA_PATHS

from movement.io import load_poses, save_poses

@@ -56,7 +56,7 @@ def test_to_sleap_analysis_file_returns_same_h5_file_content(
"""Test that saving pose tracks (loaded from a SLEAP analysis
file) to a SLEAP-style .h5 analysis file returns the same file
contents."""
sleap_h5_file_path = POSE_DATA.get(sleap_h5_file)
sleap_h5_file_path = POSE_DATA_PATHS.get(sleap_h5_file)
ds = load_poses.from_sleap_file(sleap_h5_file_path, fps=fps)
save_poses.to_sleap_analysis_file(ds, new_h5_file)

@@ -85,7 +85,7 @@ def test_to_sleap_analysis_file_source_file(self, file, new_h5_file):
"""Test that saving pose tracks (loaded from valid source files)
to a SLEAP-style .h5 analysis file stores the .slp labels path
only when the source file is a .slp file."""
file_path = POSE_DATA.get(file)
file_path = POSE_DATA_PATHS.get(file)
if file.startswith("DLC"):
ds = load_poses.from_dlc_file(file_path)
else:
28 changes: 15 additions & 13 deletions tests/test_unit/test_load_poses.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@
import numpy as np
import pytest
import xarray as xr
from pytest import POSE_DATA
from pytest import POSE_DATA_PATHS
from sleap_io.io.slp import read_labels, write_labels
from sleap_io.model.labels import LabeledFrame, Labels

@@ -15,7 +15,9 @@ class TestLoadPoses:
@pytest.fixture
def sleap_slp_file_without_tracks(self, tmp_path):
"""Mock and return the path to a SLEAP .slp file without tracks."""
sleap_file = POSE_DATA.get("SLEAP_single-mouse_EPM.predictions.slp")
sleap_file = POSE_DATA_PATHS.get(
"SLEAP_single-mouse_EPM.predictions.slp"
)
labels = read_labels(sleap_file)
file_path = tmp_path / "track_is_none.slp"
lfs = []
@@ -43,7 +45,7 @@ def sleap_slp_file_without_tracks(self, tmp_path):
@pytest.fixture
def sleap_h5_file_without_tracks(self, tmp_path):
"""Mock and return the path to a SLEAP .h5 file without tracks."""
sleap_file = POSE_DATA.get("SLEAP_single-mouse_EPM.analysis.h5")
sleap_file = POSE_DATA_PATHS.get("SLEAP_single-mouse_EPM.analysis.h5")
file_path = tmp_path / "track_is_none.h5"
with h5py.File(sleap_file, "r") as f1, h5py.File(file_path, "w") as f2:
for key in list(f1.keys()):
@@ -112,7 +114,7 @@ def test_load_from_sleap_file_without_tracks(
sleap_file_without_tracks
)
ds_from_tracked = load_poses.from_sleap_file(
POSE_DATA.get("SLEAP_single-mouse_EPM.analysis.h5")
POSE_DATA_PATHS.get("SLEAP_single-mouse_EPM.analysis.h5")
)
# Check if the "individuals" coordinate matches
# the assigned default "individuals_0"
@@ -144,8 +146,8 @@ def test_load_from_sleap_slp_file_or_h5_file_returns_same(
):
"""Test that loading pose tracks from SLEAP .slp and .h5 files
return the same Dataset."""
slp_file_path = POSE_DATA.get(slp_file)
h5_file_path = POSE_DATA.get(h5_file)
slp_file_path = POSE_DATA_PATHS.get(slp_file)
h5_file_path = POSE_DATA_PATHS.get(h5_file)
ds_from_slp = load_poses.from_sleap_file(slp_file_path)
ds_from_h5 = load_poses.from_sleap_file(h5_file_path)
xr.testing.assert_allclose(ds_from_h5, ds_from_slp)
@@ -161,7 +163,7 @@ def test_load_from_sleap_slp_file_or_h5_file_returns_same(
def test_load_from_dlc_file(self, file_name):
"""Test that loading pose tracks from valid DLC files
returns a proper Dataset."""
file_path = POSE_DATA.get(file_name)
file_path = POSE_DATA_PATHS.get(file_name)
ds = load_poses.from_dlc_file(file_path)
self.assert_dataset(ds, file_path, "DeepLabCut")

@@ -174,8 +176,8 @@ def test_load_from_dlc_df(self, dlc_style_df):
def test_load_from_dlc_file_csv_or_h5_file_returns_same(self):
"""Test that loading pose tracks from DLC .csv and .h5 files
return the same Dataset."""
csv_file_path = POSE_DATA.get("DLC_single-wasp.predictions.csv")
h5_file_path = POSE_DATA.get("DLC_single-wasp.predictions.h5")
csv_file_path = POSE_DATA_PATHS.get("DLC_single-wasp.predictions.csv")
h5_file_path = POSE_DATA_PATHS.get("DLC_single-wasp.predictions.h5")
ds_from_csv = load_poses.from_dlc_file(csv_file_path)
ds_from_h5 = load_poses.from_dlc_file(h5_file_path)
xr.testing.assert_allclose(ds_from_h5, ds_from_csv)
@@ -193,7 +195,7 @@ def test_load_from_dlc_file_csv_or_h5_file_returns_same(self):
def test_fps_and_time_coords(self, fps, expected_fps, expected_time_unit):
"""Test that time coordinates are set according to the provided fps."""
ds = load_poses.from_sleap_file(
POSE_DATA.get("SLEAP_three-mice_Aeon_proofread.analysis.h5"),
POSE_DATA_PATHS.get("SLEAP_three-mice_Aeon_proofread.analysis.h5"),
fps=fps,
)
assert ds.time_unit == expected_time_unit
@@ -216,15 +218,15 @@ def test_fps_and_time_coords(self, fps, expected_fps, expected_time_unit):
def test_load_from_lp_file(self, file_name):
"""Test that loading pose tracks from valid LightningPose (LP) files
returns a proper Dataset."""
file_path = POSE_DATA.get(file_name)
file_path = POSE_DATA_PATHS.get(file_name)
ds = load_poses.from_lp_file(file_path)
self.assert_dataset(ds, file_path, "LightningPose")

def test_load_from_lp_or_dlc_file_returns_same(self):
"""Test that loading a single-animal DeepLabCut-style .csv file
using either the `from_lp_file` or `from_dlc_file` function
returns the same Dataset (except for the source_software)."""
file_path = POSE_DATA.get("LP_mouse-face_AIND.predictions.csv")
file_path = POSE_DATA_PATHS.get("LP_mouse-face_AIND.predictions.csv")
ds_drom_lp = load_poses.from_lp_file(file_path)
ds_from_dlc = load_poses.from_dlc_file(file_path)
xr.testing.assert_allclose(ds_from_dlc, ds_drom_lp)
@@ -234,6 +236,6 @@ def test_load_from_lp_or_dlc_file_returns_same(self):
def test_load_multi_animal_from_lp_file_raises(self):
"""Test that loading a multi-animal .csv file using the
`from_lp_file` function raises a ValueError."""
file_path = POSE_DATA.get("DLC_two-mice.predictions.csv")
file_path = POSE_DATA_PATHS.get("DLC_two-mice.predictions.csv")
with pytest.raises(ValueError):
load_poses.from_lp_file(file_path)
121 changes: 121 additions & 0 deletions tests/test_unit/test_sample_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
"""Test suite for the sample_data module."""

from unittest.mock import MagicMock, patch

import pooch
import pytest
from requests.exceptions import RequestException
from xarray import Dataset

from movement.sample_data import (
_fetch_metadata,
fetch_sample_data,
list_sample_data,
)


@pytest.fixture(scope="module")
def valid_file_names_with_fps():
"""Return a dict containing one valid file name and the corresponding fps
for each supported pose estimation tool."""
return {
"SLEAP_single-mouse_EPM.analysis.h5": 30,
"DLC_single-wasp.predictions.h5": 40,
"LP_mouse-face_AIND.predictions.csv": 60,
}


def validate_metadata(metadata: list[dict]) -> None:
"""Assert that the metadata is in the expected format."""
metadata_fields = [
"file_name",
"sha256sum",
"source_software",
"fps",
"species",
"number_of_individuals",
"shared_by",
"video_frame_file",
"note",
]
check_yaml_msg = "Check the format of the metadata yaml file."
assert isinstance(
metadata, list
), f"Expected metadata to be a list. {check_yaml_msg}"
assert all(
isinstance(file, dict) for file in metadata
), f"Expected metadata entries to be dicts. {check_yaml_msg}"
assert all(
set(file.keys()) == set(metadata_fields) for file in metadata
), f"Expected all metadata entries to have the same keys. {check_yaml_msg}"

# check that filenames are unique
file_names = [file["file_name"] for file in metadata]
assert len(file_names) == len(set(file_names))

# check that the first 3 fields are present and are strings
required_fields = metadata_fields[:3]
assert all(
(isinstance(file[field], str))
for file in metadata
for field in required_fields
)


# Mock pooch.retrieve with RequestException as side_effect
mock_retrieve = MagicMock(pooch.retrieve, side_effect=RequestException)


@pytest.mark.parametrize("download_fails", [True, False])
@pytest.mark.parametrize("local_exists", [True, False])
def test_fetch_metadata(tmp_path, caplog, download_fails, local_exists):
"""Test the fetch_metadata function with different combinations of
failed download and pre-existing local file. The expected behavior is
that the function will try to download the metadata file, and if that
fails, it will try to load an existing local file. If neither succeeds,
an error is raised."""
metadata_file_name = "poses_files_metadata.yaml"
local_file_path = tmp_path / metadata_file_name

with patch("movement.sample_data.DATA_DIR", tmp_path):
# simulate the existence of a local metadata file
if local_exists:
local_file_path.touch()

if download_fails:
# simulate a failed download
with patch("movement.sample_data.pooch.retrieve", mock_retrieve):
if local_exists:
_fetch_metadata(metadata_file_name)
# check that a warning was logged
assert (
"Will use the existing local version instead"
in caplog.records[-1].getMessage()
)
else:
with pytest.raises(
RequestException, match="Failed to download"
):
_fetch_metadata(metadata_file_name, data_dir=tmp_path)
else:
metadata = _fetch_metadata(metadata_file_name, data_dir=tmp_path)
assert local_file_path.is_file()
validate_metadata(metadata)


def test_list_sample_data(valid_file_names_with_fps):
assert isinstance(list_sample_data(), list)
assert all(
file in list_sample_data() for file in valid_file_names_with_fps
)


def test_fetch_sample_data(valid_file_names_with_fps):
# test with valid files
for file, fps in valid_file_names_with_fps.items():
ds = fetch_sample_data(file)
assert isinstance(ds, Dataset) and ds.fps == fps

# Test with an invalid file
with pytest.raises(ValueError):
fetch_sample_data("nonexistent_file")
12 changes: 6 additions & 6 deletions tests/test_unit/test_save_poses.py
Original file line number Diff line number Diff line change
@@ -5,7 +5,7 @@
import pandas as pd
import pytest
import xarray as xr
from pytest import POSE_DATA
from pytest import POSE_DATA_PATHS

from movement.io import load_poses, save_poses

@@ -65,33 +65,33 @@ def output_file_params(self, request):
(np.array([1, 2, 3]), pytest.raises(ValueError)), # incorrect type
(
load_poses.from_dlc_file(
POSE_DATA.get("DLC_single-wasp.predictions.h5")
POSE_DATA_PATHS.get("DLC_single-wasp.predictions.h5")
),
does_not_raise(),
), # valid dataset
(
load_poses.from_dlc_file(
POSE_DATA.get("DLC_two-mice.predictions.csv")
POSE_DATA_PATHS.get("DLC_two-mice.predictions.csv")
),
does_not_raise(),
), # valid dataset
(
load_poses.from_sleap_file(
POSE_DATA.get("SLEAP_single-mouse_EPM.analysis.h5")
POSE_DATA_PATHS.get("SLEAP_single-mouse_EPM.analysis.h5")
),
does_not_raise(),
), # valid dataset
(
load_poses.from_sleap_file(
POSE_DATA.get(
POSE_DATA_PATHS.get(
"SLEAP_three-mice_Aeon_proofread.predictions.slp"
)
),
does_not_raise(),
), # valid dataset
(
load_poses.from_lp_file(
POSE_DATA.get("LP_mouse-face_AIND.predictions.csv")
POSE_DATA_PATHS.get("LP_mouse-face_AIND.predictions.csv")
),
does_not_raise(),
), # valid dataset

0 comments on commit 03ddaf1

Please sign in to comment.