diff --git a/.gitignore b/.gitignore index 8046550a..085834ba 100644 --- a/.gitignore +++ b/.gitignore @@ -57,7 +57,9 @@ local_settings.py instance/ # Sphinx documentation -docs/_build/ +docs/build/ +docs/source/auto_examples/ +docs/source/auto_api/ # MkDocs documentation /site/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9f9d06b8..3d5616c8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -25,7 +25,10 @@ repos: hooks: - id: mypy additional_dependencies: + - attrs - types-setuptools + - pandas-stubs + - types-attrs - repo: https://github.com/mgedmin/check-manifest rev: "0.49" hooks: diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 00000000..50581c55 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,274 @@ +# How to Contribute + +**Contributors to movement are absolutely encouraged**, whether to fix a bug, +develop a new feature, or improve the documentation. +If you're unsure about any part of the contributing process, please get in touch. +It's best to reach out in public, e.g. by [opening an issue](https://github.com/neuroinformatics-unit/movement/issues) +so that others can benefit from the discussion. + +## Contributing code + +### Creating a development environment + +It is recommended to use [conda](https://docs.conda.io/en/latest/) +or [mamba](https://mamba.readthedocs.io/en/latest/index.html) to create a +development environment for movement. In the following we assume you have +`conda` installed, but the same commands will also work with `mamba`/`micromamba`. + +First, create and activate a `conda` environment with some pre-requisites: + +```sh +conda create -n movement-dev -c conda-forge python=3.10 pytables +conda activate movement-dev +``` + +The above method ensures that you will get packages that often can't be +installed via `pip`, including [hdf5](https://www.hdfgroup.org/solutions/hdf5/). + +To install movement for development, clone the GitHub repository, +and then run from inside the repository: + +```sh +pip install -e .[dev] # works on most shells +pip install -e '.[dev]' # works on zsh (the default shell on macOS) +``` + +This will install the package in editable mode, including all dependencies +required for development. + +Finally, initialise the [pre-commit hooks](#formatting-and-pre-commit-hooks): + +```bash +pre-commit install +``` + +### Pull requests + +In all cases, please submit code to the main repository via a pull request (PR). +We recommend, and adhere, to the following conventions: + +- Please submit _draft_ PRs as early as possible to allow for discussion. +- The PR title should be descriptive e.g. "Add new function to do X" or "Fix bug in Y". +- The PR description should be used to provide context and motivation for the changes. +- One approval of a PR (by a repo owner) is enough for it to be merged. +- Unless someone approves the PR with optional comments, the PR is immediately merged by the approving reviewer. +- Ask for a review from someone specific if you think they would be a particularly suited reviewer. +- PRs are preferably merged via the ["squash and merge"](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/incorporating-changes-from-a-pull-request/about-pull-request-merges#squash-and-merge-your-commits) option, to keep a clean commit history on the _main_ branch. + +A typical PR workflow would be: +* Create a new branch, make your changes, and stage them. +* When you try to commit, the [pre-commit hooks](#formatting-and-pre-commit-hooks) will be triggered. +* Stage any changes made by the hooks, and commit. +* You may also run the pre-commit hooks manually, at any time, with `pre-commit run -a`. +* Make sure to write tests for any new features or bug fixes. See [testing](#testing) below. +* Don't forget to update the documentation, if necessary. See [contributing documentation](#contributing-documentation) below. +* Push your changes to GitHub and open a draft pull request, with a meaningful title and a thorough description of the changes. +* If all checks (e.g. linting, type checking, testing) run successfully, you may mark the pull request as ready for review. +* Respond to review comments and implement any requested changes. +* Success 🎉 !! Your PR will be (squash-)merged into the _main_ branch. + +## Development guidelines + +### Formatting and pre-commit hooks + +Running `pre-commit install` will set up [pre-commit hooks](https://pre-commit.com/) to ensure a consistent formatting style. Currently, these include: +* [ruff](https://github.com/charliermarsh/ruff) does a number of jobs, including enforcing PEP8 and sorting imports +* [black](https://black.readthedocs.io/en/stable/) for auto-formatting +* [mypy](https://mypy.readthedocs.io/en/stable/index.html) as a static type checker +* [check-manifest](https://github.com/mgedmin/check-manifest) to ensure that the right files are included in the pip package. + +These will prevent code from being committed if any of these hooks fail. To run them individually (from the root of the repository), you can use: + +```sh +ruff . +black ./ +mypy -p movement +check-manifest +``` + +To run all the hooks before committing: + +```sh +pre-commit run # for staged files +pre-commit run -a # for all files in the repository +``` + +For docstrings, we adhere to the [numpydoc](https://numpydoc.readthedocs.io/en/latest/format.html) style. + +### Testing + +We use [pytest](https://docs.pytest.org/en/latest/) for testing and aim for +~100% test coverage (as far as is reasonable). +All new features should be tested. +Write your test methods and classes in the _tests_ folder. + +For some tests, you will need to use real experimental data. +Do not include these data in the repository, especially if they are large. +We store several sample datasets in an external data repository. +See [sample data](#sample-data) for more information. + + +### Continuous integration +All pushes and pull requests will be built by [GitHub actions](https://docs.github.com/en/actions). +This will usually include linting, testing and deployment. + +A GitHub actions workflow (`.github/workflows/test_and_deploy.yml`) has been set up to run (on each push/PR): +* Linting checks (pre-commit). +* Testing (only if linting checks pass) +* Release to PyPI (only if a git tag is present and if tests pass). + +### Versioning and releases +We use [semantic versioning](https://semver.org/), which includes `MAJOR`.`MINOR`.`PATCH` version numbers: + +* PATCH = small bugfix +* MINOR = new feature +* MAJOR = breaking change + +We use [setuptools_scm](https://github.com/pypa/setuptools_scm) to automatically version movement. +It has been pre-configured in the `pyproject.toml` file. +`setuptools_scm` will automatically [infer the version using git](https://github.com/pypa/setuptools_scm#default-versioning-scheme). +To manually set a new semantic version, create a tag and make sure the tag is pushed to GitHub. +Make sure you commit any changes you wish to be included in this version. E.g. to bump the version to `1.0.0`: + +```sh +git add . +git commit -m "Add new changes" +git tag -a v1.0.0 -m "Bump to version 1.0.0" +git push --follow-tags +``` +Alternatively, you can also use the GitHub web interface to create a new release and tag. + +The addition of a GitHub tag triggers the package's deployment to PyPI. +The version number is automatically determined from the latest tag on the _main_ branch. + +## Contributing documentation + +The documentation is hosted via [GitHub pages](https://pages.github.com/) at +[neuroinformatics-unit.github.io/movement](https://neuroinformatics-unit.github.io/movement/). +Its source files are located in the `docs` folder of this repository. +They are written in either [reStructuredText](https://docutils.sourceforge.io/rst.html) or +[markdown](https://myst-parser.readthedocs.io/en/stable/syntax/typography.html). +The `index.md` file corresponds to the homepage of the documentation website. +Other `.rst` or `.md` files are linked to the homepage via the `toctree` directive. + +We use [Sphinx](https://www.sphinx-doc.org/en/master/) and the +[PyData Sphinx Theme](https://pydata-sphinx-theme.readthedocs.io/en/stable/index.html) +to build the source files into HTML output. +This is handled by a GitHub actions workflow (`.github/workflows/docs_build_and_deploy.yml`). +The build job is triggered on each PR, ensuring that the documentation build is not broken by new changes. +The deployment job is only triggerred whenever a tag is pushed to the _main_ branch, +ensuring that the documentation is published in sync with each PyPI release. + +### Editing the documentation + +To edit the documentation, first clone the repository, and install movement in a +[development environment](#creating-a-development-environment). + +Now create a new branch, edit the documentation source files (`.md` or `.rst` in the `docs` folder), +and commit your changes. Submit your documentation changes via a pull request, +following the [same guidelines as for code changes](#pull-requests). +Make sure that the header levels in your `.md` or `.rst` files are incremented +consistently (H1 > H2 > H3, etc.) without skipping any levels. + +If you create a new documentation source file (e.g. `my_new_file.md` or `my_new_file.rst`), +you will need to add it to the `toctree` directive in `index.md` +for it to be included in the documentation website: + +```rst +:maxdepth: 2 +:hidden: + +existing_file +my_new_file +``` + +### Updating the API reference +If your PR introduces new public-facing functions, classes, or methods, +make sure to add them to the `docs/source/api_index.rst` page, so that they are +included in the [API reference](https://neuroinformatics-unit.github.io/movement/api_index.html), +e.g.: + +```rst +My new module +-------------- +.. currentmodule:: movement.new_module +.. autosummary:: + :toctree: auto_api + + new_function + NewClass +``` + +For this to work, your functions/classes/methods will need to have docstrings +that follow the [numpydoc](https://numpydoc.readthedocs.io/en/latest/format.html) style. + +### Updating the examples +We use [sphinx-gallery](https://sphinx-gallery.github.io/stable/index.html) +to create the [examples](https://neuroinformatics-unit.github.io/movement/auto_examples/index.html). +To add new examples, you will need to create a new `.py` file in `examples/`. +The file should be structured as specified in the relevant +[sphinx-gallery documentation](https://sphinx-gallery.github.io/stable/syntax.html). + + +### Building the documentation locally +We recommend that you build and view the documentation website locally, before you push it. +To do so, first install the requirements for building the documentation: +```sh +pip install -r docs/requirements.txt +``` + +Then, from the root of the repository, run: +```sh +sphinx-build docs/source docs/build +``` + +You can view the local build by opening `docs/build/index.html` in a browser. +To refresh the documentation, after making changes, remove the `docs/build` folder and re-run the above command: + +```sh +rm -rf docs/build && sphinx-build docs/source docs/build +``` + +## Sample data + +We maintain some sample data to be used for testing, examples and tutorials on an +[external data repository](https://gin.g-node.org/neuroinformatics/movement-test-data). +Our hosting platform of choice is called [GIN](https://gin.g-node.org/) and is maintained +by the [German Neuroinformatics Node](https://www.g-node.org/). +GIN has a GitHub-like interface and git-like +[CLI](https://gin.g-node.org/G-Node/Info/wiki/GIN+CLI+Setup#quickstart) functionalities. + +Currently the data repository contains sample pose estimation data files +stored in the `poses` folder. Each file name starts with either "DLC" or "SLEAP", +depending on the pose estimation software used to generate the data. + +### Fetching data +To fetch the data from GIN, we use the [pooch](https://www.fatiando.org/pooch/latest/index.html) +Python package, which can download data from pre-specified URLs and store them +locally for all subsequent uses. It also provides some nice utilities, +like verification of sha256 hashes and decompression of archives. + +The relevant functionality is implemented in the `movement.datasets.py` module. +The most important parts of this module are: + +1. The `POSE_DATA` download manager object, which contains a list of stored files and their known hashes. +2. The `list_pose_data()` function, which returns a list of the available files in the data repository. +3. The `fetch_pose_data_path()` function, which downloads a file (if not already cached locally) and returns the local path to it. + +By default, the downloaded files are stored in the `~/.movement/data` folder. +This can be changed by setting the `DATA_DIR` variable in the `movement.datasets.py` module. + +### Adding new data +Only core movement developers may add new files to the external data repository. +To add a new file, you will need to: + +1. Create a [GIN](https://gin.g-node.org/) account +2. Ask to be added as a collaborator on the [movement data repository](https://gin.g-node.org/neuroinformatics/movement-test-data) (if not already) +3. Download the [GIN CLI](https://gin.g-node.org/G-Node/Info/wiki/GIN+CLI+Setup#quickstart) and set it up with your GIN credentials, by running `gin login` in a terminal. +4. Clone the movement data repository to your local machine, by running `gin get neuroinformatics/movement-test-data` in a terminal. +5. Add your new files and commit them with `gin commit -m `. +6. Upload the commited changes to the GIN repository, by running `gin upload`. Latest changes to the repository can be pulled via `gin download`. `gin sync` will synchronise the latest changes bidirectionally. +7. Determine the sha256 checksum hash of each new file, by running `sha256sum ` in a terminal. Alternatively, you can use `pooch` to do this for you: `python -c "import pooch; pooch.file_hash('/path/to/file')"`. If you wish to generate a text file containing the hashes of all the files in a given folder, you can use `python -c "import pooch; pooch.make_registry('/path/to/folder', 'sha256_registry.txt')`. +8. Update the `movement.datasets.py` module on the [movement GitHub repository](https://github.com/SainsburyWellcomeCentre/movement) by adding the new files to the `POSE_DATA` registry. Make sure to include the correct sha256 hash, as determined in the previous step. Follow all the usual [guidelines for contributing code](#contributing-code). Make sure to test whether the new files can be fetched successfully (see [fetching data](#fetching-data) above) before submitting your pull request. + +You can also perform steps 3-6 via the GIN web interface, if you prefer to avoid using the CLI. diff --git a/MANIFEST.in b/MANIFEST.in index 77097270..ff091745 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,9 +1,10 @@ include LICENSE -include README.md +include *.md exclude .pre-commit-config.yaml exclude .cruft.json recursive-exclude * __pycache__ recursive-exclude * *.py[co] recursive-exclude docs * +recursive-exclude examples * recursive-exclude tests * diff --git a/README.md b/README.md index 3d7905c1..ec7256c0 100644 --- a/README.md +++ b/README.md @@ -9,12 +9,18 @@ Kinematic analysis of animal 🐝 🦀 🐀 🐒 body movements for neuroscience and ethology research 🔬. +- Read the [documentation](https://neuroinformatics-unit.github.io/movement/) for more information. +- If you wish to contribute, please read the [contributing guide](./CONTRIBUTING.md). + ## Status -The package is currently in early development 🏗️ and is not yet ready for use. Stay tuned ⌛ +> **Warning** +> - 🏗️ The package is currently in early development. Stay tuned ⌛ +> - It is not sufficiently tested to be used for scientific analysis +> - The interface is subject to changes. [Open an issue](https://github.com/neuroinformatics-unit/movement/issues) if you have suggestions. ## Aims -* Load keypoint tracks from pose estimation software (e.g. [DeepLabCut](http://www.mackenziemathislab.org/deeplabcut) or [SLEAP](https://sleap.ai/)) -* Evaluate the quality of the tracks and perform data cleaning +* Load pose tracks from pose estimation software packages (e.g. [DeepLabCut](http://www.mackenziemathislab.org/deeplabcut) or [SLEAP](https://sleap.ai/)) +* Evaluate the quality of the tracks and perform data cleaning operations * Calculate kinematic variables (e.g. speed, acceleration, joint angles, etc.) * Produce reports and visualise the results @@ -25,57 +31,8 @@ The following projects cover related needs and served as inspiration for this pr * [Kino](https://github.com/BrancoLab/Kino) * [WAZP](https://github.com/SainsburyWellcomeCentre/WAZP) -## How to contribute -### Setup -* We recommend you install `movement` inside a [conda](https://docs.conda.io/en/latest/) environment. -Assuming you have `conda` installed, the following will create and activate an environment containing Python 3 as well as the required `pytables` library. You can call your environment whatever you like, we've used `movement-env`. - - ```sh - conda create -n movement-env -c conda-forge python=3.11 pytables - conda activate movement-env - ``` - -* Next clone the repository and install the package in editable mode (including all `dev` dependencies): - - ```bash - git clone https://github.com/neuroinformatics-unit/movement - cd movement - pip install -e '.[dev]' - ``` -* Initialize the pre-commit hooks: - - ```bash - pre-commit install - ``` - -### Workflow -* Create a new branch, make your changes, and stage them. -* When you try to commit, the pre-commit hooks will be triggered. These include linting with [`ruff`](https://github.com/charliermarsh/ruff) and auto-formatting with [`black`](https://github.com/psf/black). Stage any changes made by the hooks, and commit. You may also run the pre-commit hooks manually, at any time, with `pre-commit run --all-files`. -* Push your changes to GitHub and open a draft pull request. -* If all checks (e.g. linting, type checking, testing) run successfully, you may mark the pull request as ready for review. -* For debugging purposes, you may also want to run the tests and the type checks locally, before pushing. This can be done with the following commands: - ```bash - cd movement - pytest - mypy -p movement - ``` -* When your pull request is approved, squash-merge it into the `main` branch and delete the feature branch. - -### Versioning and deployment -The package is deployed to PyPI automatically when a new release is created on GitHub. We use [semantic versioning](https://semver.org/), with `MAJOR`.`MINOR`.`PATCH` version numbers. - -We use [`setuptools_scm`](https://github.com/pypa/setuptools_scm), which automatically [infers the version using git](https://github.com/pypa/setuptools_scm#default-versioning-scheme). To manually set a new semantic version, create an appropriate tag and push it to GitHub. Make sure to commit any changes you wish to be included in this version. E.g. to bump the version to `1.0.0`: - -```bash -git add . -git commit -m "Add new changes" -git tag -a v1.0.0 -m "Bump to version 1.0.0" -git push --follow-tags -``` - ## License - ⚖️ [BSD 3-Clause](./LICENSE) ## Template -This package layout and configuration (including pre-commit hooks and GitHub actions) have been copied from the [python-cookiecutter](https://github.com/SainsburyWellcomeCentre/python-cookiecutter) template. +This package layout and configuration (including pre-commit hooks and GitHub actions) have been copied from the [python-cookiecutter](https://github.com/neuroinformatics-unit/python-cookiecutter) template. diff --git a/docs/requirements.txt b/docs/requirements.txt index b5a754c8..cb6d2bac 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,7 +1,11 @@ +-e . linkify-it-py +matplotlib myst-parser nbsphinx pydata-sphinx-theme setuptools-scm -sphinx +sphinx<7.2 sphinx-autodoc-typehints +sphinx-design +sphinx-gallery diff --git a/docs/source/api_index.rst b/docs/source/api_index.rst new file mode 100644 index 00000000..32b9a7ce --- /dev/null +++ b/docs/source/api_index.rst @@ -0,0 +1,48 @@ +API Reference +============= + + +Input/Output +------------ +.. currentmodule:: movement.io.load_poses +.. autosummary:: + :toctree: auto_api + + from_sleap_file + from_dlc_file + from_dlc_df + +.. currentmodule:: movement.io.save_poses +.. autosummary:: + :toctree: auto_api + + to_dlc_file + to_dlc_df + +.. currentmodule:: movement.io.validators +.. autosummary:: + :toctree: auto_api + + ValidFile + ValidHDF5 + ValidPosesCSV + ValidPoseTracks + +Datasets +-------- +.. currentmodule:: movement.datasets +.. autosummary:: + :toctree: auto_api + + list_pose_data + fetch_pose_data_path + +Logging +------- +.. currentmodule:: movement.logging +.. autosummary:: + :toctree: auto_api + + configure_logging + log_error + log_warning diff --git a/docs/source/conf.py b/docs/source/conf.py index 9b7efb35..31c53e8a 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -25,6 +25,7 @@ author = "Niko Sirmpilatze" try: release = setuptools_scm.get_version(root="../..", relative_to=__file__) + release = release.split("+")[0] # remove git hash except LookupError: # if git is not initialised, still allow local build # with a dummy version @@ -43,6 +44,8 @@ "sphinx.ext.intersphinx", "myst_parser", "nbsphinx", + "sphinx_design", + "sphinx_gallery.gen_gallery", ] # Configure the myst parser to enable cool markdown features @@ -80,8 +83,18 @@ # to ensure that include files (partial pages) aren't built, exclude them # https://github.com/sphinx-doc/sphinx/issues/1965#issuecomment-124732907 "**/includes/**", + # exclude .py and .ipynb files in auto_examples generated by sphinx-gallery + # this is to prevent sphinx from complaining about duplicate source files + "auto_examples/*.ipynb", + "auto_examples/*.py", ] +# Configure Sphinx gallery +sphinx_gallery_conf = { + "examples_dirs": ["../../examples"], + "filename_pattern": "/*.py", # which files to execute before inclusion +} + # -- Options for HTML output ------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output html_theme = "pydata_sphinx_theme" @@ -112,7 +125,7 @@ # The default is the URL of the GitHub pages # https://www.sphinx-doc.org/en/master/usage/extensions/githubpages.html github_user = "neuroinformatics-unit" -html_baseurl = f"https://neuroinformatics-unit.github.io/movement/" +html_baseurl = "https://neuroinformatics-unit.github.io/movement/" # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, diff --git a/docs/source/contributing.rst b/docs/source/contributing.rst new file mode 100644 index 00000000..498202d4 --- /dev/null +++ b/docs/source/contributing.rst @@ -0,0 +1,13 @@ +.. include:: ../../CONTRIBUTING.md + :parser: myst_parser.sphinx_ + :end-before: **Contributors + +.. important:: + .. include:: ../../CONTRIBUTING.md + :parser: myst_parser.sphinx_ + :start-after: How to Contribute + :end-before: ## Contributing code + +.. include:: ../../CONTRIBUTING.md + :parser: myst_parser.sphinx_ + :start-after: from the discussion. diff --git a/docs/source/getting_started.md b/docs/source/getting_started.md index 169a31d6..29bb21c0 100644 --- a/docs/source/getting_started.md +++ b/docs/source/getting_started.md @@ -1,11 +1,224 @@ -# Getting started +# Getting Started -Here you may demonstrate the basic functionalities your package. +## Installation -You can include code snippets using the usual Markdown syntax: +We recommend you install movement inside a [conda](https://docs.conda.io/en/latest/) +or [mamba](https://mamba.readthedocs.io/en/latest/index.html) environment. +In the following we assume you have `conda` installed, +but the same commands will also work with `mamba`/`micromamba`. + + +First, create and activate an environment. +You can call your environment whatever you like, we've used "movement-env". + +```sh +conda create -n movement-env -c conda-forge python=3.10 pytables +conda activate movement-env +``` + +Next install the `movement` package: + +::::{tab-set} + +:::{tab-item} Users +To get the latest release from PyPI: + +```sh +pip install movement +``` +If you have an older version of `movement` installed in the same environment, +you can update to the latest version with: + +```sh +pip install --upgrade movement +``` +::: + +:::{tab-item} Developers +To get the latest development version, clone the +[GitHub repository](https://neuroinformatics-unit.github.io/movement/) +and then run from inside the repository: + +```sh +pip install -e .[dev] # works on most shells +pip install -e '.[dev]' # works on zsh (the default shell on macOS) +``` + +This will install the package in editable mode, including all `dev` dependencies. +Please see the [contributing guide](./contributing.rst) for more information. +::: + +:::: + + +## Loading data +You can load predicted pose tracks from the pose estimation software packages +[DeepLabCut](http://www.mackenziemathislab.org/deeplabcut) or [SLEAP](https://sleap.ai/). + +First import the `movement.io.load_poses` module: ```python from movement.io import load_poses +``` + +Then, use the `from_dlc_file` or `from_sleap_file` functions to load the data. + +::::{tab-set} + +:::{tab-item} SLEAP + +Load from [SLEAP analysis files](https://sleap.ai/tutorials/analysis.html) (`.h5`): +```python +ds = load_poses.from_sleap_file("/path/to/file.analysis.h5", fps=30) +``` +::: + +:::{tab-item} DeepLabCut + +Load pose estimation outputs from `.h5` files: +```python +ds = load_poses.from_dlc_file("/path/to/file.h5", fps=30) +``` + +You may also load `.csv` files (assuming they are formatted as DeepLabCut expects them): +```python +ds = load_poses.from_dlc_file("/path/to/file.csv", fps=30) +``` + +If you have already imported the data into a pandas DataFrame, you can +convert it to a movement dataset with: +```python +import pandas as pd + +df = pd.read_hdf("/path/to/file.h5") +ds = load_poses.from_dlc_df(df, fps=30) +``` +::: + +:::: + +You can also try movement out on some sample data included in the package. + +:::{dropdown} Fetching sample data +:color: primary +:icon: unlock + +You can view the available sample data files with: -df = load_poses.from_dlc('path/to/file.h5') +```python +from movement import datasets + +file_names = datasets.list_pose_data() +print(file_names) +``` +This will print a list of file names containing sample pose data. +The files are prefixed with the name of the pose estimation software package, +either "DLC" or "SLEAP". + +To get the path to one of the sample files, +you can use the `fetch_pose_data_path` function: + +```python +file_path = datasets.fetch_pose_data_path("DLC_two-mice.predictions.csv") +``` +The first time you call this function, it will download the corresponding file +to your local machine and save it in the `~/.movement/data` directory. On +subsequent calls, it will simply return the path to that local file. + +You can feed the path to the `from_dlc_file` or `from_sleap_file` functions +and load the data, as shown above. +::: + +## Working with movement datasets + +Loaded pose estimation data are represented in movement as +[`xarray.Dataset`](https://docs.xarray.dev/en/stable/generated/xarray.Dataset.html) objects. + +You can view information about the loaded dataset by printing it: +```python +ds = load_poses.from_dlc_file("/path/to/file.h5", fps=30) +print(ds) +``` +If you are working in a Jupyter notebook, you can also view an interactive +representation of the dataset by simply typing its name - e.g. `ds` - in a cell. + +### Dataset structure + +The movement `xarray.Dataset` has the following dimensions: +- `time`: the number of frames in the video +- `individuals`: the number of individuals in the video +- `keypoints`: the number of keypoints in the skeleton +- `space`: the number of spatial dimensions, either 2 or 3 + +Appropriate coordinate labels are assigned to each dimension: +list of unique names (str) for `individuals` and `keypoints`, +['x','y',('z')] for `space`. The coordinates of the `time` dimension are +in seconds if `fps` is provided, otherwise they are in frame numbers. + +The dataset contains two data variables stored as +[`xarray.DataArray`](https://docs.xarray.dev/en/latest/generated/xarray.DataArray.html#xarray.DataArray) objects: +- `pose_tracks`: with shape (`time`, `individuals`, `keypoints`, `space`) +- `confidence`: with shape (`time`, `individuals`, `keypoints`) + +You can think of a `DataArray` as a `numpy.ndarray` with `pandas`-style +indexing and labelling. To learn more about `xarray` data structures, see the +relevant [documentation](https://docs.xarray.dev/en/latest/user-guide/data-structures.html). + +The dataset may also contain the following attributes as metadata: +- `fps`: the number of frames per second in the video +- `time_unit`: the unit of the `time` coordinates, frames or seconds +- `source_software`: the software from which the pose tracks were loaded +- `source_file`: the file from which the pose tracks were loaded + +### Indexing and selection +You can access the data variables and attributes of the dataset as follows: +```python +pose_tracks = ds.pose_tracks # ds["pose_tracks"] also works +confidence = ds.confidence + +fps = ds.fps # ds.attrs["fps"] also works +``` + +You can select subsets of the data using the `sel` method: +```python +# select the first 100 seconds of data +ds_sel = ds.sel(time=slice(0, 100)) + +# select specific individuals or keypoints +ds_sel = ds.sel(individuals=["individual1", "individual2"]) +ds_sel = ds.sel(keypoints="snout") + +# combine selections +ds_sel = ds.sel(time=slice(0, 100), individuals=["individual1", "individual2"], keypoints="snout") +``` +All of the above selections can also be applied to the data variables, +resulting in a `DataArray` rather than a `Dataset`: + +```python +pose_tracks = ds.pose_tracks.sel(individuals="individual1", keypoints="snout") +``` +You may also use all the other powerful [indexing and selection](https://docs.xarray.dev/en/latest/user-guide/indexing.html) methods provided by `xarray`. + +### Plotting + +You can also use the built-in [`xarray` plotting methods](https://docs.xarray.dev/en/latest/user-guide/plotting.html) +to visualise the data. Check out the [Load and explore pose tracks](./auto_examples/load_and_explore_poses.rst) +example for inspiration. + +## Saving data +You can save movement datasets to disk in a variety of formats. +Currently, only saving to DeepLabCut-style files is supported. + +```python +from movement.io import save_poses + +save_poses.to_dlc_file(ds, "/path/to/file.h5") # preferred +save_poses.to_dlc_file(ds, "/path/to/file.csv") +``` + +Instead of saving to file directly, you can also convert the dataset to a +DeepLabCut-style `pandas.DataFrame` first: +```python +df = save_poses.to_dlc_df(ds) ``` +and then save it to file using any `pandas` method, e.g. `to_hdf` or `to_csv`. diff --git a/docs/source/index.md b/docs/source/index.md new file mode 100644 index 00000000..6399c5ed --- /dev/null +++ b/docs/source/index.md @@ -0,0 +1,60 @@ +# movement + +Kinematic analysis of animal 🐝 🦀 🐀 🐒 body movements for neuroscience and ethology research. + +::::{grid} 1 2 2 3 +:gutter: 3 + +:::{grid-item-card} {fas}`rocket;sd-text-primary` Getting Started +:link: getting_started +:link-type: doc + +Install and try it out. +::: + +:::{grid-item-card} {fas}`chalkboard-user;sd-text-primary` Examples +:link: auto_examples/index +:link-type: doc + +Example use cases. +::: + +:::{grid-item-card} {fas}`code;sd-text-primary` API Reference +:link: api_index +:link-type: doc + +Index of all functions, classes, and methods. +::: +:::: + +## Status +:::{warning} +- 🏗️ The package is currently in early development. Stay tuned ⌛ +- It is not sufficiently tested to be used for scientific analysis +- The interface is subject to changes. [Open an issue](https://github.com/neuroinformatics-unit/movement/issues) if you have suggestions. +::: + + +## Aims +* Load pose tracks from pose estimation software packages (e.g. [DeepLabCut](http://www.mackenziemathislab.org/deeplabcut) or [SLEAP](https://sleap.ai/)) +* Evaluate the quality of the tracks and perform data cleaning operations +* Calculate kinematic variables (e.g. speed, acceleration, joint angles, etc.) +* Produce reports and visualise the results + +## Related projects +The following projects cover related needs and served as inspiration for this project: +* [DLC2Kinematics](https://github.com/AdaptiveMotorControlLab/DLC2Kinematics) +* [PyRat](https://github.com/pyratlib/pyrat) +* [Kino](https://github.com/BrancoLab/Kino) +* [WAZP](https://github.com/SainsburyWellcomeCentre/WAZP) + + +```{toctree} +:maxdepth: 2 +:hidden: + +getting_started +auto_examples/index +api_index +contributing +``` diff --git a/docs/source/index.rst b/docs/source/index.rst deleted file mode 100644 index f9776ca5..00000000 --- a/docs/source/index.rst +++ /dev/null @@ -1,13 +0,0 @@ -Welcome to movement's documentation! -========================================================= - -.. toctree:: - :maxdepth: 2 - :caption: Contents: - - getting_started - -Index & Search --------------- -* :ref:`genindex` -* :ref:`search` diff --git a/examples/README.rst b/examples/README.rst new file mode 100644 index 00000000..311ded86 --- /dev/null +++ b/examples/README.rst @@ -0,0 +1,4 @@ +Examples +======== + +Below is a gallery of examples using `movement`. diff --git a/examples/load_and_explore_poses.py b/examples/load_and_explore_poses.py new file mode 100644 index 00000000..8562b532 --- /dev/null +++ b/examples/load_and_explore_poses.py @@ -0,0 +1,86 @@ +""" +Load and explore pose tracks +============================ + +Load and explore an example dataset of pose tracks. +""" + +# %% +# Imports +# ------- +from matplotlib import pyplot as plt + +from movement import datasets +from movement.io import load_poses + +# %% +# Fetch an example dataset +# ------------------------ +# Print a list of available datasets: + +for file_name in datasets.list_pose_data(): + print(file_name) + +# %% +# Fetch the path to an example dataset. +# Feel free to replace this with the path to your own dataset. +# e.g., ``file_path = "/path/to/my/data.h5"``) +file_path = datasets.fetch_pose_data_path( + "SLEAP_three-mice_Aeon_proofread.analysis.h5" +) + +# %% +# Load the dataset +# ---------------- + +ds = load_poses.from_sleap_file(file_path, fps=50) +print(ds) + +# %% +# The loaded dataset contains two data variables: +# ``pose_tracks`` and ``confidence``` +# To get the pose tracks: +pose_tracks = ds.pose_tracks + +# %% +# Select and plot data with xarray +# -------------------------------- +# You can use the ``sel`` method to index into ``xarray`` objects. +# For example, we can get a ``DataArray`` containing only data +# for a single keypoint of the first individual: + +da = pose_tracks.sel(individuals="AEON3B_NTP", keypoints="centroid") +print(da) + +# %% +# We could plot the x, y coordinates of this keypoint over time, +# using ``xarray``'s built-in plotting methods: +da.plot.line(x="time", row="space", aspect=2, size=2.5) + +# %% +# Similarly we could plot the same keypoint's x, y coordinates +# for all individuals: + +da = pose_tracks.sel(keypoints="centroid") +da.plot.line(x="time", row="individuals", aspect=2, size=2.5) + +# %% +# Trajectory plots +# ---------------- +# We are not limited to ``xarray``'s built-in plots. +# For example, we can use ``matplotlib`` to plot trajectories +# (using scatter plots): + +mouse_name = "AEON3B_TP1" + +plt.scatter( + da.sel(individuals=mouse_name, space="x"), + da.sel(individuals=mouse_name, space="y"), + s=2, + c=da.time, + cmap="viridis", +) +plt.title(f"Trajectory of {mouse_name}") +plt.xlabel("x") +plt.ylabel("y") +plt.colorbar(label="time (sec)") diff --git a/movement/__init__.py b/movement/__init__.py index 4e64c32d..9ac02704 100644 --- a/movement/__init__.py +++ b/movement/__init__.py @@ -1,5 +1,5 @@ from importlib.metadata import PackageNotFoundError, version -from movement.log_config import configure_logging +from movement.logging import configure_logging try: __version__ = version("movement") @@ -7,6 +7,10 @@ # package is not installed pass +# set xarray global options +import xarray as xr + +xr.set_options(keep_attrs=True, display_expand_data=False) # initialize logger upon import configure_logging() diff --git a/movement/datasets.py b/movement/datasets.py index ae0b5dc8..30760125 100644 --- a/movement/datasets.py +++ b/movement/datasets.py @@ -6,6 +6,7 @@ """ from pathlib import Path +from typing import List import pooch @@ -40,8 +41,18 @@ ) +def list_pose_data() -> List[str]: + """Find available sample pose data in the *movement* data repository. + + Returns + ------- + filenames : list of str + List of filenames for available pose data.""" + return list(POSE_DATA.registry.keys()) + + def fetch_pose_data_path(filename: str) -> Path: - """Fetch sample pose data from the remote repository. + """Fetch sample pose data from the *movement* data repository. The data are downloaded to the user's local machine the first time they are used and are stored in a local cache directory. The function returns the diff --git a/movement/io/__init__.py b/movement/io/__init__.py index e69de29b..c95035f8 100644 --- a/movement/io/__init__.py +++ b/movement/io/__init__.py @@ -0,0 +1 @@ +from .poses_accessor import PosesAccessor diff --git a/movement/io/load_poses.py b/movement/io/load_poses.py index e7b92f6f..2ed8d4de 100644 --- a/movement/io/load_poses.py +++ b/movement/io/load_poses.py @@ -2,75 +2,309 @@ from pathlib import Path from typing import Optional, Union +import h5py +import numpy as np import pandas as pd +import xarray as xr +from sleap_io.io.slp import read_labels -from movement.io.validators import DeepLabCutPosesFile +from movement.io.poses_accessor import PosesAccessor +from movement.io.validators import ( + ValidFile, + ValidHDF5, + ValidPosesCSV, + ValidPoseTracks, +) +from movement.logging import log_error -# get logger logger = logging.getLogger(__name__) -def from_dlc(file_path: Union[Path, str]) -> Optional[pd.DataFrame]: - """Load pose estimation results from a DeepLabCut (DLC) files. - Files must be in .h5 format or .csv format. +def from_dlc_df(df: pd.DataFrame, fps: Optional[float] = None) -> xr.Dataset: + """Create an xarray.Dataset from a DeepLabCut-style pandas DataFrame. Parameters ---------- - file_path : pathlib Path or str - Path to the file containing the DLC poses. + df : pandas.DataFrame + DataFrame containing the pose tracks and confidence scores. Must + be formatted as in DeepLabCut output files (see Notes). + fps : float, optional + The number of frames per second in the video. If None (default), + the `time` coordinates will be in frame numbers. Returns ------- - pandas DataFrame - DataFrame containing the DLC poses + xarray.Dataset + Dataset containing the pose tracks, confidence scores, and metadata. + + Notes + ----- + The DataFrame must have a multi-index column with the following levels: + "scorer", ("individuals"), "bodyparts", "coords". The "individuals" + level may be omitted if there is only one individual in the video. + The "coords" level contains the spatial coordinates "x", "y", + as well as "likelihood" (point-wise confidence scores). + The row index corresponds to the frame number. + + See Also + -------- + movement.io.load_poses.from_dlc_file : Load pose tracks directly from file. + """ + + # read names of individuals and keypoints from the DataFrame + if "individuals" in df.columns.names: + individual_names = ( + df.columns.get_level_values("individuals").unique().to_list() + ) + else: + individual_names = ["individual_0"] + + keypoint_names = ( + df.columns.get_level_values("bodyparts").unique().to_list() + ) + + # reshape the data into (n_frames, n_individuals, n_keypoints, 3) + # where the last axis contains "x", "y", "likelihood" + tracks_with_scores = df.to_numpy().reshape( + (-1, len(individual_names), len(keypoint_names), 3) + ) + + valid_data = ValidPoseTracks( + tracks_array=tracks_with_scores[:, :, :, :-1], + scores_array=tracks_with_scores[:, :, :, -1], + individual_names=individual_names, + keypoint_names=keypoint_names, + fps=fps, + ) + return _from_valid_data(valid_data) + + +def from_sleap_file( + file_path: Union[Path, str], fps: Optional[float] = None +) -> xr.Dataset: + """Load pose tracking data from a SLEAP file into an xarray Dataset. + + Parameters + ---------- + file_path : pathlib.Path or str + Path to the file containing the SLEAP predictions in ".h5" + (analysis) format. Alternatively, an ".slp" (labels) file can + also be supplied (but this feature is experimental, see Notes). + fps : float, optional + The number of frames per second in the video. If None (default), + the `time` coordinates will be in frame numbers. + + Returns + ------- + xarray.Dataset + Dataset containing the pose tracks, confidence scores, and metadata. + + Notes + ----- + The SLEAP predictions are normally saved in ".slp" files, e.g. + "v1.predictions.slp". An analysis file, suffixed with ".h5" can be exported + from the ".slp" file, using either the command line tool `sleap-convert` + (with the "--format analysis" option enabled) or the SLEAP GUI (Choose + "Export Analysis HDF5…" from the "File" menu) [1]_. This is the + preferred format for loading pose tracks from SLEAP into *movement*. + + You can also try directly loading the ".slp" file, but this feature is + experimental and doesnot work in all cases. If the ".slp" file contains + both user-labeled and predicted instances, only the predicted ones will be + loaded. If there are multiple videos in the file, only the first one will + be used. + + *movement* expects the tracks to be assigned and proofread before loading + them, meaning each track is interpreted as a single individual/animal. + Follow the SLEAP guide for tracking and proofreading [2]_. + + References + ---------- + .. [1] https://sleap.ai/tutorials/analysis.html + .. [2] https://sleap.ai/guides/proofreading.html + + Examples + -------- + >>> from movement.io import load_poses + >>> ds = load_poses.from_sleap_file("path/to/file.analysis.h5", fps=30) + """ + + file = ValidFile( + file_path, + expected_permission="r", + expected_suffix=[".h5", ".slp"], + ) + + # Load and validate data + if file.path.suffix == ".h5": + valid_data = _load_from_sleap_analysis_file(file.path, fps=fps) + else: # file.path.suffix == ".slp" + valid_data = _load_from_sleap_labels_file(file.path, fps=fps) + logger.debug(f"Validated pose tracks from {file.path}.") + + # Initialize an xarray dataset from the dictionary + ds = _from_valid_data(valid_data) + + # Add metadata as attrs + ds.attrs["source_software"] = "SLEAP" + ds.attrs["source_file"] = file.path.as_posix() + + logger.info(f"Loaded pose tracks from {file.path}:") + logger.info(ds) + return ds + + +def from_dlc_file( + file_path: Union[Path, str], fps: Optional[float] = None +) -> xr.Dataset: + """Load pose tracking data from a DeepLabCut (DLC) output file + into an xarray Dataset. + + Parameters + ---------- + file_path : pathlib.Path or str + Path to the file containing the DLC predicted poses, either in ".h5" + or ".csv" format. + fps : float, optional + The number of frames per second in the video. If None (default), + the `time` coordinates will be in frame numbers. + + Returns + ------- + xarray.Dataset + Dataset containing the pose tracks, confidence scores, and metadata. + + See Also + -------- + movement.io.load_poses.from_dlc_df : Load pose tracks from a DataFrame. Examples -------- >>> from movement.io import load_poses - >>> poses = load_poses.from_dlc("path/to/file.h5") + >>> ds = load_poses.from_dlc_file("path/to/file.h5", fps=30) """ - # Validate the input file path - dlc_poses_file = DeepLabCutPosesFile(file_path=file_path) # type: ignore - file_suffix = dlc_poses_file.file_path.suffix + file = ValidFile( + file_path, + expected_permission="r", + expected_suffix=[".csv", ".h5"], + ) - # Load the DLC poses - try: - if file_suffix == ".csv": - df = _parse_dlc_csv_to_dataframe(dlc_poses_file.file_path) - else: # file can only be .h5 at this point - df = pd.read_hdf(dlc_poses_file.file_path) - # above line does not necessarily return a DataFrame - df = pd.DataFrame(df) - except (OSError, TypeError, ValueError) as e: - error_msg = ( - f"Could not load poses from {file_path}. " - "Please check that the file is valid and readable." + # Load the DLC poses into a DataFrame + if file.path.suffix == ".csv": + df = _parse_dlc_csv_to_df(file.path) + else: # file.path.suffix == ".h5" + df = _load_df_from_dlc_h5(file.path) + + logger.debug(f"Loaded poses from {file.path} into a DataFrame.") + # Convert the DataFrame to an xarray dataset + ds = from_dlc_df(df=df, fps=fps) + + # Add metadata as attrs + ds.attrs["source_software"] = "DeepLabCut" + ds.attrs["source_file"] = file.path.as_posix() + + logger.info(f"Loaded pose tracks from {file.path}:") + logger.info(ds) + return ds + + +def _load_from_sleap_analysis_file( + file_path: Path, fps: Optional[float] +) -> ValidPoseTracks: + """Load and validate pose tracks and confidence scores from a SLEAP + analysis file. + + Parameters + ---------- + file_path : pathlib.Path + Path to the SLEAP analysis file containing predicted pose tracks. + fps : float, optional + The number of frames per second in the video. If None (default), + the `time` coordinates will be in frame units. + + Returns + ------- + movement.io.tracks_validators.ValidPoseTracks + The validated pose tracks and confidence scores. + """ + + file = ValidHDF5(file_path, expected_datasets=["tracks"]) + + with h5py.File(file.path, "r") as f: + # transpose to shape: (n_frames, n_tracks, n_keypoints, n_space) + tracks = f["tracks"][:].transpose((3, 0, 2, 1)) + # Create an array of NaNs for the confidence scores + scores = np.full(tracks.shape[:-1], np.nan, dtype="float32") + # If present, read the point-wise scores, + # and transpose to shape: (n_frames, n_tracks, n_keypoints) + if "point_scores" in f.keys(): + scores = f["point_scores"][:].transpose((2, 0, 1)) + + return ValidPoseTracks( + tracks_array=tracks, + scores_array=scores, + individual_names=[n.decode() for n in f["track_names"][:]], + keypoint_names=[n.decode() for n in f["node_names"][:]], + fps=fps, ) - logger.error(error_msg) - raise OSError from e - logger.info(f"Loaded poses from {file_path}") - return df -def _parse_dlc_csv_to_dataframe(file_path: Path) -> pd.DataFrame: - """If poses are loaded from a DeepLabCut.csv file, the resulting DataFrame +def _load_from_sleap_labels_file( + file_path: Path, fps: Optional[float] +) -> ValidPoseTracks: + """Load and validate pose tracks and confidence scores from a SLEAP + labels file. + + Parameters + ---------- + file_path : pathlib.Path + Path to the SLEAP labels file containing predicted pose tracks. + fps : float, optional + The number of frames per second in the video. If None (default), + the `time` coordinates will be in frame units. + + Returns + ------- + movement.io.tracks_validators.ValidPoseTracks + The validated pose tracks and confidence scores. + """ + + file = ValidHDF5(file_path, expected_datasets=["pred_points", "metadata"]) + + labels = read_labels(file.path.as_posix()) + tracks_with_scores = labels.numpy(untracked=False, return_confidence=True) + + return ValidPoseTracks( + tracks_array=tracks_with_scores[:, :, :, :-1], + scores_array=tracks_with_scores[:, :, :, -1], + individual_names=[track.name for track in labels.tracks], + keypoint_names=[kp.name for kp in labels.skeletons[0].nodes], + fps=fps, + ) + + +def _parse_dlc_csv_to_df(file_path: Path) -> pd.DataFrame: + """If poses are loaded from a DeepLabCut .csv file, the DataFrame lacks the multi-index columns that are present in the .h5 file. This - function parses the csv file to a DataFrame with multi-index columns. + function parses the csv file to a pandas DataFrame with multi-index + columns, i.e. the same format as in the .h5 file. Parameters ---------- - file_path : pathlib Path - Path to the file containing the DLC poses, in .csv format. + file_path : pathlib.Path + Path to the DeepLabCut-style CSV file. Returns ------- - pandas DataFrame - DataFrame containing the DLC poses, with multi-index columns. + pandas.DataFrame + DeepLabCut-style DataFrame with multi-index columns. """ + file = ValidPosesCSV(file_path) + possible_level_names = ["scorer", "individuals", "bodyparts", "coords"] - with open(file_path, "r") as f: + with open(file.path, "r") as f: # if line starts with a possible level name, split it into a list # of strings, and add it to the list of header lines header_lines = [ @@ -86,7 +320,81 @@ def _parse_dlc_csv_to_dataframe(file_path: Path) -> pd.DataFrame: # Import the DLC poses as a DataFrame df = pd.read_csv( - file_path, skiprows=len(header_lines), index_col=0, names=columns + file.path, + skiprows=len(header_lines), + index_col=0, + names=np.array(columns), ) df.columns.rename(level_names, inplace=True) return df + + +def _load_df_from_dlc_h5(file_path: Path) -> pd.DataFrame: + """Load pose tracks and likelihood scores from a DeepLabCut .h5 file + into a pandas DataFrame. + + Parameters + ---------- + file_path : pathlib.Path + Path to the DeepLabCut-style HDF5 file containing pose tracks. + + Returns + ------- + pandas.DataFrame + DeepLabCut-style Dataframe. + """ + + file = ValidHDF5(file_path, expected_datasets=["df_with_missing"]) + + try: + # pd.read_hdf does not always return a DataFrame + df = pd.DataFrame(pd.read_hdf(file.path, key="df_with_missing")) + except Exception as error: + raise log_error(error, f"Could not load a dataframe from {file.path}.") + return df + + +def _from_valid_data(data: ValidPoseTracks) -> xr.Dataset: + """Convert already validated pose tracking data to an xarray Dataset. + + Parameters + ---------- + data : movement.io.tracks_validators.ValidPoseTracks + The validated data object. + + Returns + ------- + xarray.Dataset + Dataset containing the pose tracks, confidence scores, and metadata. + """ + + n_frames = data.tracks_array.shape[0] + n_space = data.tracks_array.shape[-1] + + # Create the time coordinate, depending on the value of fps + time_coords = np.arange(n_frames, dtype=int) + time_unit = "frames" + if data.fps is not None: + time_coords = time_coords / data.fps + time_unit = "seconds" + + DIM_NAMES = PosesAccessor.dim_names + # Convert data to an xarray.Dataset + return xr.Dataset( + data_vars={ + "pose_tracks": xr.DataArray(data.tracks_array, dims=DIM_NAMES), + "confidence": xr.DataArray(data.scores_array, dims=DIM_NAMES[:-1]), + }, + coords={ + DIM_NAMES[0]: time_coords, + DIM_NAMES[1]: data.individual_names, + DIM_NAMES[2]: data.keypoint_names, + DIM_NAMES[3]: ["x", "y", "z"][:n_space], + }, + attrs={ + "fps": data.fps, + "time_unit": time_unit, + "source_software": None, + "source_file": None, + }, + ) diff --git a/movement/io/poses_accessor.py b/movement/io/poses_accessor.py new file mode 100644 index 00000000..4a92476a --- /dev/null +++ b/movement/io/poses_accessor.py @@ -0,0 +1,83 @@ +import logging +from typing import ClassVar + +import xarray as xr + +from movement.io.validators import ValidPoseTracks + +logger = logging.getLogger(__name__) + +# Preserve the attributes (metadata) of xarray objects after operations +xr.set_options(keep_attrs=True) + + +@xr.register_dataset_accessor("poses") +class PosesAccessor: + """An accessor that extends an `xarray.Dataset` object. + + The `xarray.Dataset` has the following dimensions: + - `time`: the number of frames in the video + - `individuals`: the number of individuals in the video + - `keypoints`: the number of keypoints in the skeleton + - `space`: the number of spatial dimensions, either 2 or 3 + + Appropriate coordinate labels are assigned to each dimension: + list of unique names (str) for `individuals` and `keypoints`, + ['x','y',('z')] for `space`. The coordinates of the `time` dimension are + in seconds if `fps` is provided, otherwise they are in frame numbers. + + The dataset contains two data variables (`xarray.DataArray` objects): + - `pose_tracks`: with shape (`time`, `individuals`, `keypoints`, `space`) + - `confidence`: with shape (`time`, `individuals`, `keypoints`) + + The dataset may also contain following attributes as metadata: + - `fps`: the number of frames per second in the video + - `time_unit`: the unit of the `time` coordinates, frames or seconds + - `source_software`: the software from which the pose tracks were loaded + - `source_file`: the file from which the pose tracks were loaded + + Notes + ----- + Using an acessor is the recommended way to extend xarray objects. + See [1]_ for more details. + + Methods/properties that are specific to this class can be used via + the `.poses` accessor, e.g. `ds.poses.to_dlc_df()`. + + References + ---------- + .. _1: https://docs.xarray.dev/en/stable/internals/extending-xarray.html + """ + + # Names of the expected dimensions in the dataset + dim_names: ClassVar[tuple] = ( + "time", + "individuals", + "keypoints", + "space", + ) + + # Names of the expected data variables in the dataset + var_names: ClassVar[tuple] = ( + "pose_tracks", + "confidence", + ) + + def __init__(self, ds: xr.Dataset): + self._obj = ds + + def validate(self) -> None: + """Validate the PoseTracks dataset.""" + fps = self._obj.attrs.get("fps", None) + try: + ValidPoseTracks( + tracks_array=self._obj[self.var_names[0]].values, + scores_array=self._obj[self.var_names[1]].values, + individual_names=self._obj.coords[self.dim_names[1]].values, + keypoint_names=self._obj.coords[self.dim_names[2]].values, + fps=fps, + ) + except Exception as e: + error_msg = "The dataset does not contain valid pose tracks." + logger.error(error_msg) + raise ValueError(error_msg) from e diff --git a/movement/io/save_poses.py b/movement/io/save_poses.py new file mode 100644 index 00000000..e1be3f70 --- /dev/null +++ b/movement/io/save_poses.py @@ -0,0 +1,112 @@ +import logging +from pathlib import Path +from typing import Union + +import numpy as np +import pandas as pd +import xarray as xr + +from movement.io.validators import ValidFile + +logger = logging.getLogger(__name__) + + +def to_dlc_df(ds: xr.Dataset) -> pd.DataFrame: + """Convert an xarray dataset containing pose tracks into a + DeepLabCut-style pandas DataFrame with multi-index columns. + + Parameters + ---------- + ds : xarray Dataset + Dataset containing pose tracks, confidence scores, and metadata. + + Returns + ------- + pandas DataFrame + + Notes + ----- + The DataFrame will have a multi-index column with the following levels: + "scorer", "individuals", "bodyparts", "coords" (even if there is only + one individual present). Regardless of the provenance of the + points-wise confidence scores, they will be referred to as + "likelihood", and stored in the "coords" level (as DeepLabCut expects). + + See Also + -------- + to_dlc_file : Save the xarray dataset containing pose tracks directly + to a DeepLabCut-style ".h5" or ".csv" file. + """ + + if not isinstance(ds, xr.Dataset): + error_msg = f"Expected an xarray Dataset, but got {type(ds)}. " + logger.error(error_msg) + raise ValueError(error_msg) + + ds.poses.validate() # validate the dataset + + # Concatenate the pose tracks and confidence scores into one array + tracks_with_scores = np.concatenate( + ( + ds.pose_tracks.data, + ds.confidence.data[..., np.newaxis], + ), + axis=-1, + ) + + # Create the DLC-style multi-index columns + # Use the DLC terminology: scorer, individuals, bodyparts, coords + scorer = ["movement"] + individuals = ds.coords["individuals"].data.tolist() + bodyparts = ds.coords["keypoints"].data.tolist() + # The confidence scores in DLC are referred to as "likelihood" + coords = ds.coords["space"].data.tolist() + ["likelihood"] + + index_levels = ["scorer", "individuals", "bodyparts", "coords"] + columns = pd.MultiIndex.from_product( + [scorer, individuals, bodyparts, coords], names=index_levels + ) + df = pd.DataFrame( + data=tracks_with_scores.reshape(ds.dims["time"], -1), + index=np.arange(ds.dims["time"], dtype=int), + columns=columns, + dtype=float, + ) + logger.info("Converted PoseTracks dataset to DLC-style DataFrame.") + return df + + +def to_dlc_file(ds: xr.Dataset, file_path: Union[str, Path]) -> None: + """Save the xarray dataset containing pose tracks to a + DeepLabCut-style ".h5" or ".csv" file. + + Parameters + ---------- + ds : xarray Dataset + Dataset containing pose tracks, confidence scores, and metadata. + file_path : pathlib Path or str + Path to the file to save the DLC poses to. The file extension + must be either ".h5" (recommended) or ".csv". + + See Also + -------- + to_dlc_df : Convert an xarray dataset containing pose tracks into a + DeepLabCut-style pandas DataFrame with multi-index columns. + """ + + try: + file = ValidFile( + file_path, + expected_permission="w", + expected_suffix=[".csv", ".h5"], + ) + except (OSError, ValueError) as error: + logger.error(error) + raise error + + df = to_dlc_df(ds) # convert to pandas DataFrame + if file.path.suffix == ".csv": + df.to_csv(file.path, sep=",") + else: # file.path.suffix == ".h5" + df.to_hdf(file.path, key="df_with_missing") + logger.info(f"Saved PoseTracks dataset to {file.path}.") diff --git a/movement/io/validators.py b/movement/io/validators.py index 8984b0aa..c56f5bb7 100644 --- a/movement/io/validators.py +++ b/movement/io/validators.py @@ -1,39 +1,348 @@ -import logging +import os from pathlib import Path +from typing import Any, Iterable, List, Literal, Optional, Union -from pydantic import BaseModel, field_validator +import h5py +import numpy as np +from attrs import converters, define, field, validators -# initialize logger -logger = logging.getLogger(__name__) +from movement.logging import log_error, log_warning -class DeepLabCutPosesFile(BaseModel): - """Pydantic class for validating files containing - pose estimation results from DeepLabCut (DLC). +@define +class ValidFile: + """Class for validating file paths. - Pydantic will enforce the input data type. - This class additionally checks that the file exists - and has a valid suffix. + Parameters + ---------- + path : str or pathlib.Path + Path to the file. + expected_permission : {'r', 'w', 'rw'} + Expected access permission(s) for the file. If 'r', the file is + expected to be readable. If 'w', the file is expected to be writable. + If 'rw', the file is expected to be both readable and writable. + Default: 'r'. + expected_suffix : list of str + Expected suffix(es) for the file. If an empty list (default), this + check is skipped. + + Raises + ------ + IsADirectoryError + If the path points to a directory. + PermissionError + If the file does not have the expected access permission(s). + FileNotFoundError + If the file does not exist when `expected_permission` is 'r' or 'rw'. + FileExistsError + If the file exists when `expected_permission` is 'w'. + ValueError + If the file does not have one of the expected suffix(es). + """ + + path: Path = field(converter=Path, validator=validators.instance_of(Path)) + expected_permission: Literal["r", "w", "rw"] = field( + default="r", validator=validators.in_(["r", "w", "rw"]), kw_only=True + ) + expected_suffix: List[str] = field(factory=list, kw_only=True) + + @path.validator + def path_is_not_dir(self, attribute, value): + """Ensures that the path does not point to a directory.""" + if value.is_dir(): + raise log_error( + IsADirectoryError, + f"Expected a file path but got a directory: {value}.", + ) + + @path.validator + def file_exists_when_expected(self, attribute, value): + """Ensures that the file exists (or not) depending on the expected + usage (read and/or write).""" + if "r" in self.expected_permission: + if not value.exists(): + raise log_error( + FileNotFoundError, f"File {value} does not exist." + ) + else: # expected_permission is 'w' + if value.exists(): + raise log_error( + FileExistsError, f"File {value} already exists." + ) + + @path.validator + def file_has_access_permissions(self, attribute, value): + """Ensures that the file has the expected access permission(s). + Raises a PermissionError if not.""" + file_is_readable = os.access(value, os.R_OK) + parent_is_writeable = os.access(value.parent, os.W_OK) + if ("r" in self.expected_permission) and (not file_is_readable): + raise log_error( + PermissionError, + f"Unable to read file: {value}. " + "Make sure that you have read permissions.", + ) + if ("w" in self.expected_permission) and (not parent_is_writeable): + raise log_error( + PermissionError, + f"Unable to write to file: {value}. " + "Make sure that you have write permissions.", + ) + + @path.validator + def file_has_expected_suffix(self, attribute, value): + """Ensures that the file has one of the expected suffix(es).""" + if self.expected_suffix: # list is not empty + if value.suffix not in self.expected_suffix: + raise log_error( + ValueError, + f"Expected file with suffix(es) {self.expected_suffix} " + f"but got suffix {value.suffix} instead.", + ) + + +@define +class ValidHDF5: + """Class for validating HDF5 files. + + Parameters + ---------- + path : pathlib.Path + Path to the HDF5 file. + expected_datasets : list of str or None + List of names of the expected datasets in the HDF5 file. If an empty + list (default), this check is skipped. + + Raises + ------ + ValueError + If the file is not in HDF5 format or if it does not contain the + expected datasets. + """ + + path: Path = field(validator=validators.instance_of(Path)) + expected_datasets: List[str] = field(factory=list, kw_only=True) + + @path.validator + def file_is_h5(self, attribute, value): + """Ensure that the file is indeed in HDF5 format.""" + try: + with h5py.File(value, "r") as f: + f.close() + except Exception as e: + raise log_error( + ValueError, + f"File {value} does not seem to be in valid" "HDF5 format.", + ) from e + + @path.validator + def file_contains_expected_datasets(self, attribute, value): + """Ensure that the HDF5 file contains the expected datasets.""" + if self.expected_datasets: + with h5py.File(value, "r") as f: + diff = set(self.expected_datasets).difference(set(f.keys())) + if len(diff) > 0: + raise log_error( + ValueError, + f"Could not find the expected dataset(s) {diff} " + f"in file: {value}. ", + ) + + +@define +class ValidPosesCSV: + """Class for validating CSV files that contain pose estimation outputs. + in DeepLabCut format. + + Parameters + ---------- + path : pathlib.Path + Path to the CSV file. + + Raises + ------ + ValueError + If the CSV file does not contain the expected DeepLabCut index column + levels among its top rows. + """ + + path: Path = field(validator=validators.instance_of(Path)) + + @path.validator + def csv_file_contains_expected_levels(self, attribute, value): + """Ensure that the CSV file contains the expected index column levels + among its top rows.""" + expected_levels = ["scorer", "bodyparts", "coords"] + + with open(value, "r") as f: + top4_row_starts = [f.readline().split(",")[0] for _ in range(4)] + + if top4_row_starts[3].isdigit(): + # if 4th row starts with a digit, assume single-animal DLC file + expected_levels.append(top4_row_starts[3]) + else: + # otherwise, assume multi-animal DLC file + expected_levels.insert(1, "individuals") + + if top4_row_starts != expected_levels: + raise log_error( + ValueError, + "CSV header rows do not match the known format for " + "DeepLabCut pose estimation output files.", + ) + + +def _list_of_str(value: Union[str, Iterable[Any]]) -> List[str]: + """Try to coerce the value into a list of strings. + Otherwise, raise a ValueError.""" + if isinstance(value, str): + log_warning( + f"Invalid value ({value}). Expected a list of strings. " + "Converting to a list of length 1." + ) + return [value] + elif isinstance(value, Iterable): + return [str(item) for item in value] + else: + raise log_error( + ValueError, f"Invalid value ({value}). Expected a list of strings." + ) + + +def _ensure_type_ndarray(value: Any) -> None: + """Raise ValueError the value is a not numpy array.""" + if not isinstance(value, np.ndarray): + raise log_error( + ValueError, f"Expected a numpy array, but got {type(value)}." + ) + + +def _set_fps_to_none_if_invalid(fps: Optional[float]) -> Optional[float]: + """Set fps to None if a non-positive float is passed.""" + if fps is not None and fps <= 0: + log_warning( + f"Invalid fps value ({fps}). Expected a positive number. " + "Setting fps to None." + ) + return None + return fps + + +def _validate_list_length( + attribute: str, value: Optional[List], expected_length: int +): + """Raise a ValueError if the list does not have the expected length.""" + if (value is not None) and (len(value) != expected_length): + raise log_error( + ValueError, + f"Expected `{attribute}` to have length {expected_length}, " + f"but got {len(value)}.", + ) + + +@define(kw_only=True) +class ValidPoseTracks: + """Class for validating pose tracking data imported from a file. + + Attributes + ---------- + tracks_array : np.ndarray + Array of shape (n_frames, n_individuals, n_keypoints, n_space) + containing the pose tracks. It will be converted to a + `xarray.DataArray` object named "pose_tracks". + scores_array : np.ndarray, optional + Array of shape (n_frames, n_individuals, n_keypoints) containing + the point-wise confidence scores. It will be converted to a + `xarray.DataArray` object named "confidence". + If None (default), the scores will be set to an array of NaNs. + individual_names : list of str, optional + List of unique names for the individuals in the video. If None + (default), the individuals will be named "individual_0", + "individual_1", etc. + keypoint_names : list of str, optional + List of unique names for the keypoints in the skeleton. If None + (default), the keypoints will be named "keypoint_0", "keypoint_1", + etc. + fps : float, optional + Frames per second of the video. Defaults to None. """ - file_path: Path - - @field_validator("file_path") - def file_must_exist(cls, value): - if not value.is_file(): - error_msg = f"File not found: {value}" - logger.error(error_msg) - raise FileNotFoundError(error_msg) - return value - - @field_validator("file_path") - def file_must_have_valid_suffix(cls, value): - if value.suffix not in (".h5", ".csv"): - error_msg = ( - "Expected a file with pose estimation results from " - "DeepLabCut, in one of '.h5' or '.csv' formats. " - f"Received a file with suffix '{value.suffix}' instead." - ) - logger.error(error_msg) - raise ValueError(error_msg) - return value + # Define class attributes + tracks_array: np.ndarray = field() + scores_array: Optional[np.ndarray] = field(default=None) + individual_names: Optional[List[str]] = field( + default=None, + converter=converters.optional(_list_of_str), + ) + keypoint_names: Optional[List[str]] = field( + default=None, + converter=converters.optional(_list_of_str), + ) + fps: Optional[float] = field( + default=None, + converter=converters.pipe( # type: ignore + converters.optional(float), _set_fps_to_none_if_invalid + ), + ) + + # Add validators + @tracks_array.validator + def _validate_tracks_array(self, attribute, value): + _ensure_type_ndarray(value) + if value.ndim != 4: + raise log_error( + ValueError, + f"Expected `{attribute}` to have 4 dimensions, " + f"but got {value.ndim}.", + ) + if value.shape[-1] not in [2, 3]: + raise log_error( + ValueError, + f"Expected `{attribute}` to have 2 or 3 spatial dimensions, " + f"but got {value.shape[-1]}.", + ) + + @scores_array.validator + def _validate_scores_array(self, attribute, value): + if value is not None: + _ensure_type_ndarray(value) + if value.shape != self.tracks_array.shape[:-1]: + raise log_error( + ValueError, + f"Expected `{attribute}` to have shape " + f"{self.tracks_array.shape[:-1]}, but got {value.shape}.", + ) + + @individual_names.validator + def _validate_individual_names(self, attribute, value): + _validate_list_length(attribute, value, self.tracks_array.shape[1]) + + @keypoint_names.validator + def _validate_keypoint_names(self, attribute, value): + _validate_list_length(attribute, value, self.tracks_array.shape[2]) + + def __attrs_post_init__(self): + """Assign default values to optional attributes (if None)""" + if self.scores_array is None: + self.scores_array = np.full( + (self.tracks_array.shape[:-1]), np.nan, dtype="float32" + ) + log_warning( + "Scores array was not provided. Setting to an array of NaNs." + ) + if self.individual_names is None: + self.individual_names = [ + f"individual_{i}" for i in range(self.tracks_array.shape[1]) + ] + log_warning( + "Individual names were not provided. " + f"Setting to {self.individual_names}." + ) + if self.keypoint_names is None: + self.keypoint_names = [ + f"keypoint_{i}" for i in range(self.tracks_array.shape[2]) + ] + log_warning( + "Keypoint names were not provided. " + f"Setting to {self.keypoint_names}." + ) diff --git a/movement/log_config.py b/movement/logging.py similarity index 68% rename from movement/log_config.py rename to movement/logging.py index b788fa65..994b2d81 100644 --- a/movement/log_config.py +++ b/movement/logging.py @@ -22,7 +22,7 @@ def configure_logging( The logging level to use. Defaults to logging.INFO. logger_name : str, optional The name of the logger to configure. - Defaults to 'movement'. + Defaults to "movement". log_directory : pathlib.Path, optional The directory to store the log file in. Defaults to ~/.movement. A different directory can be specified, @@ -59,3 +59,39 @@ def configure_logging( # Add the handler to the logger logger.addHandler(handler) + + +def log_error(error, message: str, logger_name: str = "movement"): + """Log an error message and return the Exception. + + Parameters + ---------- + error : Exception + The error to log and return. + message : str + The error message. + logger_name : str, optional + The name of the logger to use. Defaults to "movement". + + Returns + ------- + Exception + The error that was passed in. + """ + logger = logging.getLogger(logger_name) + logger.error(message) + return error(message) + + +def log_warning(message: str, logger_name: str = "movement"): + """Log a warning message. + + Parameters + ---------- + message : str + The warning message. + logger_name : str, optional + The name of the logger to use. Defaults to "movement". + """ + logger = logging.getLogger(logger_name) + logger.warning(message) diff --git a/pyproject.toml b/pyproject.toml index bdcbbea4..6674e52b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,9 +15,11 @@ dependencies = [ "numpy", "pandas", "h5py", - "pydantic", + "attrs", "pooch", "tqdm", + "sleap-io", + "xarray", ] classifiers = [ @@ -50,6 +52,9 @@ dev = [ "pre-commit", "ruff", "setuptools_scm", + "pandas-stubs", + "types-attrs", + "check-manifest", ] [build-system] @@ -89,6 +94,13 @@ ignore = [ "docs/source/", ] +[[tool.mypy.overrides]] +module = [ + "pooch.*", + "h5py.*", + "sleap_io.*", +] +ignore_missing_imports = true [tool.ruff] line-length = 79 diff --git a/tests/conftest.py b/tests/conftest.py index 2b56fbae..1949167a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,7 +2,7 @@ import pytest -from movement.log_config import configure_logging +from movement.logging import configure_logging @pytest.fixture(autouse=True) diff --git a/tests/test_unit/test_io.py b/tests/test_unit/test_io.py index fec91eca..bf656f5a 100644 --- a/tests/test_unit/test_io.py +++ b/tests/test_unit/test_io.py @@ -1,46 +1,148 @@ import os +import stat import h5py +import numpy as np import pandas as pd import pytest -from pandas.testing import assert_frame_equal -from pydantic import ValidationError -from tables import HDF5ExtError +import xarray as xr from movement.datasets import fetch_pose_data_path -from movement.io import load_poses +from movement.io import PosesAccessor, load_poses, save_poses +from movement.io.validators import ( + ValidFile, + ValidHDF5, + ValidPosesCSV, + ValidPoseTracks, +) -class TestLoadPoses: - """Test the load_poses module.""" +class TestPosesIO: + """Test the IO functionalities of the PoseTracks class.""" @pytest.fixture - def valid_dlc_files(self): - """Return the paths to valid DLC poses files, - in .h5 format. - - Returns - ------- - dict - Dictionary containing the paths. - - h5_path: pathlib Path to a valid .h5 file - - h5_str: path as str to a valid .h5 file - """ - h5_file = fetch_pose_data_path("DLC_single-wasp.predictions.h5") - csv_file = fetch_pose_data_path("DLC_single-wasp.predictions.csv") + def valid_tracks_array(self): + """Return a valid tracks array.""" + return np.zeros((10, 2, 2, 2)) + + @pytest.fixture + def valid_pose_dataset(self, valid_tracks_array): + """Return a valid pose tracks dataset.""" + dim_names = PosesAccessor.dim_names + return xr.Dataset( + data_vars={ + "pose_tracks": xr.DataArray( + valid_tracks_array, dims=dim_names + ), + "confidence": xr.DataArray( + valid_tracks_array[..., 0], dims=dim_names[:-1] + ), + }, + coords={ + "time": np.arange(valid_tracks_array.shape[0]), + "individuals": ["ind1", "ind2"], + "keypoints": ["key1", "key2"], + "space": ["x", "y"], + }, + attrs={ + "fps": None, + "time_unit": "frames", + "source_software": "SLEAP", + "source_file": "test.h5", + }, + ) + + @pytest.fixture + def invalid_pose_datasets(self, valid_pose_dataset): + """Return a list of invalid pose tracks datasets.""" return { - "h5_path": h5_file, - "h5_str": h5_file.as_posix(), - "csv_path": csv_file, - "csv_str": csv_file.as_posix(), + "not_a_dataset": [1, 2, 3], + "empty_dataset": xr.Dataset(), + "missing_var": valid_pose_dataset.drop_vars("pose_tracks"), + "missing_dim": valid_pose_dataset.drop_dims("time"), + } + + @pytest.fixture + def dlc_file_h5_single(self): + """Return the path to a valid DLC h5 file containing pose data + for a single animal.""" + return fetch_pose_data_path("DLC_single-wasp.predictions.h5") + + @pytest.fixture + def dlc_file_csv_single(self): + """Return the path to a valid DLC .csv file containing pose data + for a single animal. The underlying data is the same as in the + `dlc_file_h5_single` fixture.""" + return fetch_pose_data_path("DLC_single-wasp.predictions.csv") + + @pytest.fixture + def dlc_file_csv_multi(self): + """Return the path to a valid DLC .csv file containing pose data + for multiple animals.""" + return fetch_pose_data_path("DLC_two-mice.predictions.csv") + + @pytest.fixture + def sleap_file_h5_single(self): + """Return the path to a valid SLEAP "analysis" .h5 file containing + pose data for a single animal.""" + return fetch_pose_data_path("SLEAP_single-mouse_EPM.analysis.h5") + + @pytest.fixture + def sleap_file_slp_single(self): + """Return the path to a valid SLEAP .slp file containing + predicted poses (labels) for a single animal.""" + return fetch_pose_data_path("SLEAP_single-mouse_EPM.predictions.slp") + + @pytest.fixture + def sleap_file_h5_multi(self): + """Return the path to a valid SLEAP "analysis" .h5 file containing + pose data for multiple animals.""" + return fetch_pose_data_path( + "SLEAP_three-mice_Aeon_proofread.analysis.h5" + ) + + @pytest.fixture + def sleap_file_slp_multi(self): + """Return the path to a valid SLEAP .slp file containing + predicted poses (labels) for multiple animals.""" + return fetch_pose_data_path( + "SLEAP_three-mice_Aeon_proofread.predictions.slp" + ) + + @pytest.fixture + def valid_files( + self, + dlc_file_h5_single, + dlc_file_csv_single, + dlc_file_csv_multi, + sleap_file_h5_single, + sleap_file_slp_single, + sleap_file_h5_multi, + sleap_file_slp_multi, + ): + """Aggregate all valid files in a dictionary, for convenience.""" + return { + "DLC_h5_single": dlc_file_h5_single, + "DLC_csv_single": dlc_file_csv_single, + "DLC_csv_multi": dlc_file_csv_multi, + "SLEAP_h5_single": sleap_file_h5_single, + "SLEAP_slp_single": sleap_file_slp_single, + "SLEAP_h5_multi": sleap_file_h5_multi, + "SLEAP_slp_multi": sleap_file_slp_multi, } @pytest.fixture def invalid_files(self, tmp_path): + """Return a dictionary containing paths to invalid files.""" unreadable_file = tmp_path / "unreadable.h5" with open(unreadable_file, "w") as f: f.write("unreadable data") - os.chmod(f.name, 0o000) + os.chmod(f.name, not stat.S_IRUSR) + + unwriteable_dir = tmp_path / "no_write" + unwriteable_dir.mkdir() + os.chmod(unwriteable_dir, not stat.S_IWUSR) + unwritable_file = unwriteable_dir / "unwritable.h5" wrong_ext_file = tmp_path / "wrong_extension.txt" with open(wrong_ext_file, "w") as f: @@ -52,44 +154,296 @@ def invalid_files(self, tmp_path): nonexistent_file = tmp_path / "nonexistent.h5" + directory = tmp_path / "directory" + directory.mkdir() + + fake_h5_file = tmp_path / "fake.h5" + with open(fake_h5_file, "w") as f: + f.write("") + + fake_csv_file = tmp_path / "fake.csv" + with open(fake_csv_file, "w") as f: + f.write("some,columns\n") + f.write("1,2") + return { "unreadable": unreadable_file, + "unwritable": unwritable_file, "wrong_ext": wrong_ext_file, "no_dataframe": h5_file_no_dataframe, "nonexistent": nonexistent_file, + "directory": directory, + "fake_h5": fake_h5_file, + "fake_csv": fake_csv_file, } - def test_load_valid_dlc_files(self, valid_dlc_files): - """Test loading valid DLC poses files.""" - for file_type, file_path in valid_dlc_files.items(): - df = load_poses.from_dlc(file_path) - assert isinstance(df, pd.DataFrame) - assert not df.empty + @pytest.fixture + def dlc_style_df(self, dlc_file_h5_single): + """Return a valid DLC-style DataFrame.""" + df = pd.read_hdf(dlc_file_h5_single) + return df + + def test_load_from_valid_files(self, valid_files): + """Test that loading pose tracks from a wide variety of valid files + returns a proper Dataset.""" + abbrev_expand = {"DLC": "DeepLabCut", "SLEAP": "SLEAP"} + + for file_type, file_path in valid_files.items(): + if file_type.startswith("DLC"): + ds = load_poses.from_dlc_file(file_path) + elif file_type.startswith("SLEAP"): + ds = load_poses.from_sleap_file(file_path) + + assert isinstance(ds, xr.Dataset) + # Expected variables are present and of right shape/type + for var in ["pose_tracks", "confidence"]: + assert var in ds.data_vars + assert isinstance(ds[var], xr.DataArray) + assert ds.pose_tracks.ndim == 4 + assert ds.confidence.shape == ds.pose_tracks.shape[:-1] + # Check the dims and coords + DIM_NAMES = PosesAccessor.dim_names + assert all([i in ds.dims for i in DIM_NAMES]) + for d, dim in enumerate(DIM_NAMES[1:]): + assert ds.dims[dim] == ds.pose_tracks.shape[d + 1] + assert all([isinstance(s, str) for s in ds.coords[dim].values]) + assert all([i in ds.coords["space"] for i in ["x", "y"]]) + # Check the metadata attributes + assert ds.source_software == abbrev_expand[file_type.split("_")[0]] + assert ds.source_file == file_path.as_posix() + assert ds.fps is None + + def test_load_from_invalid_files(self, invalid_files): + """Test that loading pose tracks from a wide variety of invalid files + raises the appropriate errors.""" + for file_path in invalid_files.values(): + with pytest.raises((OSError, ValueError)): + load_poses.from_dlc_file(file_path) + with pytest.raises((OSError, ValueError)): + load_poses.from_sleap_file(file_path) + + @pytest.mark.parametrize("file_path", [1, 1.0, True, None, [], {}]) + def test_load_with_incorrect_file_path_types(self, file_path): + """Test loading poses from a file_path with an incorrect type.""" + with pytest.raises(TypeError): + load_poses.from_dlc_file(file_path) + with pytest.raises(TypeError): + load_poses.from_sleap_file(file_path) - def test_load_invalid_dlc_files(self, invalid_files): - """Test loading invalid DLC poses files.""" + def test_file_validator(self, invalid_files): + """Test that the file validator class raises the right errors.""" for file_type, file_path in invalid_files.items(): - if file_type == "nonexistent": - with pytest.raises(FileNotFoundError): - load_poses.from_dlc(file_path) + if file_type == "unreadable": + with pytest.raises(PermissionError): + ValidFile(path=file_path, expected_permission="r") + elif file_type == "unwritable": + with pytest.raises(PermissionError): + ValidFile(path=file_path, expected_permission="w") elif file_type == "wrong_ext": with pytest.raises(ValueError): - load_poses.from_dlc(file_path) - else: - with pytest.raises((OSError, HDF5ExtError)): - load_poses.from_dlc(file_path) + ValidFile( + path=file_path, + expected_permission="r", + expected_suffix=["h5", "csv"], + ) + elif file_type == "nonexistent": + with pytest.raises(FileNotFoundError): + ValidFile(path=file_path, expected_permission="r") + elif file_type == "directory": + with pytest.raises(IsADirectoryError): + ValidFile(path=file_path, expected_permission="r") + elif file_type in ["fake_h5", "no_dataframe"]: + with pytest.raises(ValueError): + ValidHDF5(path=file_path, expected_datasets=["dataframe"]) + elif file_type == "fake_csv": + with pytest.raises(ValueError): + ValidPosesCSV(path=file_path) - @pytest.mark.parametrize("file_path", [1, 1.0, True, None, [], {}]) - def test_load_from_dlc_with_incorrect_file_path_types(self, file_path): - """Test loading poses from a file_path with an incorrect type.""" - with pytest.raises(ValidationError): - load_poses.from_dlc(file_path) + def test_load_and_save_to_dlc_df(self, dlc_style_df): + """Test that loading pose tracks from a DLC-style DataFrame and + converting back to a DataFrame returns the same data values.""" + ds = load_poses.from_dlc_df(dlc_style_df) + df = save_poses.to_dlc_df(ds) + np.testing.assert_allclose(df.values, dlc_style_df.values) + + def test_save_and_load_dlc_file(self, valid_pose_dataset, tmp_path): + """Test that saving pose tracks to DLC .h5 and .csv files and then + loading them back in returns the same Dataset.""" + save_poses.to_dlc_file(valid_pose_dataset, tmp_path / "dlc.h5") + save_poses.to_dlc_file(valid_pose_dataset, tmp_path / "dlc.csv") + ds_from_h5 = load_poses.from_dlc_file(tmp_path / "dlc.h5") + ds_from_csv = load_poses.from_dlc_file(tmp_path / "dlc.csv") + xr.testing.assert_allclose(ds_from_h5, valid_pose_dataset) + xr.testing.assert_allclose(ds_from_csv, valid_pose_dataset) + + def test_save_valid_dataset_to_invalid_file_paths( + self, valid_pose_dataset, invalid_files, tmp_path + ): + with pytest.raises(FileExistsError): + save_poses.to_dlc_file( + valid_pose_dataset, invalid_files["fake_h5"] + ) + with pytest.raises(ValueError): + save_poses.to_dlc_file(valid_pose_dataset, tmp_path / "dlc.txt") + with pytest.raises(IsADirectoryError): + save_poses.to_dlc_file( + valid_pose_dataset, invalid_files["directory"] + ) + + def test_load_from_dlc_file_csv_or_h5_file_returns_same( + self, dlc_file_h5_single, dlc_file_csv_single + ): + """Test that loading pose tracks from DLC .csv and .h5 files + return the same Dataset.""" + ds_from_h5 = load_poses.from_dlc_file(dlc_file_h5_single) + ds_from_csv = load_poses.from_dlc_file(dlc_file_csv_single) + xr.testing.assert_allclose(ds_from_h5, ds_from_csv) + + @pytest.mark.parametrize("fps", [None, -5, 0, 30, 60.0]) + def test_fps_and_time_coords(self, sleap_file_h5_multi, fps): + """Test that time coordinates are set according to the fps.""" + ds = load_poses.from_sleap_file(sleap_file_h5_multi, fps=fps) + if (fps is None) or (fps <= 0): + assert ds.fps is None + assert ds.time_unit == "frames" + else: + assert ds.fps == fps + assert ds.time_unit == "seconds" + np.testing.assert_allclose( + ds.coords["time"].data, + np.arange(ds.dims["time"], dtype=int) / ds.attrs["fps"], + ) + + def test_load_from_str_path(self, sleap_file_h5_single): + """Test that file paths provided as strings are accepted as input.""" + xr.testing.assert_allclose( + load_poses.from_sleap_file(sleap_file_h5_single), + load_poses.from_sleap_file(sleap_file_h5_single.as_posix()), + ) + + def test_save_invalid_pose_datasets(self, invalid_pose_datasets, tmp_path): + """Test that saving invalid pose datasets raises ValueError.""" + for ds in invalid_pose_datasets.values(): + with pytest.raises(ValueError): + save_poses.to_dlc_file(ds, tmp_path / "test.h5") + + @pytest.mark.parametrize( + "tracks_array", + [ + None, # invalid, argument is non-optional + [1, 2, 3], # not an ndarray + np.zeros((10, 2, 3)), # not 4d + np.zeros((10, 2, 3, 4)), # last dim not 2 or 3 + ], + ) + def test_tracks_array_validation(self, tracks_array): + """Test that invalid tracks arrays raise the appropriate errors.""" + with pytest.raises(ValueError): + ValidPoseTracks(tracks_array=tracks_array) + + @pytest.mark.parametrize( + "scores_array", + [ + None, # valid, should default to array of NaNs + np.ones((10, 3, 2)), # will not match tracks_array shape + [1, 2, 3], # not an ndarray, should raise ValueError + ], + ) + def test_scores_array_validation(self, valid_tracks_array, scores_array): + """Test that invalid scores arrays raise the appropriate errors.""" + if scores_array is None: + poses = ValidPoseTracks(tracks_array=valid_tracks_array) + assert np.all(np.isnan(poses.scores_array)) + else: + with pytest.raises(ValueError): + ValidPoseTracks( + tracks_array=valid_tracks_array, scores_array=scores_array + ) + + @pytest.mark.parametrize( + "individual_names", + [ + None, # generate default names + ["ind1", "ind2"], # valid input + ("ind1", "ind2"), # valid input + [1, 2], # will be converted to ["1", "2"] + "ind1", # will be converted to ["ind1"] + 5, # invalid, should raise ValueError + ], + ) + def test_individual_names_validation( + self, valid_tracks_array, individual_names + ): + if individual_names is None: + poses = ValidPoseTracks( + tracks_array=valid_tracks_array, + individual_names=individual_names, + ) + assert poses.individual_names == ["individual_0", "individual_1"] + elif isinstance(individual_names, (list, tuple)): + poses = ValidPoseTracks( + tracks_array=valid_tracks_array, + individual_names=individual_names, + ) + assert poses.individual_names == [str(i) for i in individual_names] + elif isinstance(individual_names, str): + poses = ValidPoseTracks( + tracks_array=np.zeros((10, 1, 2, 2)), + individual_names=individual_names, + ) + assert poses.individual_names == [individual_names] + # raises error if not 1 individual + with pytest.raises(ValueError): + ValidPoseTracks( + tracks_array=valid_tracks_array, + individual_names=individual_names, + ) + else: + with pytest.raises(ValueError): + ValidPoseTracks( + tracks_array=valid_tracks_array, + individual_names=individual_names, + ) - def test_load_from_dlc_csv_or_h5_file_returns_same_df( - self, valid_dlc_files + @pytest.mark.parametrize( + "keypoint_names", + [ + None, # generate default names + ["key1", "key2"], # valid input + ("key", "key2"), # valid input + [1, 2], # will be converted to ["1", "2"] + "key1", # will be converted to ["ind1"] + 5, # invalid, should raise ValueError + ], + ) + def test_keypoint_names_validation( + self, valid_tracks_array, keypoint_names ): - """Test that loading poses from DLC .csv and .h5 files - return the same DataFrame.""" - df_from_h5 = load_poses.from_dlc(valid_dlc_files["h5_path"]) - df_from_csv = load_poses.from_dlc(valid_dlc_files["csv_path"]) - assert_frame_equal(df_from_h5, df_from_csv) + if keypoint_names is None: + poses = ValidPoseTracks( + tracks_array=valid_tracks_array, keypoint_names=keypoint_names + ) + assert poses.keypoint_names == ["keypoint_0", "keypoint_1"] + elif isinstance(keypoint_names, (list, tuple)): + poses = ValidPoseTracks( + tracks_array=valid_tracks_array, keypoint_names=keypoint_names + ) + assert poses.keypoint_names == [str(i) for i in keypoint_names] + elif isinstance(keypoint_names, str): + poses = ValidPoseTracks( + tracks_array=np.zeros((10, 2, 1, 2)), + keypoint_names=keypoint_names, + ) + assert poses.keypoint_names == [keypoint_names] + # raises error if not 1 keypoint + with pytest.raises(ValueError): + ValidPoseTracks( + tracks_array=valid_tracks_array, + keypoint_names=keypoint_names, + ) + else: + with pytest.raises(ValueError): + ValidPoseTracks( + tracks_array=valid_tracks_array, + keypoint_names=keypoint_names, + ) diff --git a/tests/test_unit/test_logging.py b/tests/test_unit/test_logging.py index 6dd4b9ae..bd4d2f6e 100644 --- a/tests/test_unit/test_logging.py +++ b/tests/test_unit/test_logging.py @@ -2,6 +2,8 @@ import pytest +from movement.logging import log_error, log_warning + log_messages = { "DEBUG": "This is a debug message", "INFO": "This is an info message", @@ -21,3 +23,20 @@ def test_logfile_contains_message(level, message): last_line = f.readlines()[-1] assert level in last_line assert message in last_line + + +def test_log_error(caplog): + """Check if the log_error function + logs the error message and returns an Exception.""" + with pytest.raises(ValueError): + raise log_error(ValueError, "This is a test error") + assert caplog.records[0].message == "This is a test error" + assert caplog.records[0].levelname == "ERROR" + + +def test_log_warning(caplog): + """Check if the log_warning function + logs the warning message.""" + log_warning("This is a test warning") + assert caplog.records[0].message == "This is a test warning" + assert caplog.records[0].levelname == "WARNING"