-
Notifications
You must be signed in to change notification settings - Fork 30
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Added `fetch_pose_data()` function * Construct download registry from newest `metadata.yaml` * Updated `datasets.py` module name, docstrings, and functions * Added `from_lp_file` to `fetch_sample_data` * Renamed `sample_dataset.py` and added test for `fetch_sample_data` * Updated docs and docstrings * Fixed `sample_data.fetch_sample_data()` to load data with correct FPS and renamed `POSE_DATA` download manager * Cleaned up `pyproject.toml` dependencies and improved metadata-fetching logic in `sample_data.py` * Removed hard-coded list of sample file names in `conftest.py` * Minor cleanup of docs and docstrings * Clarified "Adding New Data" instructions on `CONTRIBUTING.md` Co-authored-by: Niko Sirmpilatze <niko.sirbiladze@gmail.com> * Small edit to `getting_started.md` * Extended `fetch_sample_data_path()` to catch case in which file is not in the registry + added test for `list_sample_data()` * Fetch metadata using `pooch.retrieve()` * Fixed bug in `sample_data.py` * update fetch_metadata function * refactored test_sample_data using a fixture * refactor and test fetching of metadata * more explicit mention of sample metadata in contributing guide * renamed POSE_DATA to POSE_DATA_PATHS in testing suite, to be more explicit --------- Co-authored-by: Niko Sirmpilatze <niko.sirbiladze@gmail.com>
Showing
13 changed files
with
386 additions
and
144 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,188 @@ | ||
"""Module for fetching and loading sample datasets. | ||
This module provides functions for fetching and loading sample data used in | ||
tests, examples, and tutorials. The data are stored in a remote repository | ||
on GIN and are downloaded to the user's local machine the first time they | ||
are used. | ||
""" | ||
|
||
import logging | ||
from pathlib import Path | ||
|
||
import pooch | ||
import xarray | ||
import yaml | ||
from requests.exceptions import RequestException | ||
|
||
from movement.io import load_poses | ||
from movement.logging import log_error, log_warning | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
# URL to the remote data repository on GIN | ||
# noinspection PyInterpreter | ||
DATA_URL = ( | ||
"https://gin.g-node.org/neuroinformatics/movement-test-data/raw/master" | ||
) | ||
|
||
# Save data in ~/.movement/data | ||
DATA_DIR = Path("~", ".movement", "data").expanduser() | ||
# Create the folder if it doesn't exist | ||
DATA_DIR.mkdir(parents=True, exist_ok=True) | ||
|
||
|
||
def _download_metadata_file(file_name: str, data_dir: Path = DATA_DIR) -> Path: | ||
"""Download the yaml file containing sample metadata from the *movement* | ||
data repository and save it in the specified directory with a temporary | ||
filename - temp_{file_name} - to avoid overwriting any existing files. | ||
Parameters | ||
---------- | ||
file_name : str | ||
Name of the metadata file to fetch. | ||
data_dir : pathlib.Path, optional | ||
Directory to store the metadata file in. Defaults to the constant | ||
``DATA_DIR``. Can be overridden for testing purposes. | ||
Returns | ||
------- | ||
path : pathlib.Path | ||
Path to the downloaded file. | ||
""" | ||
local_file_path = pooch.retrieve( | ||
url=f"{DATA_URL}/{file_name}", | ||
known_hash=None, | ||
path=data_dir, | ||
fname=f"temp_{file_name}", | ||
progressbar=False, | ||
) | ||
logger.debug( | ||
f"Successfully downloaded sample metadata file {file_name} " | ||
f"from {DATA_URL} to {data_dir}" | ||
) | ||
return Path(local_file_path) | ||
|
||
|
||
def _fetch_metadata(file_name: str, data_dir: Path = DATA_DIR) -> list[dict]: | ||
"""Download the yaml file containing metadata from the *movement* sample | ||
data repository and load it as a list of dictionaries. | ||
Parameters | ||
---------- | ||
file_name : str | ||
Name of the metadata file to fetch. | ||
data_dir : pathlib.Path, optional | ||
Directory to store the metadata file in. Defaults to | ||
the constant ``DATA_DIR``. Can be overridden for testing purposes. | ||
Returns | ||
------- | ||
list[dict] | ||
A list of dictionaries containing metadata for each sample file. | ||
""" | ||
|
||
local_file_path = Path(data_dir / file_name) | ||
failed_msg = "Failed to download the newest sample metadata file." | ||
|
||
# try downloading the newest metadata file | ||
try: | ||
downloaded_file_path = _download_metadata_file(file_name, data_dir) | ||
# if download succeeds, replace any existing local metadata file | ||
downloaded_file_path.replace(local_file_path) | ||
# if download fails, try loading an existing local metadata file, | ||
# otherwise raise an error | ||
except RequestException as exc_info: | ||
if local_file_path.is_file(): | ||
log_warning( | ||
f"{failed_msg} Will use the existing local version instead." | ||
) | ||
else: | ||
raise log_error(RequestException, failed_msg) from exc_info | ||
|
||
with open(local_file_path, "r") as metadata_file: | ||
metadata = yaml.safe_load(metadata_file) | ||
return metadata | ||
|
||
|
||
metadata = _fetch_metadata("poses_files_metadata.yaml") | ||
|
||
# Create a download manager for the pose data | ||
SAMPLE_DATA = pooch.create( | ||
path=DATA_DIR / "poses", | ||
base_url=f"{DATA_URL}/poses/", | ||
retry_if_failed=0, | ||
registry={file["file_name"]: file["sha256sum"] for file in metadata}, | ||
) | ||
|
||
|
||
def list_sample_data() -> list[str]: | ||
"""Find available sample pose data in the *movement* data repository. | ||
Returns | ||
------- | ||
filenames : list of str | ||
List of filenames for available pose data.""" | ||
return list(SAMPLE_DATA.registry.keys()) | ||
|
||
|
||
def fetch_sample_data_path(filename: str) -> Path: | ||
"""Download sample pose data from the *movement* data repository and return | ||
its local filepath. | ||
The data are downloaded to the user's local machine the first time they are | ||
used and are stored in a local cache directory. The function returns the | ||
path to the downloaded file, not the contents of the file itself. | ||
Parameters | ||
---------- | ||
filename : str | ||
Name of the file to fetch. | ||
Returns | ||
------- | ||
path : pathlib.Path | ||
Path to the downloaded file. | ||
""" | ||
try: | ||
return Path(SAMPLE_DATA.fetch(filename, progressbar=True)) | ||
except ValueError: | ||
raise log_error( | ||
ValueError, | ||
f"File '{filename}' is not in the registry. Valid " | ||
f"filenames are: {list_sample_data()}", | ||
) | ||
|
||
|
||
def fetch_sample_data( | ||
filename: str, | ||
) -> xarray.Dataset: | ||
"""Download and return sample pose data from the *movement* data | ||
repository. | ||
The data are downloaded to the user's local machine the first time they are | ||
used and are stored in a local cache directory. Returns sample pose data as | ||
an xarray Dataset. | ||
Parameters | ||
---------- | ||
filename : str | ||
Name of the file to fetch. | ||
Returns | ||
------- | ||
ds : xarray.Dataset | ||
Pose data contained in the fetched sample file. | ||
""" | ||
|
||
file_path = fetch_sample_data_path(filename) | ||
file_metadata = next( | ||
file for file in metadata if file["file_name"] == filename | ||
) | ||
|
||
if file_metadata["source_software"] == "SLEAP": | ||
ds = load_poses.from_sleap_file(file_path, fps=file_metadata["fps"]) | ||
elif file_metadata["source_software"] == "DeepLabCut": | ||
ds = load_poses.from_dlc_file(file_path, fps=file_metadata["fps"]) | ||
elif file_metadata["source_software"] == "LightningPose": | ||
ds = load_poses.from_lp_file(file_path, fps=file_metadata["fps"]) | ||
return ds |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
"""Test suite for the sample_data module.""" | ||
|
||
from unittest.mock import MagicMock, patch | ||
|
||
import pooch | ||
import pytest | ||
from requests.exceptions import RequestException | ||
from xarray import Dataset | ||
|
||
from movement.sample_data import ( | ||
_fetch_metadata, | ||
fetch_sample_data, | ||
list_sample_data, | ||
) | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def valid_file_names_with_fps(): | ||
"""Return a dict containing one valid file name and the corresponding fps | ||
for each supported pose estimation tool.""" | ||
return { | ||
"SLEAP_single-mouse_EPM.analysis.h5": 30, | ||
"DLC_single-wasp.predictions.h5": 40, | ||
"LP_mouse-face_AIND.predictions.csv": 60, | ||
} | ||
|
||
|
||
def validate_metadata(metadata: list[dict]) -> None: | ||
"""Assert that the metadata is in the expected format.""" | ||
metadata_fields = [ | ||
"file_name", | ||
"sha256sum", | ||
"source_software", | ||
"fps", | ||
"species", | ||
"number_of_individuals", | ||
"shared_by", | ||
"video_frame_file", | ||
"note", | ||
] | ||
check_yaml_msg = "Check the format of the metadata yaml file." | ||
assert isinstance( | ||
metadata, list | ||
), f"Expected metadata to be a list. {check_yaml_msg}" | ||
assert all( | ||
isinstance(file, dict) for file in metadata | ||
), f"Expected metadata entries to be dicts. {check_yaml_msg}" | ||
assert all( | ||
set(file.keys()) == set(metadata_fields) for file in metadata | ||
), f"Expected all metadata entries to have the same keys. {check_yaml_msg}" | ||
|
||
# check that filenames are unique | ||
file_names = [file["file_name"] for file in metadata] | ||
assert len(file_names) == len(set(file_names)) | ||
|
||
# check that the first 3 fields are present and are strings | ||
required_fields = metadata_fields[:3] | ||
assert all( | ||
(isinstance(file[field], str)) | ||
for file in metadata | ||
for field in required_fields | ||
) | ||
|
||
|
||
# Mock pooch.retrieve with RequestException as side_effect | ||
mock_retrieve = MagicMock(pooch.retrieve, side_effect=RequestException) | ||
|
||
|
||
@pytest.mark.parametrize("download_fails", [True, False]) | ||
@pytest.mark.parametrize("local_exists", [True, False]) | ||
def test_fetch_metadata(tmp_path, caplog, download_fails, local_exists): | ||
"""Test the fetch_metadata function with different combinations of | ||
failed download and pre-existing local file. The expected behavior is | ||
that the function will try to download the metadata file, and if that | ||
fails, it will try to load an existing local file. If neither succeeds, | ||
an error is raised.""" | ||
metadata_file_name = "poses_files_metadata.yaml" | ||
local_file_path = tmp_path / metadata_file_name | ||
|
||
with patch("movement.sample_data.DATA_DIR", tmp_path): | ||
# simulate the existence of a local metadata file | ||
if local_exists: | ||
local_file_path.touch() | ||
|
||
if download_fails: | ||
# simulate a failed download | ||
with patch("movement.sample_data.pooch.retrieve", mock_retrieve): | ||
if local_exists: | ||
_fetch_metadata(metadata_file_name) | ||
# check that a warning was logged | ||
assert ( | ||
"Will use the existing local version instead" | ||
in caplog.records[-1].getMessage() | ||
) | ||
else: | ||
with pytest.raises( | ||
RequestException, match="Failed to download" | ||
): | ||
_fetch_metadata(metadata_file_name, data_dir=tmp_path) | ||
else: | ||
metadata = _fetch_metadata(metadata_file_name, data_dir=tmp_path) | ||
assert local_file_path.is_file() | ||
validate_metadata(metadata) | ||
|
||
|
||
def test_list_sample_data(valid_file_names_with_fps): | ||
assert isinstance(list_sample_data(), list) | ||
assert all( | ||
file in list_sample_data() for file in valid_file_names_with_fps | ||
) | ||
|
||
|
||
def test_fetch_sample_data(valid_file_names_with_fps): | ||
# test with valid files | ||
for file, fps in valid_file_names_with_fps.items(): | ||
ds = fetch_sample_data(file) | ||
assert isinstance(ds, Dataset) and ds.fps == fps | ||
|
||
# Test with an invalid file | ||
with pytest.raises(ValueError): | ||
fetch_sample_data("nonexistent_file") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters