Skip to content

Commit

Permalink
Merge branch 'main' into smg/review-vector-tests
Browse files Browse the repository at this point in the history
sfmig committed Sep 16, 2024
2 parents 67ceaec + 644c1b1 commit 6742bae
Showing 32 changed files with 1,169 additions and 506 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test_and_deploy.yml
Original file line number Diff line number Diff line change
@@ -52,7 +52,7 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
auto-update-conda: true
channels: conda-forge
channels: conda-forge,nodefaults
activate-environment: movement-env
- uses: neuroinformatics-unit/actions/test@v2
with:
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -29,12 +29,12 @@ repos:
- id: rst-directive-colons
- id: rst-inline-touching-normal
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.5.6
rev: v0.6.3
hooks:
- id: ruff
- id: ruff-format
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.11.1
rev: v1.11.2
hooks:
- id: mypy
additional_dependencies:
85 changes: 63 additions & 22 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
@@ -12,15 +12,12 @@ development environment for movement. In the following we assume you have
First, create and activate a `conda` environment with some prerequisites:

```sh
conda create -n movement-dev -c conda-forge python=3.10 pytables
conda create -n movement-dev -c conda-forge python=3.11 pytables
conda activate movement-dev
```

The above method ensures that you will get packages that often can't be
installed via `pip`, including [hdf5](https://www.hdfgroup.org/solutions/hdf5/).

To install movement for development, clone the GitHub repository,
and then run from inside the repository:
To install movement for development, clone the [GitHub repository](movement-github:),
and then run from within the repository:

```sh
pip install -e .[dev] # works on most shells
@@ -162,13 +159,12 @@ The version number is automatically determined from the latest tag on the _main_
The documentation is hosted via [GitHub pages](https://pages.github.com/) at
[movement.neuroinformatics.dev](target-movement).
Its source files are located in the `docs` folder of this repository.
They are written in either [reStructuredText](https://docutils.sourceforge.io/rst.html) or
[markdown](myst-parser:syntax/typography.html).
They are written in either [Markdown](myst-parser:syntax/typography.html)
or [reStructuredText](https://docutils.sourceforge.io/rst.html).
The `index.md` file corresponds to the homepage of the documentation website.
Other `.rst` or `.md` files are linked to the homepage via the `toctree` directive.
Other `.md` or `.rst` files are linked to the homepage via the `toctree` directive.

We use [Sphinx](https://www.sphinx-doc.org/en/master/) and the
[PyData Sphinx Theme](https://pydata-sphinx-theme.readthedocs.io/en/stable/index.html)
We use [Sphinx](sphinx-doc:) and the [PyData Sphinx Theme](https://pydata-sphinx-theme.readthedocs.io/en/stable/index.html)
to build the source files into HTML output.
This is handled by a GitHub actions workflow (`.github/workflows/docs_build_and_deploy.yml`).
The build job is triggered on each PR, ensuring that the documentation build is not broken by new changes.
@@ -199,17 +195,16 @@ existing_file
my_new_file
```

#### Adding external links
If you are adding references to an external link (e.g. `https://github.com/neuroinformatics-unit/movement/issues/1`) in a `.md` file, you will need to check if a matching URL scheme (e.g. `https://github.com/neuroinformatics-unit/movement/`) is defined in `myst_url_schemes` in `docs/source/conf.py`. If it is, the following `[](scheme:loc)` syntax will be converted to the [full URL](movement-github:issues/1) during the build process:
#### Linking to external URLs
If you are adding references to an external URL (e.g. `https://github.com/neuroinformatics-unit/movement/issues/1`) in a `.md` file, you will need to check if a matching URL scheme (e.g. `https://github.com/neuroinformatics-unit/movement/`) is defined in `myst_url_schemes` in `docs/source/conf.py`. If it is, the following `[](scheme:loc)` syntax will be converted to the [full URL](movement-github:issues/1) during the build process:
```markdown
[link text](movement-github:issues/1)
```

If it is not yet defined and you have multiple external links pointing to the same base URL, you will need to [add the URL scheme](myst-parser:syntax/cross-referencing.html#customising-external-url-resolution) to `myst_url_schemes` in `docs/source/conf.py`.

If it is not yet defined and you have multiple external URLs pointing to the same base URL, you will need to [add the URL scheme](myst-parser:syntax/cross-referencing.html#customising-external-url-resolution) to `myst_url_schemes` in `docs/source/conf.py`.

### Updating the API reference
The API reference is auto-generated by the `docs/make_api_index.py` script, and the `sphinx-autodoc` and `sphinx-autosummary` plugins.
The [API reference](target-api) is auto-generated by the `docs/make_api_index.py` script, and the [sphinx-autodoc](sphinx-doc:extensions/autodoc.html) and [sphinx-autosummary](sphinx-doc:extensions/autosummary.html) extensions.
The script generates the `docs/source/api_index.rst` file containing the list of modules to be included in the [API reference](target-api).
The plugins then generate the API reference pages for each module listed in `api_index.rst`, based on the docstrings in the source code.
So make sure that all your public functions/classes/methods have valid docstrings following the [numpydoc](https://numpydoc.readthedocs.io/en/latest/format.html) style.
@@ -224,14 +219,60 @@ To add new examples, you will need to create a new `.py` file in `examples/`.
The file should be structured as specified in the relevant
[sphinx-gallery documentation](sphinx-gallery:syntax).

We are using sphinx-gallery's [integration with binder](https://sphinx-gallery.github.io/stable/configuration.html#binder-links)
We are using sphinx-gallery's [integration with binder](sphinx-gallery:configuration#binder-links)
to provide interactive versions of the examples.
If your examples rely on packages that are not among movement's dependencies,
you will need to add them to the `docs/source/environment.yml` file.
That file is used by binder to create the conda environment in which the
examples are run. See the relevant section of the
[binder documentation](https://mybinder.readthedocs.io/en/latest/using/config_files.html).

### Cross-referencing Python objects
:::{note}
Docstrings in the `.py` files for the [API reference](target-api) and the [examples](target-examples) are converted into `.rst` files, so these should use reStructuredText syntax.
:::

#### Internal references
::::{tab-set}
:::{tab-item} Markdown
For referencing movement objects in `.md` files, use the `` {role}`target` `` syntax with the appropriate [Python object role](sphinx-doc:domains/python.html#cross-referencing-python-objects).

For example, to reference the {mod}`movement.io.load_poses` module, use:
```markdown
{mod}`movement.io.load_poses`
```
:::
:::{tab-item} RestructuredText
For referencing movement objects in `.rst` files, use the `` :role:`target` `` syntax with the appropriate [Python object role](sphinx-doc:domains/python.html#cross-referencing-python-objects).

For example, to reference the {mod}`movement.io.load_poses` module, use:
```rst
:mod:`movement.io.load_poses`
```
:::
::::

#### External references
For referencing external Python objects using [intersphinx](sphinx-doc:extensions/intersphinx.html),
ensure the mapping between module names and their documentation URLs is defined in [`intersphinx_mapping`](sphinx-doc:extensions/intersphinx.html#confval-intersphinx_mapping) in `docs/source/conf.py`.
Once the module is included in the mapping, use the same syntax as for [internal references](#internal-references).

::::{tab-set}
:::{tab-item} Markdown
For example, to reference the {meth}`xarray.Dataset.update` method, use:
```markdown
{meth}`xarray.Dataset.update`
```
:::

:::{tab-item} RestructuredText
For example, to reference the {meth}`xarray.Dataset.update` method, use:
```rst
:meth:`xarray.Dataset.update`
```
:::
::::

### Building the documentation locally
We recommend that you build and view the documentation website locally, before you push it.
To do so, first navigate to `docs/`.
@@ -256,7 +297,7 @@ The local build can be viewed by opening `docs/build/html/index.html` in a brows

:::{tab-item} All platforms
```sh
python make_api_index.py && sphinx-build source build
python make_api_index.py && sphinx-build source build -W --keep-going
```
The local build can be viewed by opening `docs/build/index.html` in a browser.
:::
@@ -276,7 +317,7 @@ make clean html
:::{tab-item} All platforms
```sh
rm -f source/api_index.rst && rm -rf build && rm -rf source/api && rm -rf source/examples
python make_api_index.py && sphinx-build source build
python make_api_index.py && sphinx-build source build -W --keep-going
```
:::
::::
@@ -292,7 +333,7 @@ make linkcheck

:::{tab-item} All platforms
```sh
sphinx-build source build -b linkcheck
sphinx-build source build -b linkcheck -W --keep-going
```
:::
::::
@@ -338,7 +379,7 @@ The most important parts of this module are:
1. The `SAMPLE_DATA` download manager object.
2. The `list_datasets()` function, which returns a list of the available poses and bounding boxes datasets (file names of the data files).
3. The `fetch_dataset_paths()` function, which returns a dictionary containing local paths to the files associated with a particular sample dataset: `poses` or `bboxes`, `frame`, `video`. If the relevant files are not already cached locally, they will be downloaded.
4. The `fetch_dataset()` function, which downloads the files associated with a given sample dataset (same as `fetch_dataset_paths()`) and additionally loads the pose or bounding box data into `movement`, returning an `xarray.Dataset` object. If available, the local paths to the associated video and frame files are stored as dataset attributes, with names `video_path` and `frame_path`, respectively.
4. The `fetch_dataset()` function, which downloads the files associated with a given sample dataset (same as `fetch_dataset_paths()`) and additionally loads the pose or bounding box data into movement, returning an `xarray.Dataset` object. If available, the local paths to the associated video and frame files are stored as dataset attributes, with names `video_path` and `frame_path`, respectively.

By default, the downloaded files are stored in the `~/.movement/data` folder.
This can be changed by setting the `DATA_DIR` variable in the `movement.sample_data.py` module.
@@ -372,7 +413,7 @@ To add a new file, you will need to:
```
:::
::::
For convenience, we've included a `get_sha256_hashes.py` script in the [movement data repository](gin:neuroinformatics/movement-test-data). If you run this from the root of the data repository, within a Python environment with `movement` installed, it will calculate the sha256 hashes for all files in the `poses`, `bboxes`, `videos` and `frames` folders and write them to files named `poses_hashes.txt`, `bboxes_hashes.txt`, `videos_hashes.txt`, and `frames_hashes.txt` respectively.
For convenience, we've included a `get_sha256_hashes.py` script in the [movement data repository](gin:neuroinformatics/movement-test-data). If you run this from the root of the data repository, within a Python environment with movement installed, it will calculate the sha256 hashes for all files in the `poses`, `bboxes`, `videos` and `frames` folders and write them to files named `poses_hashes.txt`, `bboxes_hashes.txt`, `videos_hashes.txt`, and `frames_hashes.txt` respectively.
7. Add metadata for your new files to `metadata.yaml`, including their sha256 hashes you've calculated. See the example entry below for guidance.

11 changes: 3 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
@@ -16,17 +16,12 @@ A Python toolbox for analysing body movements across space and time, to aid the

## Quick install

First, create and activate a conda environment with the required dependencies:
Create and activate a conda environment with movement installed:
```
conda create -n movement-env -c conda-forge python=3.11 pytables
conda create -n movement-env -c conda-forge movement
conda activate movement-env
```

Then install the `movement` package:
```
pip install movement
```

> [!Note]
> Read the [documentation](https://movement.neuroinformatics.dev) for more information, including [full installation instructions](https://movement.neuroinformatics.dev/getting_started/installation.html) and [examples](https://movement.neuroinformatics.dev/examples/index.html).
@@ -52,7 +47,7 @@ You are welcome to chat with the team on [zulip](https://neuroinformatics.zulipc

## Citation

If you use `movement` in your work, please cite the following Zenodo DOI:
If you use movement in your work, please cite the following Zenodo DOI:

> Nikoloz Sirmpilatze, Chang Huan Lo, Sofía Miñano, Brandon D. Peri, Dhruv Sharma, Laura Porta, Iván Varela & Adam L. Tyson (2024). neuroinformatics-unit/movement. Zenodo. https://zenodo.org/doi/10.5281/zenodo.12755724
4 changes: 3 additions & 1 deletion docs/Makefile
Original file line number Diff line number Diff line change
@@ -3,7 +3,9 @@

# You can set these variables from the command line, and also
# from the environment for the first two.
SPHINXOPTS ?=
# -W: if there are warnings, treat them as errors and exit with status 1.
# --keep-going: run sphinx-build to completion and exit with status 1 if errors.
SPHINXOPTS ?= -W --keep-going
SPHINXBUILD ?= sphinx-build
SOURCEDIR = source
BUILDDIR = build
9 changes: 9 additions & 0 deletions docs/source/_static/css/custom.css
Original file line number Diff line number Diff line change
@@ -30,3 +30,12 @@ display: flex;
flex-wrap: wrap;
justify-content: space-between;
}

/* Disable decoration for all but movement backrefs */
a[class^="sphx-glr-backref-module-"],
a[class^="sphx-glr-backref-type-"] {
text-decoration: none;
}
a[class^="sphx-glr-backref-module-movement"] {
text-decoration: underline;
}
Binary file modified docs/source/_static/dataset_structure.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/source/_static/movement_overview.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion docs/source/community/roadmaps.md
Original file line number Diff line number Diff line change
@@ -24,5 +24,5 @@ We plan to release version `v0.1` of movement in early 2024, providing a minimal
- [x] Ability to compute velocity and acceleration from pose tracks.
- [x] Public website with [documentation](target-movement).
- [x] Package released on [PyPI](https://pypi.org/project/movement/).
- [ ] Package released on [conda-forge](https://conda-forge.org/).
- [x] Package released on [conda-forge](https://anaconda.org/conda-forge/movement).
- [ ] Ability to visualise pose tracks using [napari](napari:). We aim to represent pose tracks via napari's [Points](napari:howtos/layers/points) and [Tracks](napari:howtos/layers/tracks) layers and overlay them on video frames.
11 changes: 9 additions & 2 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
@@ -68,7 +68,7 @@
"tasklist",
]
# Automatically add anchors to markdown headings
myst_heading_anchors = 3
myst_heading_anchors = 4

# Add any paths that contain templates here, relative to this directory.
templates_path = ["_templates"]
@@ -108,6 +108,7 @@
"binderhub_url": "https://mybinder.org",
"dependencies": ["environment.yml"],
},
"reference_url": {"movement": None},
"remove_config_comments": True,
# do not render config params set as # sphinx_gallery_config [= value]
}
@@ -191,8 +192,14 @@
"napari": "https://napari.org/dev/{{path}}",
"setuptools-scm": "https://setuptools-scm.readthedocs.io/en/latest/{{path}}#{{fragment}}",
"sleap": "https://sleap.ai/{{path}}#{{fragment}}",
"sphinx-gallery": "https://sphinx-gallery.github.io/stable/{{path}}",
"sphinx-doc": "https://www.sphinx-doc.org/en/master/usage/{{path}}#{{fragment}}",
"sphinx-gallery": "https://sphinx-gallery.github.io/stable/{{path}}#{{fragment}}",
"xarray": "https://docs.xarray.dev/en/stable/{{path}}#{{fragment}}",
"lp": "https://lightning-pose.readthedocs.io/en/stable/{{path}}#{{fragment}}",
"via": "https://www.robots.ox.ac.uk/~vgg/software/via/{{path}}#{{fragment}}",
}

intersphinx_mapping = {
"xarray": ("https://docs.xarray.dev/en/stable/", None),
"scipy": ("https://docs.scipy.org/doc/scipy/reference/", None),
}
2 changes: 1 addition & 1 deletion docs/source/environment.yml
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@
channels:
- conda-forge
dependencies:
- python=3.10
- python=3.11
- pytables
- pip:
- movement
82 changes: 38 additions & 44 deletions docs/source/getting_started/installation.md
Original file line number Diff line number Diff line change
@@ -1,68 +1,62 @@
(target-installation)=
# Installation

## Create a conda environment

## Install the package
:::{admonition} Use a conda environment
:class: note
We recommend you install movement inside a [conda](conda:)
or [mamba](mamba:) environment, to avoid dependency conflicts with other packages.
In the following we assume you have `conda` installed,
but the same commands will also work with `mamba`/`micromamba`.
To avoid dependency conflicts with other packages, it is best practice to install Python packages within a virtual environment.
We recommend using [conda](conda:) or [mamba](mamba:) to create and manage this environment, as they simplify the installation process.
The following instructions assume that you have conda installed, but the same commands will also work with `mamba`/`micromamba`.
:::

First, create and activate an environment with some prerequisites.
You can call your environment whatever you like, we've used `movement-env`.
### Users
To install movement in a new environment, follow one of the options below.
We will use `movement-env` as the environment name, but you can choose any name you prefer.

::::{tab-set}
:::{tab-item} Conda
Create and activate an environment with movement installed:
```sh
conda create -n movement-env -c conda-forge python=3.11 pytables
conda create -n movement-env -c conda-forge movement
conda activate movement-env
```

## Install the package

Then install the `movement` package as described below.

::::{tab-set}

:::{tab-item} Users
To get the latest release from PyPI:

:::
:::{tab-item} Pip
Create and activate an environment with some prerequisites:
```sh
pip install movement
conda create -n movement-env -c conda-forge python=3.11 pytables
conda activate movement-env
```
If you have an older version of `movement` installed in the same environment,
you can update to the latest version with:

Install the latest movement release from PyPI:
```sh
pip install --upgrade movement
pip install movement
```
:::
::::

:::{tab-item} Developers
To get the latest development version, clone the
[GitHub repository](movement-github:)
and then run from inside the repository:
### Developers
If you are a developer looking to contribute to movement, please refer to our [contributing guide](target-contributing) for detailed setup instructions and guidelines.

## Check the installation
To verify that the installation was successful, run (with `movement-env` activated):
```sh
pip install -e .[dev] # works on most shells
pip install -e '.[dev]' # works on zsh (the default shell on macOS)
movement info
```
You should see a printout including the version numbers of movement
and some of its dependencies.

This will install the package in editable mode, including all `dev` dependencies.
Please see the [contributing guide](target-contributing) for more information.
:::

::::

## Check the installation

To verify that the installation was successful, you can run the following
command (with the `movement-env` activated):
## Update the package
To update movement to the latest version, we recommend installing it in a new environment,
as this prevents potential compatibility issues caused by changes in dependency versions.

To uninstall an existing environment named `movement-env`:
```sh
movement info
conda env remove -n movement-env
```

You should see a printout including the version numbers of `movement`
and some of its dependencies.
:::{tip}
If you are unsure about the environment name, you can get a list of the environments on your system with:
```sh
conda env list
```
:::
Once the environment has been removed, you can create a new one following the [installation instructions](#install-the-package) above.
6 changes: 5 additions & 1 deletion docs/source/getting_started/movement_dataset.md
Original file line number Diff line number Diff line change
@@ -19,7 +19,11 @@ To learn more about `xarray` data structures in general, see the relevant

## Dataset structure

![](../_static/dataset_structure.png)
```{figure} ../_static/dataset_structure.png
:alt: movement dataset structure
An {class}`xarray.Dataset` is a collection of several data arrays that share some dimensions. The schematic shows the data arrays that make up the `poses` and `bboxes` datasets in `movement`.
```

The structure of a `movement` dataset `ds` can be easily inspected by simply
printing it.
40 changes: 16 additions & 24 deletions examples/compute_kinematics.py
Original file line number Diff line number Diff line change
@@ -10,14 +10,13 @@
# Imports
# -------

import numpy as np

# For interactive plots: install ipympl with `pip install ipympl` and uncomment
# the following line in your notebook
# %matplotlib widget
from matplotlib import pyplot as plt

from movement import sample_data
from movement.utils.vector import compute_norm

# %%
# Load sample dataset
@@ -105,7 +104,7 @@
# %%
# We can also easily plot the components of the position vector against time
# using ``xarray``'s built-in plotting methods. We use
# :py:meth:`xarray.DataArray.squeeze` to
# :meth:`xarray.DataArray.squeeze` to
# remove the dimension of length 1 from the data (the ``keypoints`` dimension).
position.squeeze().plot.line(x="time", row="individuals", aspect=2, size=2.5)
plt.gcf().show()
@@ -131,7 +130,7 @@

# %%
# Notice that we could also compute the displacement (and all the other
# kinematic variables) using the :py:mod:`movement.analysis.kinematics` module:
# kinematic variables) using the :mod:`movement.analysis.kinematics` module:

# %%
import movement.analysis.kinematics as kin
@@ -255,13 +254,12 @@
# mouse along its trajectory.

# length of each displacement vector
displacement_vectors_lengths = np.linalg.norm(
displacement.sel(individuals=mouse_name, space=["x", "y"]).squeeze(),
axis=1,
displacement_vectors_lengths = compute_norm(
displacement.sel(individuals=mouse_name)
)

# sum of all displacement vectors
total_displacement = np.sum(displacement_vectors_lengths, axis=0) # in pixels
# sum the lengths of all displacement vectors (in pixels)
total_displacement = displacement_vectors_lengths.sum(dim="time").values[0]

print(
f"The mouse {mouse_name}'s trajectory is {total_displacement:.2f} "
@@ -284,7 +282,7 @@
# %%
# We can plot the components of the velocity vector against time
# using ``xarray``'s built-in plotting methods. We use
# :py:meth:`xarray.DataArray.squeeze` to
# :meth:`xarray.DataArray.squeeze` to
# remove the dimension of length 1 from the data (the ``keypoints`` dimension).

velocity.squeeze().plot.line(x="time", row="individuals", aspect=2, size=2.5)
@@ -299,14 +297,12 @@
# uses second order central differences.

# %%
# We can also visualise the speed, as the norm of the velocity vector:
# We can also visualise the speed, as the magnitude (norm)
# of the velocity vector:
fig, axes = plt.subplots(3, 1, sharex=True, sharey=True)
for mouse_name, ax in zip(velocity.individuals.values, axes, strict=False):
# compute the norm of the velocity vector for one mouse
speed_one_mouse = np.linalg.norm(
velocity.sel(individuals=mouse_name, space=["x", "y"]).squeeze(),
axis=1,
)
# compute the magnitude of the velocity vector for one mouse
speed_one_mouse = compute_norm(velocity.sel(individuals=mouse_name))
# plot speed against time
ax.plot(speed_one_mouse)
ax.set_title(mouse_name)
@@ -379,16 +375,12 @@
fig.tight_layout()

# %%
# The norm of the acceleration vector is the magnitude of the
# acceleration.
# We can also represent this for each individual.
# The can also represent the magnitude (norm) of the acceleration vector
# for each individual:
fig, axes = plt.subplots(3, 1, sharex=True, sharey=True)
for mouse_name, ax in zip(accel.individuals.values, axes, strict=False):
# compute norm of the acceleration vector for one mouse
accel_one_mouse = np.linalg.norm(
accel.sel(individuals=mouse_name, space=["x", "y"]).squeeze(),
axis=1,
)
# compute magnitude of the acceleration vector for one mouse
accel_one_mouse = compute_norm(accel.sel(individuals=mouse_name))

# plot acceleration against time
ax.plot(accel_one_mouse)
26 changes: 13 additions & 13 deletions examples/filter_and_interpolate.py
Original file line number Diff line number Diff line change
@@ -27,7 +27,7 @@
# Visualise the pose tracks
# -------------------------
# Since the data contains only a single wasp, we use
# :py:meth:`xarray.DataArray.squeeze` to remove
# :meth:`xarray.DataArray.squeeze` to remove
# the dimension of length 1 from the data (the ``individuals`` dimension).

ds.position.squeeze().plot.line(
@@ -51,7 +51,7 @@
# it's always a good idea to inspect the actual confidence values in the data.
#
# Let's first look at a histogram of the confidence scores. As before, we use
# :py:meth:`xarray.DataArray.squeeze` to remove the ``individuals`` dimension
# :meth:`xarray.DataArray.squeeze` to remove the ``individuals`` dimension
# from the data.

ds.confidence.squeeze().plot.hist(bins=20)
@@ -74,28 +74,28 @@
# Filter out points with low confidence
# -------------------------------------
# Using the
# :py:meth:`filter_by_confidence()\
# :meth:`filter_by_confidence()\
# <movement.move_accessor.MovementDataset.filtering_wrapper>`
# method of the ``move`` accessor,
# we can filter out points with confidence scores below a certain threshold.
# The default ``threshold=0.6`` will be used when ``threshold`` is not
# provided.
# This method will also report the number of NaN values in the dataset before
# and after the filtering operation by default (``print_report=True``).
# We will use :py:meth:`xarray.Dataset.update` to update ``ds`` in-place
# We will use :meth:`xarray.Dataset.update` to update ``ds`` in-place
# with the filtered ``position``.

ds.update({"position": ds.move.filter_by_confidence()})

# %%
# .. note::
# The ``move`` accessor :py:meth:`filter_by_confidence()\
# The ``move`` accessor :meth:`filter_by_confidence()\
# <movement.move_accessor.MovementDataset.filtering_wrapper>`
# method is a convenience method that applies
# :py:func:`movement.filtering.filter_by_confidence`,
# :func:`movement.filtering.filter_by_confidence`,
# which takes ``position`` and ``confidence`` as arguments.
# The equivalent function call using the
# :py:mod:`movement.filtering` module would be:
# :mod:`movement.filtering` module would be:
#
# .. code-block:: python
#
@@ -121,7 +121,7 @@
# Interpolate over missing values
# -------------------------------
# Using the
# :py:meth:`interpolate_over_time()\
# :meth:`interpolate_over_time()\
# <movement.move_accessor.MovementDataset.filtering_wrapper>`
# method of the ``move`` accessor,
# we can interpolate over the gaps we've introduced in the pose tracks.
@@ -135,13 +135,13 @@

# %%
# .. note::
# The ``move`` accessor :py:meth:`interpolate_over_time()\
# The ``move`` accessor :meth:`interpolate_over_time()\
# <movement.move_accessor.MovementDataset.filtering_wrapper>`
# is also a convenience method that applies
# :py:func:`movement.filtering.interpolate_over_time`
# :func:`movement.filtering.interpolate_over_time`
# to the ``position`` data variable.
# The equivalent function call using the
# :py:mod:`movement.filtering` module would be:
# :mod:`movement.filtering` module would be:
#
# .. code-block:: python
#
@@ -176,7 +176,7 @@
# %%
# Filtering multiple data variables
# ---------------------------------
# All :py:mod:`movement.filtering` functions are available via the
# All :mod:`movement.filtering` functions are available via the
# ``move`` accessor. These ``move`` accessor methods operate on the
# ``position`` data variable in the dataset ``ds`` by default.
# There is also an additional argument ``data_vars`` that allows us to
@@ -192,7 +192,7 @@
# in ``ds``, based on the confidence scores, we can specify
# ``data_vars=["position", "velocity"]`` in the method call.
# As the filtered data variables are returned as a dictionary, we can once
# again use :py:meth:`xarray.Dataset.update` to update ``ds`` in-place
# again use :meth:`xarray.Dataset.update` to update ``ds`` in-place
# with the filtered data variables.

ds["velocity"] = ds.move.compute_velocity()
14 changes: 7 additions & 7 deletions examples/smooth.py
Original file line number Diff line number Diff line change
@@ -109,7 +109,7 @@ def plot_raw_and_smooth_timeseries_and_psd(
# Smoothing with a median filter
# ------------------------------
# Using the
# :py:meth:`median_filter()\
# :meth:`median_filter()\
# <movement.move_accessor.MovementDataset.filtering_wrapper>`
# method of the ``move`` accessor,
# we apply a rolling window median filter over a 0.1-second window
@@ -125,13 +125,13 @@ def plot_raw_and_smooth_timeseries_and_psd(

# %%
# .. note::
# The ``move`` accessor :py:meth:`median_filter()\
# The ``move`` accessor :meth:`median_filter()\
# <movement.move_accessor.MovementDataset.filtering_wrapper>`
# method is a convenience method that applies
# :py:func:`movement.filtering.median_filter`
# :func:`movement.filtering.median_filter`
# to the ``position`` data variable.
# The equivalent function call using the
# :py:mod:`movement.filtering` module would be:
# :mod:`movement.filtering` module would be:
#
# .. code-block:: python
#
@@ -249,11 +249,11 @@ def plot_raw_and_smooth_timeseries_and_psd(
# Smoothing with a Savitzky-Golay filter
# --------------------------------------
# Here we use the
# :py:meth:`savgol_filter()\
# :meth:`savgol_filter()\
# <movement.move_accessor.MovementDataset.filtering_wrapper>`
# method of the ``move`` accessor, which is a convenience method that applies
# :py:func:`movement.filtering.savgol_filter`
# (a wrapper around :py:func:`scipy.signal.savgol_filter`),
# :func:`movement.filtering.savgol_filter`
# (a wrapper around :func:`scipy.signal.savgol_filter`),
# to the ``position`` data variable.
# The Savitzky-Golay filter is a polynomial smoothing filter that can be
# applied to time series data on a rolling window basis.
125 changes: 86 additions & 39 deletions movement/analysis/kinematics.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,45 @@
"""Compute kinematic variables like velocity and acceleration."""

import numpy as np
import xarray as xr

from movement.utils.logging import log_error


def compute_displacement(data: xr.DataArray) -> xr.DataArray:
"""Compute displacement between consecutive positions.
"""Compute displacement array in cartesian coordinates.
This is the difference between consecutive positions of each keypoint for
each individual across time. At each time point ``t``, it's defined as a
vector in cartesian ``(x,y)`` coordinates, pointing from the previous
``(t-1)`` to the current ``(t)`` position.
The displacement array is defined as the difference between the position
array at time point ``t`` and the position array at time point ``t-1``.
As a result, for a given individual and keypoint, the displacement vector
at time point ``t``, is the vector pointing from the previous
``(t-1)`` to the current ``(t)`` position, in cartesian coordinates.
Parameters
----------
data : xarray.DataArray
The input data containing ``time`` as a dimension.
The input data array containing position vectors in cartesian
coordinates, with ``time`` as a dimension.
Returns
-------
xarray.DataArray
An xarray DataArray containing the computed displacement.
An xarray DataArray containing displacement vectors in cartesian
coordinates.
Notes
-----
For the ``position`` array of a ``poses`` dataset, the ``displacement``
array will hold the displacement vectors for every keypoint and every
individual.
For the ``position`` array of a ``bboxes`` dataset, the ``displacement``
array will hold the displacement vectors for the centroid of every
individual bounding box.
For the ``shape`` array of a ``bboxes`` dataset, the
``displacement`` array will hold vectors with the change in width and
height per bounding box, between consecutive time points.
"""
_validate_time_dimension(data)
@@ -32,66 +49,101 @@ def compute_displacement(data: xr.DataArray) -> xr.DataArray:


def compute_velocity(data: xr.DataArray) -> xr.DataArray:
"""Compute the velocity in cartesian ``(x,y)`` coordinates.
"""Compute velocity array in cartesian coordinates.
Velocity is the first derivative of position for each keypoint
and individual across time. It's computed using numerical differentiation
and assumes equidistant time spacing.
The velocity array is the first time-derivative of the position
array. It is computed by applying the second-order accurate central
differences method on the position array.
Parameters
----------
data : xarray.DataArray
The input data containing ``time`` as a dimension.
The input data array containing position vectors in cartesian
coordinates, with ``time`` as a dimension.
Returns
-------
xarray.DataArray
An xarray DataArray containing the computed velocity.
An xarray DataArray containing velocity vectors in cartesian
coordinates.
Notes
-----
For the ``position`` array of a ``poses`` dataset, the ``velocity`` array
will hold the velocity vectors for every keypoint and every individual.
For the ``position`` array of a ``bboxes`` dataset, the ``velocity`` array
will hold the velocity vectors for the centroid of every individual
bounding box.
See Also
--------
:meth:`xarray.DataArray.differentiate` : The underlying method used.
"""
return _compute_approximate_derivative(data, order=1)
return _compute_approximate_time_derivative(data, order=1)


def compute_acceleration(data: xr.DataArray) -> xr.DataArray:
"""Compute acceleration in cartesian ``(x,y)`` coordinates.
"""Compute acceleration array in cartesian coordinates.
Acceleration represents the second derivative of position for each keypoint
and individual across time. It's computed using numerical differentiation
and assumes equidistant time spacing.
The acceleration array is the second time-derivative of the
position array. It is computed by applying the second-order accurate
central differences method on the velocity array.
Parameters
----------
data : xarray.DataArray
The input data containing ``time`` as a dimension.
The input data array containing position vectors in cartesian
coordinates, with``time`` as a dimension.
Returns
-------
xarray.DataArray
An xarray DataArray containing the computed acceleration.
An xarray DataArray containing acceleration vectors in cartesian
coordinates.
Notes
-----
For the ``position`` array of a ``poses`` dataset, the ``acceleration``
array will hold the acceleration vectors for every keypoint and every
individual.
For the ``position`` array of a ``bboxes`` dataset, the ``acceleration``
array will hold the acceleration vectors for the centroid of every
individual bounding box.
See Also
--------
:meth:`xarray.DataArray.differentiate` : The underlying method used.
"""
return _compute_approximate_derivative(data, order=2)
return _compute_approximate_time_derivative(data, order=2)


def _compute_approximate_derivative(
def _compute_approximate_time_derivative(
data: xr.DataArray, order: int
) -> xr.DataArray:
"""Compute the derivative using numerical differentiation.
"""Compute the time-derivative of an array using numerical differentiation.
This assumes equidistant time spacing.
This function uses :meth:`xarray.DataArray.differentiate`,
which differentiates the array with the second-order
accurate central differences method.
Parameters
----------
data : xarray.DataArray
The input data containing ``time`` as a dimension.
The input data array containing ``time`` as a dimension.
order : int
The order of the derivative. 1 for velocity, 2 for
acceleration. Value must be a positive integer.
The order of the time-derivative. For an input containing position
data, use 1 to compute velocity, and 2 to compute acceleration. Value
must be a positive integer.
Returns
-------
xarray.DataArray
An xarray DataArray containing the derived variable.
An xarray DataArray containing the time-derivative of the
input data.
"""
if not isinstance(order, int):
@@ -100,17 +152,12 @@ def _compute_approximate_derivative(
)
if order <= 0:
raise log_error(ValueError, "Order must be a positive integer.")

_validate_time_dimension(data)

result = data
dt = data["time"].values[1] - data["time"].values[0]
for _ in range(order):
result = xr.apply_ufunc(
np.gradient,
result,
dt,
kwargs={"axis": 0},
)
result = result.reindex_like(data)
result = result.differentiate("time")
return result


@@ -124,11 +171,11 @@ def _validate_time_dimension(data: xr.DataArray) -> None:
Raises
------
AttributeError
ValueError
If the input data does not contain a ``time`` dimension.
"""
if "time" not in data.dims:
raise log_error(
AttributeError, "Input data must contain 'time' as a dimension."
ValueError, "Input data must contain 'time' as a dimension."
)
44 changes: 22 additions & 22 deletions movement/filtering.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Filter and interpolate pose tracks in ``movement`` datasets."""
"""Filter and interpolate tracks in ``movement`` datasets."""

import xarray as xr
from scipy import signal
@@ -40,14 +40,14 @@ def filter_by_confidence(
Notes
-----
The point-wise confidence values reported by various pose estimation
frameworks are not standardised, and the range of values can vary.
For example, DeepLabCut reports a likelihood value between 0 and 1, whereas
the point confidence reported by SLEAP can range above 1.
Therefore, the default threshold value will not be appropriate for all
datasets and does not have the same meaning across pose estimation
frameworks. We advise users to inspect the confidence values
in their dataset and adjust the threshold accordingly.
For the poses dataset case, note that the point-wise confidence values
reported by various pose estimation frameworks are not standardised, and
the range of values can vary. For example, DeepLabCut reports a likelihood
value between 0 and 1, whereas the point confidence reported by SLEAP can
range above 1. Therefore, the default threshold value will not be
appropriate for all datasets and does not have the same meaning across
pose estimation frameworks. We advise users to inspect the confidence
values in their dataset and adjust the threshold accordingly.
"""
data_filtered = data.where(confidence >= threshold)
@@ -66,7 +66,7 @@ def interpolate_over_time(
) -> xr.DataArray:
"""Fill in NaN values by interpolating over the ``time`` dimension.
This method uses :py:meth:`xarray.DataArray.interpolate_na` under the
This method uses :meth:`xarray.DataArray.interpolate_na` under the
hood and passes the ``method`` and ``max_gap`` parameters to it.
See the xarray documentation for more details on these parameters.
@@ -88,14 +88,14 @@ def interpolate_over_time(
Returns
-------
xr.DataArray
xarray.DataArray
The data where NaN values have been interpolated over
using the parameters provided.
Notes
-----
The ``max_gap`` parameter differs slightly from that in
:py:meth:`xarray.DataArray.interpolate_na`, in which the gap size
:meth:`xarray.DataArray.interpolate_na`, in which the gap size
is defined as the difference between the ``time`` coordinate values
at the first data point after a gap and the last value before a gap.
@@ -127,17 +127,17 @@ def median_filter(
data : xarray.DataArray
The input data to be smoothed.
window : int
The size of the filter window, representing the fixed number
The size of the smoothing window, representing the fixed number
of observations used for each window.
min_periods : int
Minimum number of observations in the window required to have
a value (otherwise result is NaN). The default, None, is
equivalent to setting ``min_periods`` equal to the size of the window.
This argument is directly passed to the ``min_periods`` parameter of
:py:meth:`xarray.DataArray.rolling`.
:meth:`xarray.DataArray.rolling`.
print_report : bool
Whether to print a report on the number of NaNs in the dataset
before and after filtering. Default is ``True``.
before and after smoothing. Default is ``True``.
Returns
-------
@@ -146,7 +146,7 @@ def median_filter(
Notes
-----
By default, whenever one or more NaNs are present in the filter window,
By default, whenever one or more NaNs are present in the smoothing window,
a NaN is returned to the output array. As a result, any
stretch of NaNs present in the input data will be propagated
proportionally to the size of the window (specifically, by
@@ -194,18 +194,18 @@ def savgol_filter(
data : xarray.DataArray
The input data to be smoothed.
window : int
The size of the filter window, representing the fixed number
The size of the smoothing window, representing the fixed number
of observations used for each window.
polyorder : int
The order of the polynomial used to fit the samples. Must be
less than ``window``. By default, a ``polyorder`` of
2 is used.
print_report : bool
Whether to print a report on the number of NaNs in the dataset
before and after filtering. Default is ``True``.
before and after smoothing. Default is ``True``.
**kwargs : dict
Additional keyword arguments are passed to
:py:func:`scipy.signal.savgol_filter`.
:func:`scipy.signal.savgol_filter`.
Note that the ``axis`` keyword argument may not be overridden.
@@ -217,15 +217,15 @@ def savgol_filter(
Notes
-----
Uses the :py:func:`scipy.signal.savgol_filter` function to apply a
Uses the :func:`scipy.signal.savgol_filter` function to apply a
Savitzky-Golay filter to the input data.
See the SciPy documentation for more information on that function.
Whenever one or more NaNs are present in a filter window of the
Whenever one or more NaNs are present in a smoothing window of the
input data, a NaN is returned to the output array. As a result, any
stretch of NaNs present in the input data will be propagated
proportionally to the size of the window (specifically, by
``floor(window/2)``). Note that, unlike
:py:func:`movement.filtering.median_filter()`, there is no ``min_periods``
:func:`movement.filtering.median_filter`, there is no ``min_periods``
option to control this behaviour.
"""
19 changes: 13 additions & 6 deletions movement/io/load_bboxes.py
Original file line number Diff line number Diff line change
@@ -35,19 +35,19 @@ def from_numpy(
position_array : np.ndarray
Array of shape (n_frames, n_individuals, n_space)
containing the tracks of the bounding boxes' centroids.
It will be converted to a :py:class:`xarray.DataArray` object
It will be converted to a :class:`xarray.DataArray` object
named "position".
shape_array : np.ndarray
Array of shape (n_frames, n_individuals, n_space)
containing the shape of the bounding boxes. The shape of a bounding
box is its width (extent along the x-axis of the image) and height
(extent along the y-axis of the image). It will be converted to a
:py:class:`xarray.DataArray` object named "shape".
:class:`xarray.DataArray` object named "shape".
confidence_array : np.ndarray, optional
Array of shape (n_frames, n_individuals) containing
the confidence scores of the bounding boxes. If None (default), the
confidence scores are set to an array of NaNs. It will be converted
to a :py:class:`xarray.DataArray` object named "confidence".
to a :class:`xarray.DataArray` object named "confidence".
individual_names : list of str, optional
List of individual names for the tracked bounding boxes in the video.
If None (default), bounding boxes are assigned names based on the size
@@ -402,6 +402,11 @@ def _numpy_arrays_from_via_tracks_file(file_path: Path) -> dict:

array_dict[key] = np.stack(list_arrays, axis=1).squeeze()

# Transform position_array to represent centroid of bbox,
# rather than top-left corner
# (top left corner: corner of the bbox with minimum x and y coordinates)
array_dict["position_array"] += array_dict["shape_array"] / 2

# Add remaining arrays to dict
array_dict["ID_array"] = df["ID"].unique().reshape(-1, 1)
array_dict["frame_array"] = df["frame_number"].unique().reshape(-1, 1)
@@ -415,14 +420,16 @@ def _df_from_via_tracks_file(file_path: Path) -> pd.DataFrame:
Read the VIA tracks .csv file as a pandas dataframe with columns:
- ID: the integer ID of the tracked bounding box.
- frame_number: the frame number of the tracked bounding box.
- x: the x-coordinate of the tracked bounding box centroid.
- y: the y-coordinate of the tracked bounding box centroid.
- x: the x-coordinate of the tracked bounding box's top-left corner.
- y: the y-coordinate of the tracked bounding box's top-left corner.
- w: the width of the tracked bounding box.
- h: the height of the tracked bounding box.
- confidence: the confidence score of the tracked bounding box.
The dataframe is sorted by ID and frame number, and for each ID,
empty frames are filled in with NaNs.
empty frames are filled in with NaNs. The coordinates of the bboxes
are assumed to be in the image coordinate system (i.e., the top-left
corner of a bbox is its corner with minimum x and y coordinates).
"""
# Read VIA tracks .csv file as a pandas dataframe
df_file = pd.read_csv(file_path, sep=",", header=0)
4 changes: 2 additions & 2 deletions movement/io/load_poses.py
Original file line number Diff line number Diff line change
@@ -34,11 +34,11 @@ def from_numpy(
position_array : np.ndarray
Array of shape (n_frames, n_individuals, n_keypoints, n_space)
containing the poses. It will be converted to a
:py:class:`xarray.DataArray` object named "position".
:class:`xarray.DataArray` object named "position".
confidence_array : np.ndarray, optional
Array of shape (n_frames, n_individuals, n_keypoints) containing
the point-wise confidence scores. It will be converted to a
:py:class:`xarray.DataArray` object named "confidence".
:class:`xarray.DataArray` object named "confidence".
If None (default), the scores will be set to an array of NaNs.
individual_names : list of str, optional
List of unique names for the individuals in the video. If None
2 changes: 1 addition & 1 deletion movement/io/save_poses.py
Original file line number Diff line number Diff line change
@@ -241,7 +241,7 @@ def to_lp_file(
-----
LightningPose saves pose estimation outputs as .csv files, using the same
format as single-animal DeepLabCut projects. Therefore, under the hood,
this function calls :py:func:`movement.io.save_poses.to_dlc_file`
this function calls :func:`movement.io.save_poses.to_dlc_file`
with ``split_individuals=True``. This setting means that each individual
is saved to a separate file, with the individual's name appended to the
file path, just before the file extension,
14 changes: 7 additions & 7 deletions movement/move_accessor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Accessor for extending :py:class:`xarray.Dataset` objects."""
"""Accessor for extending :class:`xarray.Dataset` objects."""

import logging
from typing import ClassVar
@@ -18,9 +18,9 @@

@xr.register_dataset_accessor("move")
class MovementDataset:
"""An :py:class:`xarray.Dataset` accessor for ``movement`` data.
"""An :class:`xarray.Dataset` accessor for ``movement`` data.
A ``movement`` dataset is an :py:class:`xarray.Dataset` with a specific
A ``movement`` dataset is an :class:`xarray.Dataset` with a specific
structure to represent pose tracks or bounding boxes data,
associated confidence scores and relevant metadata.
@@ -66,8 +66,8 @@ def __getattr__(self, name: str) -> xr.DataArray:
This method currently only forwards kinematic property computation
and filtering operations to the respective functions in
:py:mod:`movement.analysis.kinematics` and
:py:mod:`movement.filtering`.
:mod:`movement.analysis.kinematics` and
:mod:`movement.filtering`.
Parameters
----------
@@ -106,7 +106,7 @@ def kinematics_wrapper(
"""Provide convenience method for computing kinematic properties.
This method forwards kinematic property computation
to the respective functions in :py:mod:`movement.analysis.kinematics`.
to the respective functions in :mod:`movement.analysis.kinematics`.
Parameters
----------
@@ -161,7 +161,7 @@ def filtering_wrapper(
"""Provide convenience method for filtering data variables.
This method forwards filtering and/or smoothing to the respective
functions in :py:mod:`movement.filtering`. The data variables to
functions in :mod:`movement.filtering`. The data variables to
filter can be specified in ``data_vars``. If ``data_vars`` is not
specified, the ``position`` data variable is selected by default.
4 changes: 2 additions & 2 deletions movement/utils/logging.py
Original file line number Diff line number Diff line change
@@ -113,8 +113,8 @@ def log_to_attrs(func):
"""Log the operation performed by the wrapped function.
This decorator appends log entries to the data's ``log``
attribute. The wrapped function must accept an :py:class:`xarray.Dataset`
or :py:class:`xarray.DataArray` as its first argument and return an
attribute. The wrapped function must accept an :class:`xarray.Dataset`
or :class:`xarray.DataArray` as its first argument and return an
object of the same type.
"""

102 changes: 96 additions & 6 deletions movement/utils/vector.py
Original file line number Diff line number Diff line change
@@ -6,6 +6,93 @@
from movement.utils.logging import log_error


def compute_norm(data: xr.DataArray) -> xr.DataArray:
"""Compute the norm of the vectors along the spatial dimension.
The norm of a vector is its magnitude, also called Euclidean norm, 2-norm
or Euclidean length. Note that if the input data is expressed in polar
coordinates, the magnitude of a vector is the same as its radial coordinate
``rho``.
Parameters
----------
data : xarray.DataArray
The input data array containing either ``space`` or ``space_pol``
as a dimension.
Returns
-------
xarray.DataArray
A data array holding the norm of the input vectors.
Note that this output array has no spatial dimension but preserves
all other dimensions of the input data array (see Notes).
Notes
-----
If the input data array is a ``position`` array, this function will compute
the magnitude of the position vectors, for every individual and keypoint,
at every timestep. If the input data array is a ``shape`` array of a
bounding boxes dataset, it will compute the magnitude of the shape
vectors (i.e., the diagonal of the bounding box),
for every individual and at every timestep.
"""
if "space" in data.dims:
_validate_dimension_coordinates(data, {"space": ["x", "y"]})
return xr.apply_ufunc(
np.linalg.norm,
data,
input_core_dims=[["space"]],
kwargs={"axis": -1},
)
elif "space_pol" in data.dims:
_validate_dimension_coordinates(data, {"space_pol": ["rho", "phi"]})
return data.sel(space_pol="rho", drop=True)
else:
_raise_error_for_missing_spatial_dim()


def convert_to_unit(data: xr.DataArray) -> xr.DataArray:
"""Convert the vectors along the spatial dimension into unit vectors.
A unit vector is a vector pointing in the same direction as the original
vector but with norm = 1.
Parameters
----------
data : xarray.DataArray
The input data array containing either ``space`` or ``space_pol``
as a dimension.
Returns
-------
xarray.DataArray
A data array holding the unit vectors of the input data array
(all input dimensions are preserved).
Notes
-----
Note that the unit vector for the null vector is undefined, since the null
vector has 0 norm and no direction associated with it.
"""
if "space" in data.dims:
_validate_dimension_coordinates(data, {"space": ["x", "y"]})
return data / compute_norm(data)
elif "space_pol" in data.dims:
_validate_dimension_coordinates(data, {"space_pol": ["rho", "phi"]})
# Set both rho and phi values to NaN at null vectors (where rho = 0)
new_data = xr.where(data.sel(space_pol="rho") == 0, np.nan, data)
# Set the rho values to 1 for non-null vectors (phi is preserved)
new_data.loc[{"space_pol": "rho"}] = xr.where(
new_data.sel(space_pol="rho").isnull(), np.nan, 1
)
return new_data
else:
_raise_error_for_missing_spatial_dim()


def cart2pol(data: xr.DataArray) -> xr.DataArray:
"""Transform Cartesian coordinates to polar.
@@ -25,12 +112,7 @@ def cart2pol(data: xr.DataArray) -> xr.DataArray:
"""
_validate_dimension_coordinates(data, {"space": ["x", "y"]})
rho = xr.apply_ufunc(
np.linalg.norm,
data,
input_core_dims=[["space"]],
kwargs={"axis": -1},
)
rho = compute_norm(data)
phi = xr.apply_ufunc(
np.arctan2,
data.sel(space="y"),
@@ -122,3 +204,11 @@ def _validate_dimension_coordinates(
)
if error_message:
raise log_error(ValueError, error_message)


def _raise_error_for_missing_spatial_dim() -> None:
raise log_error(
ValueError,
"Input data array must contain either 'space' or 'space_pol' "
"as dimensions.",
)
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -14,7 +14,6 @@ license = { text = "BSD-3-Clause" }

dependencies = [
"numpy",
"pandas<2.2.2;python_version>='3.12'",
"pandas",
"h5py",
"attrs",
156 changes: 144 additions & 12 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -39,6 +39,7 @@ def setup_logging(tmp_path):
)


# --------- File validator fixtures ---------------------------------
@pytest.fixture
def unreadable_file(tmp_path):
"""Return a dictionary containing the file path and
@@ -213,13 +214,42 @@ def sleap_file(request):
return pytest.DATA_PATHS.get(request.param)


# ------------ Dataset validator fixtures ---------------------------------


@pytest.fixture
def valid_bboxes_array():
"""Return a dictionary of valid non-zero arrays for a
def valid_bboxes_arrays_all_zeros():
"""Return a dictionary of valid zero arrays (in terms of shape) for a
ValidBboxesDataset.
"""
# define the shape of the arrays
n_frames, n_individuals, n_space = (10, 2, 2)

# build a valid array for position or shape with all zeros
valid_bbox_array_all_zeros = np.zeros((n_frames, n_individuals, n_space))

# return as a dict
return {
"position": valid_bbox_array_all_zeros,
"shape": valid_bbox_array_all_zeros,
"individual_names": ["id_" + str(id) for id in range(n_individuals)],
}


# --------------------- Bboxes dataset fixtures ----------------------------
@pytest.fixture
def valid_bboxes_arrays():
"""Return a dictionary of valid arrays for a
ValidBboxesDataset representing a uniform linear motion.
It represents 2 individuals for 10 frames, in 2D space.
- Individual 0 moves along the x=y line from the origin.
- Individual 1 moves along the x=-y line line from the origin.
Contains realistic data for 10 frames, 2 individuals, in 2D
with 5 low confidence bounding boxes.
All confidence values are set to 0.9 except the following which are set
to 0.1:
- Individual 0 at frames 2, 3, 4
- Individual 1 at frames 2, 3
"""
# define the shape of the arrays
n_frames, n_individuals, n_space = (10, 2, 2)
@@ -252,22 +282,21 @@ def valid_bboxes_array():
"position": position,
"shape": shape,
"confidence": confidence,
"individual_names": ["id_" + str(id) for id in range(n_individuals)],
}


@pytest.fixture
def valid_bboxes_dataset(
valid_bboxes_array,
valid_bboxes_arrays,
):
"""Return a valid bboxes dataset with low confidence values and
time in frames.
"""Return a valid bboxes dataset for two individuals moving in uniform
linear motion, with 5 frames with low confidence values and time in frames.
"""
dim_names = MovementDataset.dim_names["bboxes"]

position_array = valid_bboxes_array["position"]
shape_array = valid_bboxes_array["shape"]
confidence_array = valid_bboxes_array["confidence"]
position_array = valid_bboxes_arrays["position"]
shape_array = valid_bboxes_arrays["shape"]
confidence_array = valid_bboxes_arrays["confidence"]

n_frames, n_individuals, _ = position_array.shape

@@ -315,6 +344,7 @@ def valid_bboxes_dataset_with_nan(valid_bboxes_dataset):
return valid_bboxes_dataset


# --------------------- Poses dataset fixtures ----------------------------
@pytest.fixture
def valid_position_array():
"""Return a function that generates different kinds
@@ -384,12 +414,111 @@ def valid_poses_dataset(valid_position_array, request):
@pytest.fixture
def valid_poses_dataset_with_nan(valid_poses_dataset):
"""Return a valid pose tracks dataset with NaN values."""
# Sets position for all keypoints in individual ind1 to NaN
# at timepoints 3, 7, 8
valid_poses_dataset.position.loc[
{"individuals": "ind1", "time": [3, 7, 8]}
] = np.nan
return valid_poses_dataset


@pytest.fixture
def valid_poses_array_uniform_linear_motion():
"""Return a dictionary of valid arrays for a
ValidPosesDataset representing a uniform linear motion.
It represents 2 individuals with 3 keypoints, for 10 frames, in 2D space.
- Individual 0 moves along the x=y line from the origin.
- Individual 1 moves along the x=-y line line from the origin.
All confidence values for all keypoints are set to 0.9 except
for the keypoints at the following frames which are set to 0.1:
- Individual 0 at frames 2, 3, 4
- Individual 1 at frames 2, 3
"""
# define the shape of the arrays
n_frames, n_individuals, n_keypoints, n_space = (10, 2, 3, 2)

# define centroid (index=0) trajectory in position array
# for each individual, the centroid moves along
# the x=+/-y line, starting from the origin.
# - individual 0 moves along x = y line
# - individual 1 moves along x = -y line
# They move one unit along x and y axes in each frame
frames = np.arange(n_frames)
position = np.empty((n_frames, n_individuals, n_keypoints, n_space))
position[:, :, 0, 0] = frames[:, None] # reshape to (n_frames, 1)
position[:, 0, 0, 1] = frames
position[:, 1, 0, 1] = -frames

# define trajectory of left and right keypoints
# for individual 0, at each timepoint:
# - the left keypoint (index=1) is at x_centroid, y_centroid + 1
# - the right keypoint (index=2) is at x_centroid + 1, y_centroid
# for individual 1, at each timepoint:
# - the left keypoint (index=1) is at x_centroid - 1, y_centroid
# - the right keypoint (index=2) is at x_centroid, y_centroid + 1
offsets = [
[(0, 1), (1, 0)], # individual 0: left, right keypoints (x,y) offsets
[(-1, 0), (0, 1)], # individual 1: left, right keypoints (x,y) offsets
]
for i in range(n_individuals):
for kpt in range(1, n_keypoints):
position[:, i, kpt, 0] = (
position[:, i, 0, 0] + offsets[i][kpt - 1][0]
)
position[:, i, kpt, 1] = (
position[:, i, 0, 1] + offsets[i][kpt - 1][1]
)

# build an array of confidence values, all 0.9
confidence = np.full((n_frames, n_individuals, n_keypoints), 0.9)
# set 5 low-confidence values
# - set 3 confidence values for individual id_0's centroid to 0.1
# - set 2 confidence values for individual id_1's centroid to 0.1
idx_start = 2
confidence[idx_start : idx_start + 3, 0, 0] = 0.1
confidence[idx_start : idx_start + 2, 1, 0] = 0.1

return {"position": position, "confidence": confidence}


@pytest.fixture
def valid_poses_dataset_uniform_linear_motion(
valid_poses_array_uniform_linear_motion,
):
"""Return a valid poses dataset for two individuals moving in uniform
linear motion, with 5 frames with low confidence values and time in frames.
"""
dim_names = MovementDataset.dim_names["poses"]

position_array = valid_poses_array_uniform_linear_motion["position"]
confidence_array = valid_poses_array_uniform_linear_motion["confidence"]

n_frames, n_individuals, _, _ = position_array.shape

return xr.Dataset(
data_vars={
"position": xr.DataArray(position_array, dims=dim_names),
"confidence": xr.DataArray(confidence_array, dims=dim_names[:-1]),
},
coords={
dim_names[0]: np.arange(n_frames),
dim_names[1]: [f"id_{i}" for i in range(1, n_individuals + 1)],
dim_names[2]: ["centroid", "left", "right"],
dim_names[3]: ["x", "y"],
},
attrs={
"fps": None,
"time_unit": "frames",
"source_software": "test",
"source_file": "test_poses.h5",
"ds_type": "poses",
},
)


# -------------------- Invalid datasets fixtures ------------------------------
@pytest.fixture
def not_a_dataset():
"""Return data that is not a pose tracks dataset."""
@@ -444,7 +573,7 @@ def kinematic_property(request):
return request.param


# VIA tracks CSV fixtures
# ---------------- VIA tracks CSV file fixtures ----------------------------
@pytest.fixture
def via_tracks_csv_with_invalid_header(tmp_path):
"""Return the file path for a file with invalid header."""
@@ -705,6 +834,9 @@ def count_consecutive_nans(da):
return (da.isnull().astype(int).diff("time") == 1).sum().item()


# ----------------- Helper fixture -----------------


@pytest.fixture
def helpers():
"""Return an instance of the ``Helpers`` class."""
284 changes: 201 additions & 83 deletions tests/test_unit/test_filtering.py
Original file line number Diff line number Diff line change
@@ -10,115 +10,233 @@
savgol_filter,
)

# Dataset fixtures
list_valid_datasets_without_nans = [
"valid_poses_dataset",
"valid_bboxes_dataset",
]
list_valid_datasets_with_nans = [
f"{dataset}_with_nan" for dataset in list_valid_datasets_without_nans
]
list_all_valid_datasets = (
list_valid_datasets_without_nans + list_valid_datasets_with_nans
)


@pytest.mark.parametrize(
"max_gap, expected_n_nans", [(None, 0), (1, 8), (2, 0)]
"valid_dataset_with_nan",
list_valid_datasets_with_nans,
)
@pytest.mark.parametrize(
"max_gap, expected_n_nans_in_position", [(None, 0), (0, 3), (1, 2), (2, 0)]
)
def test_interpolate_over_time(
valid_poses_dataset_with_nan, helpers, max_gap, expected_n_nans
def test_interpolate_over_time_on_position(
valid_dataset_with_nan,
max_gap,
expected_n_nans_in_position,
helpers,
request,
):
"""Test that the number of NaNs decreases after interpolating
"""Test that the number of NaNs decreases after linearly interpolating
over time and that the resulting number of NaNs is as expected
for different values of ``max_gap``.
"""
# First dataset with time unit in frames
data_in_frames = valid_poses_dataset_with_nan.position
# Create second dataset with time unit in seconds
data_in_seconds = data_in_frames.copy()
data_in_seconds["time"] = data_in_seconds["time"] * 0.1
data_interp_frames = interpolate_over_time(data_in_frames, max_gap=max_gap)
data_interp_seconds = interpolate_over_time(
data_in_seconds, max_gap=max_gap
valid_dataset_in_frames = request.getfixturevalue(valid_dataset_with_nan)

# Get position array with time unit in frames & seconds
# assuming 10 fps = 0.1 s per frame
valid_dataset_in_seconds = valid_dataset_in_frames.copy()
valid_dataset_in_seconds.coords["time"] = (
valid_dataset_in_seconds.coords["time"] * 0.1
)
n_nans_before = helpers.count_nans(data_in_frames)
n_nans_after_frames = helpers.count_nans(data_interp_frames)
n_nans_after_seconds = helpers.count_nans(data_interp_seconds)
position = {
"frames": valid_dataset_in_frames.position,
"seconds": valid_dataset_in_seconds.position,
}

# Count number of NaNs before and after interpolating position
n_nans_before = helpers.count_nans(position["frames"])
n_nans_after_per_time_unit = {}
for time_unit in ["frames", "seconds"]:
# interpolate
position_interp = interpolate_over_time(
position[time_unit], method="linear", max_gap=max_gap
)
# count nans
n_nans_after_per_time_unit[time_unit] = helpers.count_nans(
position_interp
)

# The number of NaNs should be the same for both datasets
# as max_gap is based on number of missing observations (NaNs)
assert n_nans_after_frames == n_nans_after_seconds
assert n_nans_after_frames < n_nans_before
assert n_nans_after_frames == expected_n_nans
assert (
n_nans_after_per_time_unit["frames"]
== n_nans_after_per_time_unit["seconds"]
)

# The number of NaNs should decrease after interpolation
n_nans_after = n_nans_after_per_time_unit["frames"]
if max_gap == 0:
assert n_nans_after == n_nans_before
else:
assert n_nans_after < n_nans_before

# The number of NaNs after interpolating should be as expected
assert n_nans_after == (
valid_dataset_in_frames.sizes["space"]
* valid_dataset_in_frames.sizes.get("keypoints", 1)
# in bboxes dataset there is no keypoints dimension
* expected_n_nans_in_position
)


def test_filter_by_confidence(valid_poses_dataset, helpers):
@pytest.mark.parametrize(
"valid_dataset_no_nans, n_low_confidence_kpts",
[
("valid_poses_dataset", 20),
("valid_bboxes_dataset", 5),
],
)
def test_filter_by_confidence_on_position(
valid_dataset_no_nans, n_low_confidence_kpts, helpers, request
):
"""Test that points below the default 0.6 confidence threshold
are converted to NaN.
"""
data = valid_poses_dataset.position
confidence = valid_poses_dataset.confidence
data_filtered = filter_by_confidence(data, confidence)
n_nans = helpers.count_nans(data_filtered)
assert isinstance(data_filtered, xr.DataArray)
# 5 timepoints * 2 individuals * 2 keypoints * 2 space dimensions
# have confidence below 0.6
assert n_nans == 40


@pytest.mark.parametrize("window_size", [2, 4])
def test_median_filter(valid_poses_dataset_with_nan, window_size):
"""Test that applying the median filter returns
a different xr.DataArray than the input data.
"""
data = valid_poses_dataset_with_nan.position
data_smoothed = median_filter(data, window_size)
del data_smoothed.attrs["log"]
assert isinstance(data_smoothed, xr.DataArray) and not (
data_smoothed.equals(data)
# Filter position by confidence
valid_input_dataset = request.getfixturevalue(valid_dataset_no_nans)
position_filtered = filter_by_confidence(
valid_input_dataset.position,
confidence=valid_input_dataset.confidence,
threshold=0.6,
)

# Count number of NaNs in the full array
n_nans = helpers.count_nans(position_filtered)

def test_median_filter_with_nans(valid_poses_dataset_with_nan, helpers):
"""Test NaN behaviour of the median filter. The input data
contains NaNs in all keypoints of the first individual at timepoints
3, 7, and 8 (0-indexed, 10 total timepoints). The median filter
should propagate NaNs within the windows of the filter,
but it should not introduce any NaNs for the second individual.
"""
data = valid_poses_dataset_with_nan.position
data_smoothed = median_filter(data, window=3)
# All points of the first individual are converted to NaNs except
# at timepoints 0, 1, and 5.
assert not (
data_smoothed.isel(individuals=0, time=[0, 1, 5]).isnull().any()
)
# 7 timepoints * 1 individual * 2 keypoints * 2 space dimensions
assert helpers.count_nans(data_smoothed) == 28
# No NaNs should be introduced for the second individual
assert not data_smoothed.isel(individuals=1).isnull().any()
# expected number of nans for poses:
# 5 timepoints * 2 individuals * 2 keypoints
# Note: we count the number of nans in the array, so we multiply
# the number of low confidence keypoints by the number of
# space dimensions
assert isinstance(position_filtered, xr.DataArray)
assert n_nans == valid_input_dataset.sizes["space"] * n_low_confidence_kpts


@pytest.mark.parametrize("window, polyorder", [(2, 1), (4, 2)])
def test_savgol_filter(valid_poses_dataset_with_nan, window, polyorder):
"""Test that applying the Savitzky-Golay filter returns
a different xr.DataArray than the input data.
@pytest.mark.parametrize(
"valid_dataset",
list_all_valid_datasets,
)
@pytest.mark.parametrize(
("filter_func, filter_kwargs"),
[
(median_filter, {"window": 2}),
(median_filter, {"window": 4}),
(savgol_filter, {"window": 2, "polyorder": 1}),
(savgol_filter, {"window": 4, "polyorder": 2}),
],
)
def test_filter_on_position(
filter_func, filter_kwargs, valid_dataset, request
):
"""Test that applying a filter to the position data returns
a different xr.DataArray than the input position data.
"""
data = valid_poses_dataset_with_nan.position
data_smoothed = savgol_filter(data, window, polyorder=polyorder)
del data_smoothed.attrs["log"]
assert isinstance(data_smoothed, xr.DataArray) and not (
data_smoothed.equals(data)
# Filter position
valid_input_dataset = request.getfixturevalue(valid_dataset)
position_filtered = filter_func(
valid_input_dataset.position, **filter_kwargs
)

del position_filtered.attrs["log"]

# filtered array is an xr.DataArray
assert isinstance(position_filtered, xr.DataArray)

# filtered data should not be equal to the original data
assert not position_filtered.equals(valid_input_dataset.position)


def test_savgol_filter_with_nans(valid_poses_dataset_with_nan, helpers):
"""Test NaN behaviour of the Savitzky-Golay filter. The input data
contains NaN values in all keypoints of the first individual at times
3, 7, and 8 (0-indexed, 10 total timepoints).
The Savitzky-Golay filter should propagate NaNs within the windows of
the filter, but it should not introduce any NaNs for the second individual.
# Expected number of nans in the position array per
# individual, after applying a filter with window size 3
@pytest.mark.parametrize(
("valid_dataset, expected_nans_in_filtered_position_per_indiv"),
[
(
"valid_poses_dataset",
{0: 0, 1: 0},
), # filtering should not introduce nans if input has no nans
("valid_bboxes_dataset", {0: 0, 1: 0}),
("valid_poses_dataset_with_nan", {0: 7, 1: 0}),
("valid_bboxes_dataset_with_nan", {0: 7, 1: 0}),
],
)
@pytest.mark.parametrize(
("filter_func, filter_kwargs"),
[
(median_filter, {"window": 3}),
(savgol_filter, {"window": 3, "polyorder": 2}),
],
)
def test_filter_with_nans_on_position(
filter_func,
filter_kwargs,
valid_dataset,
expected_nans_in_filtered_position_per_indiv,
helpers,
request,
):
"""Test NaN behaviour of the selected filter. The median and SG filters
should set all values to NaN if one element of the sliding window is NaN.
"""
data = valid_poses_dataset_with_nan.position
data_smoothed = savgol_filter(data, window=3, polyorder=2)
# There should be 28 NaNs in total for the first individual, i.e.
# at 7 timepoints, 2 keypoints, 2 space dimensions
# all except for timepoints 0, 1 and 5
assert helpers.count_nans(data_smoothed) == 28
assert not (
data_smoothed.isel(individuals=0, time=[0, 1, 5]).isnull().any()

def _assert_n_nans_in_position_per_individual(
valid_input_dataset,
position_filtered,
expected_nans_in_filt_position_per_indiv,
):
# compute n nans in position after filtering per individual
n_nans_after_filtering_per_indiv = {
i: helpers.count_nans(position_filtered.isel(individuals=i))
for i in range(valid_input_dataset.sizes["individuals"])
}

# check number of nans per indiv is as expected
for i in range(valid_input_dataset.sizes["individuals"]):
assert n_nans_after_filtering_per_indiv[i] == (
expected_nans_in_filt_position_per_indiv[i]
* valid_input_dataset.sizes["space"]
* valid_input_dataset.sizes.get("keypoints", 1)
)

# Filter position
valid_input_dataset = request.getfixturevalue(valid_dataset)
position_filtered = filter_func(
valid_input_dataset.position, **filter_kwargs
)
assert not data_smoothed.isel(individuals=1).isnull().any()

# check number of nans per indiv is as expected
_assert_n_nans_in_position_per_individual(
valid_input_dataset,
position_filtered,
expected_nans_in_filtered_position_per_indiv,
)

# if input had nans,
# individual 1's position at exact timepoints 0, 1 and 5 is not nan
n_nans_input = helpers.count_nans(valid_input_dataset.position)
if n_nans_input != 0:
assert not (
position_filtered.isel(individuals=0, time=[0, 1, 5])
.isnull()
.any()
)


@pytest.mark.parametrize(
"valid_dataset",
list_all_valid_datasets,
)
@pytest.mark.parametrize(
"override_kwargs",
[
@@ -128,7 +246,7 @@ def test_savgol_filter_with_nans(valid_poses_dataset_with_nan, helpers):
],
)
def test_savgol_filter_kwargs_override(
valid_poses_dataset_with_nan, override_kwargs
valid_dataset, override_kwargs, request
):
"""Test that overriding keyword arguments in the Savitzky-Golay filter
works, except for the ``axis`` argument, which should raise a ValueError.
@@ -140,7 +258,7 @@ def test_savgol_filter_kwargs_override(
)
with expected_exception:
savgol_filter(
valid_poses_dataset_with_nan.position,
request.getfixturevalue(valid_dataset).position,
window=3,
**override_kwargs,
)
278 changes: 175 additions & 103 deletions tests/test_unit/test_kinematics.py
Original file line number Diff line number Diff line change
@@ -1,112 +1,184 @@
from contextlib import nullcontext as does_not_raise

import numpy as np
import pytest
import xarray as xr

from movement.analysis import kinematics


class TestKinematics:
"""Test suite for the kinematics module."""

@pytest.fixture
def expected_dataarray(self, valid_poses_dataset):
"""Return a function to generate the expected dataarray
for different kinematic properties.
"""

def _expected_dataarray(property):
"""Return an xarray.DataArray with default values and
the expected dimensions and coordinates.
"""
# Expected x,y values for velocity
x_vals = np.array(
[1.0, 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0, 17.0]
)
y_vals = np.full((10, 2, 2, 1), 4.0)
if property == "acceleration":
x_vals = np.array(
[1.0, 1.5, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 1.5, 1.0]
)
y_vals = np.full((10, 2, 2, 1), 0)
elif property == "displacement":
x_vals = np.array(
[0.0, 1.0, 3.0, 5.0, 7.0, 9.0, 11.0, 13.0, 15.0, 17.0]
)
y_vals[0] = 0

x_vals = x_vals.reshape(-1, 1, 1, 1)
# Repeat the x_vals to match the shape of the position
x_vals = np.tile(x_vals, (1, 2, 2, 1))
return xr.DataArray(
np.concatenate(
[x_vals, y_vals],
axis=-1,
),
dims=valid_poses_dataset.dims,
coords=valid_poses_dataset.coords,
)

return _expected_dataarray

kinematic_test_params = [
("valid_poses_dataset", does_not_raise()),
("valid_poses_dataset_with_nan", does_not_raise()),
("missing_dim_poses_dataset", pytest.raises(AttributeError)),
]
@pytest.mark.parametrize(
"valid_dataset_uniform_linear_motion",
[
"valid_poses_dataset_uniform_linear_motion",
"valid_bboxes_dataset",
],
)
@pytest.mark.parametrize(
"kinematic_variable, expected_kinematics",
[
(
"displacement",
[
np.vstack([np.zeros((1, 2)), np.ones((9, 2))]), # Individual 0
np.multiply(
np.vstack([np.zeros((1, 2)), np.ones((9, 2))]),
np.array([1, -1]),
), # Individual 1
],
),
(
"velocity",
[
np.ones((10, 2)), # Individual 0
np.multiply(
np.ones((10, 2)), np.array([1, -1])
), # Individual 1
],
),
(
"acceleration",
[
np.zeros((10, 2)), # Individual 0
np.zeros((10, 2)), # Individual 1
],
),
],
)
def test_kinematics_uniform_linear_motion(
valid_dataset_uniform_linear_motion,
kinematic_variable,
expected_kinematics, # 2D: n_frames, n_space_dims
request,
):
"""Test computed kinematics for a uniform linear motion case.
Uniform linear motion means the individuals move along a line
at constant velocity.
We consider 2 individuals ("id_0" and "id_1"),
tracked for 10 frames, along x and y:
- id_0 moves along x=y line from the origin
- id_1 moves along x=-y line from the origin
- they both move one unit (pixel) along each axis in each frame
If the dataset is a poses dataset, we consider 3 keypoints per individual
(centroid, left, right), that are always in front of the centroid keypoint
at 45deg from the trajectory.
"""
# Compute kinematic array from input dataset
position = request.getfixturevalue(
valid_dataset_uniform_linear_motion
).position
kinematic_array = getattr(kinematics, f"compute_{kinematic_variable}")(
position
)

@pytest.mark.parametrize("ds, expected_exception", kinematic_test_params)
def test_displacement(
self, ds, expected_exception, expected_dataarray, request
):
"""Test displacement computation."""
ds = request.getfixturevalue(ds)
with expected_exception:
result = kinematics.compute_displacement(ds.position)
expected = expected_dataarray("displacement")
if ds.position.isnull().any():
expected.loc[
{"individuals": "ind1", "time": [3, 4, 7, 8, 9]}
] = np.nan
xr.testing.assert_allclose(result, expected)

@pytest.mark.parametrize("ds, expected_exception", kinematic_test_params)
def test_velocity(
self, ds, expected_exception, expected_dataarray, request
):
"""Test velocity computation."""
ds = request.getfixturevalue(ds)
with expected_exception:
result = kinematics.compute_velocity(ds.position)
expected = expected_dataarray("velocity")
if ds.position.isnull().any():
expected.loc[
{"individuals": "ind1", "time": [2, 4, 6, 7, 8, 9]}
] = np.nan
xr.testing.assert_allclose(result, expected)

@pytest.mark.parametrize("ds, expected_exception", kinematic_test_params)
def test_acceleration(
self, ds, expected_exception, expected_dataarray, request
):
"""Test acceleration computation."""
ds = request.getfixturevalue(ds)
with expected_exception:
result = kinematics.compute_acceleration(ds.position)
expected = expected_dataarray("acceleration")
if ds.position.isnull().any():
expected.loc[
{"individuals": "ind1", "time": [1, 3, 5, 6, 7, 8, 9]}
] = np.nan
xr.testing.assert_allclose(result, expected)

@pytest.mark.parametrize("order", [0, -1, 1.0, "1"])
def test_approximate_derivative_with_invalid_order(self, order):
"""Test that an error is raised when the order is non-positive."""
data = np.arange(10)
expected_exception = (
ValueError if isinstance(order, int) else TypeError
# Build expected data array from the expected numpy array
expected_array = xr.DataArray(
np.stack(expected_kinematics, axis=1),
# Stack along the "individuals" axis
dims=["time", "individuals", "space"],
)
if "keypoints" in position.coords:
expected_array = expected_array.expand_dims(
{"keypoints": position.coords["keypoints"].size}
)
with pytest.raises(expected_exception):
kinematics._compute_approximate_derivative(data, order=order)
expected_array = expected_array.transpose(
"time", "individuals", "keypoints", "space"
)

# Compare the values of the kinematic_array against the expected_array
np.testing.assert_allclose(kinematic_array.values, expected_array.values)


@pytest.mark.parametrize(
"valid_dataset_with_nan",
[
"valid_poses_dataset_with_nan",
"valid_bboxes_dataset_with_nan",
],
)
@pytest.mark.parametrize(
"kinematic_variable, expected_nans_per_individual",
[
("displacement", [5, 0]), # individual 0, individual 1
("velocity", [6, 0]),
("acceleration", [7, 0]),
],
)
def test_kinematics_with_dataset_with_nans(
valid_dataset_with_nan,
kinematic_variable,
expected_nans_per_individual,
helpers,
request,
):
"""Test kinematics computation for a dataset with nans.
We test that the kinematics can be computed and that the number
of nan values in the kinematic array is as expected.
"""
# compute kinematic array
valid_dataset = request.getfixturevalue(valid_dataset_with_nan)
position = valid_dataset.position
kinematic_array = getattr(kinematics, f"compute_{kinematic_variable}")(
position
)

# compute n nans in kinematic array per individual
n_nans_kinematics_per_indiv = [
helpers.count_nans(kinematic_array.isel(individuals=i))
for i in range(valid_dataset.sizes["individuals"])
]

# expected nans per individual adjusted for space and keypoints dimensions
expected_nans_adjusted = [
n
* valid_dataset.sizes["space"]
* valid_dataset.sizes.get("keypoints", 1)
for n in expected_nans_per_individual
]
# check number of nans per individual is as expected in kinematic array
np.testing.assert_array_equal(
n_nans_kinematics_per_indiv, expected_nans_adjusted
)


@pytest.mark.parametrize(
"invalid_dataset, expected_exception",
[
("not_a_dataset", pytest.raises(AttributeError)),
("empty_dataset", pytest.raises(AttributeError)),
("missing_var_poses_dataset", pytest.raises(AttributeError)),
("missing_var_bboxes_dataset", pytest.raises(AttributeError)),
("missing_dim_poses_dataset", pytest.raises(ValueError)),
("missing_dim_bboxes_dataset", pytest.raises(ValueError)),
],
)
@pytest.mark.parametrize(
"kinematic_variable",
[
"displacement",
"velocity",
"acceleration",
],
)
def test_kinematics_with_invalid_dataset(
invalid_dataset,
expected_exception,
kinematic_variable,
request,
):
"""Test kinematics computation with an invalid dataset."""
with expected_exception:
position = request.getfixturevalue(invalid_dataset).position
getattr(kinematics, f"compute_{kinematic_variable}")(position)


@pytest.mark.parametrize("order", [0, -1, 1.0, "1"])
def test_approximate_derivative_with_invalid_order(order):
"""Test that an error is raised when the order is non-positive."""
data = np.arange(10)
expected_exception = ValueError if isinstance(order, int) else TypeError
with pytest.raises(expected_exception):
kinematics._compute_approximate_time_derivative(data, order=order)
51 changes: 51 additions & 0 deletions tests/test_unit/test_load_bboxes.py
Original file line number Diff line number Diff line change
@@ -419,3 +419,54 @@ def test_fps_and_time_coords(
else:
start_frame = 0
assert_time_coordinates(ds, expected_fps, start_frame)


def test_df_from_via_tracks_file(via_tracks_file):
"""Test that the helper function correctly reads the VIA tracks .csv file
as a dataframe.
"""
df = load_bboxes._df_from_via_tracks_file(via_tracks_file)

assert isinstance(df, pd.DataFrame)
assert len(df.frame_number.unique()) == 5
assert (
df.shape[0] == len(df.ID.unique()) * 5
) # all individuals in all frames (even if nan)
assert list(df.columns) == [
"ID",
"frame_number",
"x",
"y",
"w",
"h",
"confidence",
]


def test_position_numpy_array_from_via_tracks_file(via_tracks_file):
"""Test the extracted position array from the VIA tracks .csv file
represents the centroid of the bbox.
"""
# Extract numpy arrays from VIA tracks .csv file
bboxes_arrays = load_bboxes._numpy_arrays_from_via_tracks_file(
via_tracks_file
)

# Read VIA tracks .csv file as a dataframe
df = load_bboxes._df_from_via_tracks_file(via_tracks_file)

# Compute centroid positions from the dataframe
# (go thru in the same order as ID array)
list_derived_centroids = []
for id in bboxes_arrays["ID_array"]:
df_one_id = df[df["ID"] == id.item()]
centroid_position = np.array(
[df_one_id.x + df_one_id.w / 2, df_one_id.y + df_one_id.h / 2]
).T # frames, xy
list_derived_centroids.append(centroid_position)

# Compare to extracted position array
assert np.allclose(
bboxes_arrays["position_array"], # frames, individuals, xy
np.stack(list_derived_centroids, axis=1),
)
41 changes: 28 additions & 13 deletions tests/test_unit/test_logging.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging

import pytest
import xarray as xr

from movement.utils.logging import log_error, log_to_attrs, log_warning

@@ -45,27 +46,41 @@ def test_log_warning(caplog):
assert caplog.records[0].levelname == "WARNING"


@pytest.mark.parametrize("input_data", ["dataset", "dataarray"])
def test_log_to_attrs(input_data, valid_poses_dataset):
@pytest.mark.parametrize(
"input_data",
[
"valid_poses_dataset",
"valid_bboxes_dataset",
],
)
@pytest.mark.parametrize(
"selector_fn, expected_selector_type",
[
(lambda ds: ds, xr.Dataset), # take full dataset
(lambda ds: ds.position, xr.DataArray), # take position data array
],
)
def test_log_to_attrs(
input_data, selector_fn, expected_selector_type, request
):
"""Test that the ``log_to_attrs()`` decorator appends
log entries to the output data's ``log`` attribute and
checks that ``attrs`` contains all expected values.
log entries to the dataset's or the data array's ``log``
attribute and check that ``attrs`` contains all the expected values.
"""

# a fake operation on the dataset to log
@log_to_attrs
def fake_func(data, arg, kwarg=None):
return data

input_data = (
valid_poses_dataset
if input_data == "dataset"
else valid_poses_dataset.position
)
# apply operation to dataset or data array
dataset = request.getfixturevalue(input_data)
input_data = selector_fn(dataset)
output_data = fake_func(input_data, "test1", kwarg="test2")

# check the log in the dataset is as expected
assert isinstance(output_data, expected_selector_type)
assert "log" in output_data.attrs
assert output_data.attrs["log"][0]["operation"] == "fake_func"
assert (
output_data.attrs["log"][0]["arg_1"] == "test1"
and output_data.attrs["log"][0]["kwarg"] == "test2"
)
assert output_data.attrs["log"][0]["arg_1"] == "test1"
assert output_data.attrs["log"][0]["kwarg"] == "test2"
138 changes: 118 additions & 20 deletions tests/test_unit/test_reports.py
Original file line number Diff line number Diff line change
@@ -4,28 +4,126 @@


@pytest.mark.parametrize(
"data_selection",
"valid_dataset",
[
lambda ds: ds.position, # Entire dataset
lambda ds: ds.position.sel(
individuals="ind1"
), # Missing "individuals" dim
lambda ds: ds.position.sel(
keypoints="key1"
), # Missing "keypoints" dim
lambda ds: ds.position.sel(
individuals="ind1", keypoints="key1"
), # Missing "individuals" and "keypoints" dims
"valid_poses_dataset",
"valid_bboxes_dataset",
"valid_poses_dataset_with_nan",
"valid_bboxes_dataset_with_nan",
],
)
def test_report_nan_values(
capsys, valid_poses_dataset_with_nan, data_selection
@pytest.mark.parametrize(
"data_selection, list_expected_individuals_indices",
[
(lambda ds: ds.position, [0, 1]), # full position data array
(
lambda ds: ds.position.isel(individuals=0),
[0],
), # position of individual 0 only
],
)
def test_report_nan_values_in_position_selecting_individual(
valid_dataset,
data_selection,
list_expected_individuals_indices,
request,
):
"""Test that the nan-value reporting function handles data
with missing ``individuals`` and/or ``keypoint`` dims, and
that the dataset name is included in the report.
"""Test that the nan-value reporting function handles position data
with specific ``individuals`` , and that the data array name (position)
and only the relevant individuals are included in the report.
"""
data = data_selection(valid_poses_dataset_with_nan)
assert data.name in report_nan_values(
data
), "Dataset name should be in the output"
# extract relevant position data
input_dataset = request.getfixturevalue(valid_dataset)
output_data_array = data_selection(input_dataset)

# produce report
report_str = report_nan_values(output_data_array)

# check report of nan values includes name of data array
assert output_data_array.name in report_str

# check report of nan values includes selected individuals only
list_expected_individuals = [
input_dataset["individuals"][idx].item()
for idx in list_expected_individuals_indices
]
list_not_expected_individuals = [
indiv.item()
for indiv in input_dataset["individuals"]
if indiv.item() not in list_expected_individuals
]
assert all([ind in report_str for ind in list_expected_individuals])
assert all(
[ind not in report_str for ind in list_not_expected_individuals]
)


@pytest.mark.parametrize(
"valid_dataset",
[
"valid_poses_dataset",
"valid_poses_dataset_with_nan",
],
)
@pytest.mark.parametrize(
"data_selection, list_expected_keypoints, list_expected_individuals",
[
(
lambda ds: ds.position,
["key1", "key2"],
["ind1", "ind2"],
), # Report nans in position for all keypoints and individuals
(
lambda ds: ds.position.sel(keypoints="key1"),
[],
["ind1", "ind2"],
), # Report nans in position for keypoint "key1", for all individuals
# Note: if only one keypoint exists, it is not explicitly reported
(
lambda ds: ds.position.sel(individuals="ind1", keypoints="key1"),
[],
["ind1"],
), # Report nans in position for individual "ind1" and keypoint "key1"
# Note: if only one keypoint exists, it is not explicitly reported
],
)
def test_report_nan_values_in_position_selecting_keypoint(
valid_dataset,
data_selection,
list_expected_keypoints,
list_expected_individuals,
request,
):
"""Test that the nan-value reporting function handles position data
with specific ``keypoints`` , and that the data array name (position)
and only the relevant keypoints are included in the report.
"""
# extract relevant position data
input_dataset = request.getfixturevalue(valid_dataset)
output_data_array = data_selection(input_dataset)

# produce report
report_str = report_nan_values(output_data_array)

# check report of nan values includes name of data array
assert output_data_array.name in report_str

# check report of nan values includes only selected keypoints
list_not_expected_keypoints = [
indiv.item()
for indiv in input_dataset["keypoints"]
if indiv.item() not in list_expected_keypoints
]
assert all([kpt in report_str for kpt in list_expected_keypoints])
assert all([kpt not in report_str for kpt in list_not_expected_keypoints])

# check report of nan values includes selected individuals only
list_not_expected_individuals = [
indiv.item()
for indiv in input_dataset["individuals"]
if indiv.item() not in list_expected_individuals
]
assert all([ind in report_str for ind in list_expected_individuals])
assert all(
[ind not in report_str for ind in list_not_expected_individuals]
)
114 changes: 52 additions & 62 deletions tests/test_unit/test_validators/test_datasets_validators.py
Original file line number Diff line number Diff line change
@@ -76,29 +76,14 @@ def position_array_params(request):
), # not an ndarray
(
np.zeros((10, 2, 3)),
f"Expected '{key}' to have 2 spatial " "coordinates, but got 3.",
f"Expected '{key}_array' to have 2 spatial "
"coordinates, but got 3.",
), # last dim not 2
]
for key in ["position_array", "shape_array"]
for key in ["position", "shape"]
}


@pytest.fixture
def valid_bboxes_inputs():
"""Return a dictionary with valid inputs for a ValidBboxesDataset."""
n_frames, n_individuals, n_space = (10, 2, 2)
# valid array for position or shape
valid_bbox_array = np.zeros((n_frames, n_individuals, n_space))

return {
"position_array": valid_bbox_array,
"shape_array": valid_bbox_array,
"individual_names": [
"id_" + str(id) for id in range(valid_bbox_array.shape[1])
],
}


# Tests pose dataset
@pytest.mark.parametrize(
"invalid_position_array, log_message",
@@ -223,7 +208,7 @@ def test_poses_dataset_validator_source_software(
# Tests bboxes dataset
@pytest.mark.parametrize(
"invalid_position_array, log_message",
invalid_bboxes_arrays_and_expected_log["position_array"],
invalid_bboxes_arrays_and_expected_log["position"],
)
def test_bboxes_dataset_validator_with_invalid_position_array(
invalid_position_array, log_message, request
@@ -232,33 +217,33 @@ def test_bboxes_dataset_validator_with_invalid_position_array(
with pytest.raises(ValueError) as excinfo:
ValidBboxesDataset(
position_array=invalid_position_array,
shape_array=request.getfixturevalue("valid_bboxes_inputs")[
"shape_array"
],
individual_names=request.getfixturevalue("valid_bboxes_inputs")[
"individual_names"
],
shape_array=request.getfixturevalue(
"valid_bboxes_arrays_all_zeros"
)["shape"],
individual_names=request.getfixturevalue(
"valid_bboxes_arrays_all_zeros"
)["individual_names"],
)
assert str(excinfo.value) == log_message


@pytest.mark.parametrize(
"invalid_shape_array, log_message",
invalid_bboxes_arrays_and_expected_log["shape_array"],
invalid_bboxes_arrays_and_expected_log["shape"],
)
def test_bboxes_dataset_validator_with_invalid_shape_array(
invalid_shape_array, log_message, request
):
"""Test that invalid shape arrays raise an error."""
with pytest.raises(ValueError) as excinfo:
ValidBboxesDataset(
position_array=request.getfixturevalue("valid_bboxes_inputs")[
"position_array"
],
position_array=request.getfixturevalue(
"valid_bboxes_arrays_all_zeros"
)["position"],
shape_array=invalid_shape_array,
individual_names=request.getfixturevalue("valid_bboxes_inputs")[
"individual_names"
],
individual_names=request.getfixturevalue(
"valid_bboxes_arrays_all_zeros"
)["individual_names"],
)
assert str(excinfo.value) == log_message

@@ -274,15 +259,19 @@ def test_bboxes_dataset_validator_with_invalid_shape_array(
(
[1, 2, 3],
pytest.raises(ValueError),
"Expected 'individual_names' to have length 2, but got 3.",
"Expected 'individual_names' to have length 2, "
f"but got {len([1, 2, 3])}.",
), # length doesn't match position_array.shape[1]
# from valid_bboxes_arrays_all_zeros fixture
(
["id_1", "id_1"],
pytest.raises(ValueError),
"individual_names passed to the dataset are not unique. "
"There are 2 elements in the list, but "
"only 1 are unique.",
), # some IDs are not unique
), # some IDs are not unique.
# Note: length of individual_names list should match
# n_individuals in valid_bboxes_arrays_all_zeros fixture
],
)
def test_bboxes_dataset_validator_individual_names(
@@ -291,12 +280,12 @@ def test_bboxes_dataset_validator_individual_names(
"""Test individual_names inputs."""
with expected_exception as excinfo:
ds = ValidBboxesDataset(
position_array=request.getfixturevalue("valid_bboxes_inputs")[
"position_array"
],
shape_array=request.getfixturevalue("valid_bboxes_inputs")[
"shape_array"
],
position_array=request.getfixturevalue(
"valid_bboxes_arrays_all_zeros"
)["position"],
shape_array=request.getfixturevalue(
"valid_bboxes_arrays_all_zeros"
)["shape"],
individual_names=list_individual_names,
)
if list_individual_names is None:
@@ -313,13 +302,14 @@ def test_bboxes_dataset_validator_individual_names(
(
np.ones((10, 3, 2)),
pytest.raises(ValueError),
"Expected 'confidence_array' to have shape (10, 2), "
"but got (10, 3, 2).",
), # will not match position_array shape
f"Expected 'confidence_array' to have shape (10, 2), "
f"but got {np.ones((10, 3, 2)).shape}.",
), # will not match shape of position_array in
# valid_bboxes_arrays_all_zeros fixture
(
[1, 2, 3],
pytest.raises(ValueError),
f"Expected a numpy array, but got {type(list())}.",
f"Expected a numpy array, but got {type([1, 2, 3])}.",
), # not an ndarray, should raise ValueError
(
None,
@@ -334,15 +324,15 @@ def test_bboxes_dataset_validator_confidence_array(
"""Test that invalid confidence arrays raise the appropriate errors."""
with expected_exception as excinfo:
ds = ValidBboxesDataset(
position_array=request.getfixturevalue("valid_bboxes_inputs")[
"position_array"
],
shape_array=request.getfixturevalue("valid_bboxes_inputs")[
"shape_array"
],
individual_names=request.getfixturevalue("valid_bboxes_inputs")[
"individual_names"
],
position_array=request.getfixturevalue(
"valid_bboxes_arrays_all_zeros"
)["position"],
shape_array=request.getfixturevalue(
"valid_bboxes_arrays_all_zeros"
)["shape"],
individual_names=request.getfixturevalue(
"valid_bboxes_arrays_all_zeros"
)["individual_names"],
confidence_array=confidence_array,
)
if confidence_array is None:
@@ -387,15 +377,15 @@ def test_bboxes_dataset_validator_frame_array(
"""Test that invalid frame arrays raise the appropriate errors."""
with expected_exception as excinfo:
ds = ValidBboxesDataset(
position_array=request.getfixturevalue("valid_bboxes_inputs")[
"position_array"
],
shape_array=request.getfixturevalue("valid_bboxes_inputs")[
"shape_array"
],
individual_names=request.getfixturevalue("valid_bboxes_inputs")[
"individual_names"
],
position_array=request.getfixturevalue(
"valid_bboxes_arrays_all_zeros"
)["position"],
shape_array=request.getfixturevalue(
"valid_bboxes_arrays_all_zeros"
)["shape"],
individual_names=request.getfixturevalue(
"valid_bboxes_arrays_all_zeros"
)["individual_names"],
frame_array=frame_array,
)

0 comments on commit 6742bae

Please sign in to comment.