diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index 009baa633..4b7b4c11c 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -52,7 +52,7 @@ jobs: with: python-version: ${{ matrix.python-version }} auto-update-conda: true - channels: conda-forge + channels: conda-forge,nodefaults activate-environment: movement-env - uses: neuroinformatics-unit/actions/test@v2 with: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f764b7e35..7443f3984 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -29,12 +29,12 @@ repos: - id: rst-directive-colons - id: rst-inline-touching-normal - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.5.6 + rev: v0.6.3 hooks: - id: ruff - id: ruff-format - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.11.1 + rev: v1.11.2 hooks: - id: mypy additional_dependencies: diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index e39412a65..ac9251138 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -12,15 +12,12 @@ development environment for movement. In the following we assume you have First, create and activate a `conda` environment with some prerequisites: ```sh -conda create -n movement-dev -c conda-forge python=3.10 pytables +conda create -n movement-dev -c conda-forge python=3.11 pytables conda activate movement-dev ``` -The above method ensures that you will get packages that often can't be -installed via `pip`, including [hdf5](https://www.hdfgroup.org/solutions/hdf5/). - -To install movement for development, clone the GitHub repository, -and then run from inside the repository: +To install movement for development, clone the [GitHub repository](movement-github:), +and then run from within the repository: ```sh pip install -e .[dev] # works on most shells @@ -162,13 +159,12 @@ The version number is automatically determined from the latest tag on the _main_ The documentation is hosted via [GitHub pages](https://pages.github.com/) at [movement.neuroinformatics.dev](target-movement). Its source files are located in the `docs` folder of this repository. -They are written in either [reStructuredText](https://docutils.sourceforge.io/rst.html) or -[markdown](myst-parser:syntax/typography.html). +They are written in either [Markdown](myst-parser:syntax/typography.html) +or [reStructuredText](https://docutils.sourceforge.io/rst.html). The `index.md` file corresponds to the homepage of the documentation website. -Other `.rst` or `.md` files are linked to the homepage via the `toctree` directive. +Other `.md` or `.rst` files are linked to the homepage via the `toctree` directive. -We use [Sphinx](https://www.sphinx-doc.org/en/master/) and the -[PyData Sphinx Theme](https://pydata-sphinx-theme.readthedocs.io/en/stable/index.html) +We use [Sphinx](sphinx-doc:) and the [PyData Sphinx Theme](https://pydata-sphinx-theme.readthedocs.io/en/stable/index.html) to build the source files into HTML output. This is handled by a GitHub actions workflow (`.github/workflows/docs_build_and_deploy.yml`). The build job is triggered on each PR, ensuring that the documentation build is not broken by new changes. @@ -199,17 +195,16 @@ existing_file my_new_file ``` -#### Adding external links -If you are adding references to an external link (e.g. `https://github.com/neuroinformatics-unit/movement/issues/1`) in a `.md` file, you will need to check if a matching URL scheme (e.g. `https://github.com/neuroinformatics-unit/movement/`) is defined in `myst_url_schemes` in `docs/source/conf.py`. If it is, the following `[](scheme:loc)` syntax will be converted to the [full URL](movement-github:issues/1) during the build process: +#### Linking to external URLs +If you are adding references to an external URL (e.g. `https://github.com/neuroinformatics-unit/movement/issues/1`) in a `.md` file, you will need to check if a matching URL scheme (e.g. `https://github.com/neuroinformatics-unit/movement/`) is defined in `myst_url_schemes` in `docs/source/conf.py`. If it is, the following `[](scheme:loc)` syntax will be converted to the [full URL](movement-github:issues/1) during the build process: ```markdown [link text](movement-github:issues/1) ``` -If it is not yet defined and you have multiple external links pointing to the same base URL, you will need to [add the URL scheme](myst-parser:syntax/cross-referencing.html#customising-external-url-resolution) to `myst_url_schemes` in `docs/source/conf.py`. - +If it is not yet defined and you have multiple external URLs pointing to the same base URL, you will need to [add the URL scheme](myst-parser:syntax/cross-referencing.html#customising-external-url-resolution) to `myst_url_schemes` in `docs/source/conf.py`. ### Updating the API reference -The API reference is auto-generated by the `docs/make_api_index.py` script, and the `sphinx-autodoc` and `sphinx-autosummary` plugins. +The [API reference](target-api) is auto-generated by the `docs/make_api_index.py` script, and the [sphinx-autodoc](sphinx-doc:extensions/autodoc.html) and [sphinx-autosummary](sphinx-doc:extensions/autosummary.html) extensions. The script generates the `docs/source/api_index.rst` file containing the list of modules to be included in the [API reference](target-api). The plugins then generate the API reference pages for each module listed in `api_index.rst`, based on the docstrings in the source code. So make sure that all your public functions/classes/methods have valid docstrings following the [numpydoc](https://numpydoc.readthedocs.io/en/latest/format.html) style. @@ -224,7 +219,7 @@ To add new examples, you will need to create a new `.py` file in `examples/`. The file should be structured as specified in the relevant [sphinx-gallery documentation](sphinx-gallery:syntax). -We are using sphinx-gallery's [integration with binder](https://sphinx-gallery.github.io/stable/configuration.html#binder-links) +We are using sphinx-gallery's [integration with binder](sphinx-gallery:configuration#binder-links) to provide interactive versions of the examples. If your examples rely on packages that are not among movement's dependencies, you will need to add them to the `docs/source/environment.yml` file. @@ -232,6 +227,52 @@ That file is used by binder to create the conda environment in which the examples are run. See the relevant section of the [binder documentation](https://mybinder.readthedocs.io/en/latest/using/config_files.html). +### Cross-referencing Python objects +:::{note} +Docstrings in the `.py` files for the [API reference](target-api) and the [examples](target-examples) are converted into `.rst` files, so these should use reStructuredText syntax. +::: + +#### Internal references +::::{tab-set} +:::{tab-item} Markdown +For referencing movement objects in `.md` files, use the `` {role}`target` `` syntax with the appropriate [Python object role](sphinx-doc:domains/python.html#cross-referencing-python-objects). + +For example, to reference the {mod}`movement.io.load_poses` module, use: +```markdown +{mod}`movement.io.load_poses` +``` +::: +:::{tab-item} RestructuredText +For referencing movement objects in `.rst` files, use the `` :role:`target` `` syntax with the appropriate [Python object role](sphinx-doc:domains/python.html#cross-referencing-python-objects). + +For example, to reference the {mod}`movement.io.load_poses` module, use: +```rst +:mod:`movement.io.load_poses` +``` +::: +:::: + +#### External references +For referencing external Python objects using [intersphinx](sphinx-doc:extensions/intersphinx.html), +ensure the mapping between module names and their documentation URLs is defined in [`intersphinx_mapping`](sphinx-doc:extensions/intersphinx.html#confval-intersphinx_mapping) in `docs/source/conf.py`. +Once the module is included in the mapping, use the same syntax as for [internal references](#internal-references). + +::::{tab-set} +:::{tab-item} Markdown +For example, to reference the {meth}`xarray.Dataset.update` method, use: +```markdown +{meth}`xarray.Dataset.update` +``` +::: + +:::{tab-item} RestructuredText +For example, to reference the {meth}`xarray.Dataset.update` method, use: +```rst +:meth:`xarray.Dataset.update` +``` +::: +:::: + ### Building the documentation locally We recommend that you build and view the documentation website locally, before you push it. To do so, first navigate to `docs/`. @@ -256,7 +297,7 @@ The local build can be viewed by opening `docs/build/html/index.html` in a brows :::{tab-item} All platforms ```sh -python make_api_index.py && sphinx-build source build +python make_api_index.py && sphinx-build source build -W --keep-going ``` The local build can be viewed by opening `docs/build/index.html` in a browser. ::: @@ -276,7 +317,7 @@ make clean html :::{tab-item} All platforms ```sh rm -f source/api_index.rst && rm -rf build && rm -rf source/api && rm -rf source/examples -python make_api_index.py && sphinx-build source build +python make_api_index.py && sphinx-build source build -W --keep-going ``` ::: :::: @@ -292,7 +333,7 @@ make linkcheck :::{tab-item} All platforms ```sh -sphinx-build source build -b linkcheck +sphinx-build source build -b linkcheck -W --keep-going ``` ::: :::: @@ -338,7 +379,7 @@ The most important parts of this module are: 1. The `SAMPLE_DATA` download manager object. 2. The `list_datasets()` function, which returns a list of the available poses and bounding boxes datasets (file names of the data files). 3. The `fetch_dataset_paths()` function, which returns a dictionary containing local paths to the files associated with a particular sample dataset: `poses` or `bboxes`, `frame`, `video`. If the relevant files are not already cached locally, they will be downloaded. -4. The `fetch_dataset()` function, which downloads the files associated with a given sample dataset (same as `fetch_dataset_paths()`) and additionally loads the pose or bounding box data into `movement`, returning an `xarray.Dataset` object. If available, the local paths to the associated video and frame files are stored as dataset attributes, with names `video_path` and `frame_path`, respectively. +4. The `fetch_dataset()` function, which downloads the files associated with a given sample dataset (same as `fetch_dataset_paths()`) and additionally loads the pose or bounding box data into movement, returning an `xarray.Dataset` object. If available, the local paths to the associated video and frame files are stored as dataset attributes, with names `video_path` and `frame_path`, respectively. By default, the downloaded files are stored in the `~/.movement/data` folder. This can be changed by setting the `DATA_DIR` variable in the `movement.sample_data.py` module. @@ -372,7 +413,7 @@ To add a new file, you will need to: ``` ::: :::: - For convenience, we've included a `get_sha256_hashes.py` script in the [movement data repository](gin:neuroinformatics/movement-test-data). If you run this from the root of the data repository, within a Python environment with `movement` installed, it will calculate the sha256 hashes for all files in the `poses`, `bboxes`, `videos` and `frames` folders and write them to files named `poses_hashes.txt`, `bboxes_hashes.txt`, `videos_hashes.txt`, and `frames_hashes.txt` respectively. + For convenience, we've included a `get_sha256_hashes.py` script in the [movement data repository](gin:neuroinformatics/movement-test-data). If you run this from the root of the data repository, within a Python environment with movement installed, it will calculate the sha256 hashes for all files in the `poses`, `bboxes`, `videos` and `frames` folders and write them to files named `poses_hashes.txt`, `bboxes_hashes.txt`, `videos_hashes.txt`, and `frames_hashes.txt` respectively. 7. Add metadata for your new files to `metadata.yaml`, including their sha256 hashes you've calculated. See the example entry below for guidance. diff --git a/README.md b/README.md index fd7aaba13..5b1dec347 100644 --- a/README.md +++ b/README.md @@ -16,17 +16,12 @@ A Python toolbox for analysing body movements across space and time, to aid the ## Quick install -First, create and activate a conda environment with the required dependencies: +Create and activate a conda environment with movement installed: ``` -conda create -n movement-env -c conda-forge python=3.11 pytables +conda create -n movement-env -c conda-forge movement conda activate movement-env ``` -Then install the `movement` package: -``` -pip install movement -``` - > [!Note] > Read the [documentation](https://movement.neuroinformatics.dev) for more information, including [full installation instructions](https://movement.neuroinformatics.dev/getting_started/installation.html) and [examples](https://movement.neuroinformatics.dev/examples/index.html). @@ -52,7 +47,7 @@ You are welcome to chat with the team on [zulip](https://neuroinformatics.zulipc ## Citation -If you use `movement` in your work, please cite the following Zenodo DOI: +If you use movement in your work, please cite the following Zenodo DOI: > Nikoloz Sirmpilatze, Chang Huan Lo, Sofía Miñano, Brandon D. Peri, Dhruv Sharma, Laura Porta, Iván Varela & Adam L. Tyson (2024). neuroinformatics-unit/movement. Zenodo. https://zenodo.org/doi/10.5281/zenodo.12755724 diff --git a/docs/Makefile b/docs/Makefile index 623d2906a..529f66505 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -3,7 +3,9 @@ # You can set these variables from the command line, and also # from the environment for the first two. -SPHINXOPTS ?= +# -W: if there are warnings, treat them as errors and exit with status 1. +# --keep-going: run sphinx-build to completion and exit with status 1 if errors. +SPHINXOPTS ?= -W --keep-going SPHINXBUILD ?= sphinx-build SOURCEDIR = source BUILDDIR = build diff --git a/docs/source/_static/css/custom.css b/docs/source/_static/css/custom.css index 40e09e632..ce6b5ae34 100644 --- a/docs/source/_static/css/custom.css +++ b/docs/source/_static/css/custom.css @@ -30,3 +30,12 @@ display: flex; flex-wrap: wrap; justify-content: space-between; } + +/* Disable decoration for all but movement backrefs */ +a[class^="sphx-glr-backref-module-"], +a[class^="sphx-glr-backref-type-"] { + text-decoration: none; +} +a[class^="sphx-glr-backref-module-movement"] { + text-decoration: underline; +} diff --git a/docs/source/_static/dataset_structure.png b/docs/source/_static/dataset_structure.png index 4e6a17d90..13506196d 100644 Binary files a/docs/source/_static/dataset_structure.png and b/docs/source/_static/dataset_structure.png differ diff --git a/docs/source/_static/movement_overview.png b/docs/source/_static/movement_overview.png index 8af12daa0..33c5927a5 100644 Binary files a/docs/source/_static/movement_overview.png and b/docs/source/_static/movement_overview.png differ diff --git a/docs/source/community/roadmaps.md b/docs/source/community/roadmaps.md index 69b6000e0..78b1bd67b 100644 --- a/docs/source/community/roadmaps.md +++ b/docs/source/community/roadmaps.md @@ -24,5 +24,5 @@ We plan to release version `v0.1` of movement in early 2024, providing a minimal - [x] Ability to compute velocity and acceleration from pose tracks. - [x] Public website with [documentation](target-movement). - [x] Package released on [PyPI](https://pypi.org/project/movement/). -- [ ] Package released on [conda-forge](https://conda-forge.org/). +- [x] Package released on [conda-forge](https://anaconda.org/conda-forge/movement). - [ ] Ability to visualise pose tracks using [napari](napari:). We aim to represent pose tracks via napari's [Points](napari:howtos/layers/points) and [Tracks](napari:howtos/layers/tracks) layers and overlay them on video frames. diff --git a/docs/source/conf.py b/docs/source/conf.py index f63529432..9b051fb04 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -68,7 +68,7 @@ "tasklist", ] # Automatically add anchors to markdown headings -myst_heading_anchors = 3 +myst_heading_anchors = 4 # Add any paths that contain templates here, relative to this directory. templates_path = ["_templates"] @@ -108,6 +108,7 @@ "binderhub_url": "https://mybinder.org", "dependencies": ["environment.yml"], }, + "reference_url": {"movement": None}, "remove_config_comments": True, # do not render config params set as # sphinx_gallery_config [= value] } @@ -191,8 +192,14 @@ "napari": "https://napari.org/dev/{{path}}", "setuptools-scm": "https://setuptools-scm.readthedocs.io/en/latest/{{path}}#{{fragment}}", "sleap": "https://sleap.ai/{{path}}#{{fragment}}", - "sphinx-gallery": "https://sphinx-gallery.github.io/stable/{{path}}", + "sphinx-doc": "https://www.sphinx-doc.org/en/master/usage/{{path}}#{{fragment}}", + "sphinx-gallery": "https://sphinx-gallery.github.io/stable/{{path}}#{{fragment}}", "xarray": "https://docs.xarray.dev/en/stable/{{path}}#{{fragment}}", "lp": "https://lightning-pose.readthedocs.io/en/stable/{{path}}#{{fragment}}", "via": "https://www.robots.ox.ac.uk/~vgg/software/via/{{path}}#{{fragment}}", } + +intersphinx_mapping = { + "xarray": ("https://docs.xarray.dev/en/stable/", None), + "scipy": ("https://docs.scipy.org/doc/scipy/reference/", None), +} diff --git a/docs/source/environment.yml b/docs/source/environment.yml index 00c7d126e..b84ac374b 100644 --- a/docs/source/environment.yml +++ b/docs/source/environment.yml @@ -3,7 +3,7 @@ channels: - conda-forge dependencies: - - python=3.10 + - python=3.11 - pytables - pip: - movement diff --git a/docs/source/getting_started/installation.md b/docs/source/getting_started/installation.md index d19379442..404acceb1 100644 --- a/docs/source/getting_started/installation.md +++ b/docs/source/getting_started/installation.md @@ -1,68 +1,62 @@ (target-installation)= # Installation -## Create a conda environment - +## Install the package :::{admonition} Use a conda environment :class: note -We recommend you install movement inside a [conda](conda:) -or [mamba](mamba:) environment, to avoid dependency conflicts with other packages. -In the following we assume you have `conda` installed, -but the same commands will also work with `mamba`/`micromamba`. +To avoid dependency conflicts with other packages, it is best practice to install Python packages within a virtual environment. +We recommend using [conda](conda:) or [mamba](mamba:) to create and manage this environment, as they simplify the installation process. +The following instructions assume that you have conda installed, but the same commands will also work with `mamba`/`micromamba`. ::: -First, create and activate an environment with some prerequisites. -You can call your environment whatever you like, we've used `movement-env`. +### Users +To install movement in a new environment, follow one of the options below. +We will use `movement-env` as the environment name, but you can choose any name you prefer. +::::{tab-set} +:::{tab-item} Conda +Create and activate an environment with movement installed: ```sh -conda create -n movement-env -c conda-forge python=3.11 pytables +conda create -n movement-env -c conda-forge movement conda activate movement-env ``` - -## Install the package - -Then install the `movement` package as described below. - -::::{tab-set} - -:::{tab-item} Users -To get the latest release from PyPI: - +::: +:::{tab-item} Pip +Create and activate an environment with some prerequisites: ```sh -pip install movement +conda create -n movement-env -c conda-forge python=3.11 pytables +conda activate movement-env ``` -If you have an older version of `movement` installed in the same environment, -you can update to the latest version with: - +Install the latest movement release from PyPI: ```sh -pip install --upgrade movement +pip install movement ``` ::: +:::: -:::{tab-item} Developers -To get the latest development version, clone the -[GitHub repository](movement-github:) -and then run from inside the repository: +### Developers +If you are a developer looking to contribute to movement, please refer to our [contributing guide](target-contributing) for detailed setup instructions and guidelines. +## Check the installation +To verify that the installation was successful, run (with `movement-env` activated): ```sh -pip install -e .[dev] # works on most shells -pip install -e '.[dev]' # works on zsh (the default shell on macOS) +movement info ``` +You should see a printout including the version numbers of movement +and some of its dependencies. -This will install the package in editable mode, including all `dev` dependencies. -Please see the [contributing guide](target-contributing) for more information. -::: - -:::: - -## Check the installation - -To verify that the installation was successful, you can run the following -command (with the `movement-env` activated): +## Update the package +To update movement to the latest version, we recommend installing it in a new environment, +as this prevents potential compatibility issues caused by changes in dependency versions. +To uninstall an existing environment named `movement-env`: ```sh -movement info +conda env remove -n movement-env ``` - -You should see a printout including the version numbers of `movement` -and some of its dependencies. +:::{tip} +If you are unsure about the environment name, you can get a list of the environments on your system with: +```sh +conda env list +``` +::: +Once the environment has been removed, you can create a new one following the [installation instructions](#install-the-package) above. diff --git a/docs/source/getting_started/movement_dataset.md b/docs/source/getting_started/movement_dataset.md index 601ccc39c..6b2ef5e9f 100644 --- a/docs/source/getting_started/movement_dataset.md +++ b/docs/source/getting_started/movement_dataset.md @@ -19,7 +19,11 @@ To learn more about `xarray` data structures in general, see the relevant ## Dataset structure -![](../_static/dataset_structure.png) +```{figure} ../_static/dataset_structure.png +:alt: movement dataset structure + +An {class}`xarray.Dataset` is a collection of several data arrays that share some dimensions. The schematic shows the data arrays that make up the `poses` and `bboxes` datasets in `movement`. +``` The structure of a `movement` dataset `ds` can be easily inspected by simply printing it. diff --git a/examples/compute_kinematics.py b/examples/compute_kinematics.py index 8cea0912a..b3fefd4a4 100644 --- a/examples/compute_kinematics.py +++ b/examples/compute_kinematics.py @@ -10,14 +10,13 @@ # Imports # ------- -import numpy as np - # For interactive plots: install ipympl with `pip install ipympl` and uncomment # the following line in your notebook # %matplotlib widget from matplotlib import pyplot as plt from movement import sample_data +from movement.utils.vector import compute_norm # %% # Load sample dataset @@ -105,7 +104,7 @@ # %% # We can also easily plot the components of the position vector against time # using ``xarray``'s built-in plotting methods. We use -# :py:meth:`xarray.DataArray.squeeze` to +# :meth:`xarray.DataArray.squeeze` to # remove the dimension of length 1 from the data (the ``keypoints`` dimension). position.squeeze().plot.line(x="time", row="individuals", aspect=2, size=2.5) plt.gcf().show() @@ -131,7 +130,7 @@ # %% # Notice that we could also compute the displacement (and all the other -# kinematic variables) using the :py:mod:`movement.analysis.kinematics` module: +# kinematic variables) using the :mod:`movement.analysis.kinematics` module: # %% import movement.analysis.kinematics as kin @@ -255,13 +254,12 @@ # mouse along its trajectory. # length of each displacement vector -displacement_vectors_lengths = np.linalg.norm( - displacement.sel(individuals=mouse_name, space=["x", "y"]).squeeze(), - axis=1, +displacement_vectors_lengths = compute_norm( + displacement.sel(individuals=mouse_name) ) -# sum of all displacement vectors -total_displacement = np.sum(displacement_vectors_lengths, axis=0) # in pixels +# sum the lengths of all displacement vectors (in pixels) +total_displacement = displacement_vectors_lengths.sum(dim="time").values[0] print( f"The mouse {mouse_name}'s trajectory is {total_displacement:.2f} " @@ -284,7 +282,7 @@ # %% # We can plot the components of the velocity vector against time # using ``xarray``'s built-in plotting methods. We use -# :py:meth:`xarray.DataArray.squeeze` to +# :meth:`xarray.DataArray.squeeze` to # remove the dimension of length 1 from the data (the ``keypoints`` dimension). velocity.squeeze().plot.line(x="time", row="individuals", aspect=2, size=2.5) @@ -299,14 +297,12 @@ # uses second order central differences. # %% -# We can also visualise the speed, as the norm of the velocity vector: +# We can also visualise the speed, as the magnitude (norm) +# of the velocity vector: fig, axes = plt.subplots(3, 1, sharex=True, sharey=True) for mouse_name, ax in zip(velocity.individuals.values, axes, strict=False): - # compute the norm of the velocity vector for one mouse - speed_one_mouse = np.linalg.norm( - velocity.sel(individuals=mouse_name, space=["x", "y"]).squeeze(), - axis=1, - ) + # compute the magnitude of the velocity vector for one mouse + speed_one_mouse = compute_norm(velocity.sel(individuals=mouse_name)) # plot speed against time ax.plot(speed_one_mouse) ax.set_title(mouse_name) @@ -379,16 +375,12 @@ fig.tight_layout() # %% -# The norm of the acceleration vector is the magnitude of the -# acceleration. -# We can also represent this for each individual. +# The can also represent the magnitude (norm) of the acceleration vector +# for each individual: fig, axes = plt.subplots(3, 1, sharex=True, sharey=True) for mouse_name, ax in zip(accel.individuals.values, axes, strict=False): - # compute norm of the acceleration vector for one mouse - accel_one_mouse = np.linalg.norm( - accel.sel(individuals=mouse_name, space=["x", "y"]).squeeze(), - axis=1, - ) + # compute magnitude of the acceleration vector for one mouse + accel_one_mouse = compute_norm(accel.sel(individuals=mouse_name)) # plot acceleration against time ax.plot(accel_one_mouse) diff --git a/examples/filter_and_interpolate.py b/examples/filter_and_interpolate.py index dbe33044b..71384ca76 100644 --- a/examples/filter_and_interpolate.py +++ b/examples/filter_and_interpolate.py @@ -27,7 +27,7 @@ # Visualise the pose tracks # ------------------------- # Since the data contains only a single wasp, we use -# :py:meth:`xarray.DataArray.squeeze` to remove +# :meth:`xarray.DataArray.squeeze` to remove # the dimension of length 1 from the data (the ``individuals`` dimension). ds.position.squeeze().plot.line( @@ -51,7 +51,7 @@ # it's always a good idea to inspect the actual confidence values in the data. # # Let's first look at a histogram of the confidence scores. As before, we use -# :py:meth:`xarray.DataArray.squeeze` to remove the ``individuals`` dimension +# :meth:`xarray.DataArray.squeeze` to remove the ``individuals`` dimension # from the data. ds.confidence.squeeze().plot.hist(bins=20) @@ -74,7 +74,7 @@ # Filter out points with low confidence # ------------------------------------- # Using the -# :py:meth:`filter_by_confidence()\ +# :meth:`filter_by_confidence()\ # ` # method of the ``move`` accessor, # we can filter out points with confidence scores below a certain threshold. @@ -82,20 +82,20 @@ # provided. # This method will also report the number of NaN values in the dataset before # and after the filtering operation by default (``print_report=True``). -# We will use :py:meth:`xarray.Dataset.update` to update ``ds`` in-place +# We will use :meth:`xarray.Dataset.update` to update ``ds`` in-place # with the filtered ``position``. ds.update({"position": ds.move.filter_by_confidence()}) # %% # .. note:: -# The ``move`` accessor :py:meth:`filter_by_confidence()\ +# The ``move`` accessor :meth:`filter_by_confidence()\ # ` # method is a convenience method that applies -# :py:func:`movement.filtering.filter_by_confidence`, +# :func:`movement.filtering.filter_by_confidence`, # which takes ``position`` and ``confidence`` as arguments. # The equivalent function call using the -# :py:mod:`movement.filtering` module would be: +# :mod:`movement.filtering` module would be: # # .. code-block:: python # @@ -121,7 +121,7 @@ # Interpolate over missing values # ------------------------------- # Using the -# :py:meth:`interpolate_over_time()\ +# :meth:`interpolate_over_time()\ # ` # method of the ``move`` accessor, # we can interpolate over the gaps we've introduced in the pose tracks. @@ -135,13 +135,13 @@ # %% # .. note:: -# The ``move`` accessor :py:meth:`interpolate_over_time()\ +# The ``move`` accessor :meth:`interpolate_over_time()\ # ` # is also a convenience method that applies -# :py:func:`movement.filtering.interpolate_over_time` +# :func:`movement.filtering.interpolate_over_time` # to the ``position`` data variable. # The equivalent function call using the -# :py:mod:`movement.filtering` module would be: +# :mod:`movement.filtering` module would be: # # .. code-block:: python # @@ -176,7 +176,7 @@ # %% # Filtering multiple data variables # --------------------------------- -# All :py:mod:`movement.filtering` functions are available via the +# All :mod:`movement.filtering` functions are available via the # ``move`` accessor. These ``move`` accessor methods operate on the # ``position`` data variable in the dataset ``ds`` by default. # There is also an additional argument ``data_vars`` that allows us to @@ -192,7 +192,7 @@ # in ``ds``, based on the confidence scores, we can specify # ``data_vars=["position", "velocity"]`` in the method call. # As the filtered data variables are returned as a dictionary, we can once -# again use :py:meth:`xarray.Dataset.update` to update ``ds`` in-place +# again use :meth:`xarray.Dataset.update` to update ``ds`` in-place # with the filtered data variables. ds["velocity"] = ds.move.compute_velocity() diff --git a/examples/smooth.py b/examples/smooth.py index f9b969b54..316d94448 100644 --- a/examples/smooth.py +++ b/examples/smooth.py @@ -109,7 +109,7 @@ def plot_raw_and_smooth_timeseries_and_psd( # Smoothing with a median filter # ------------------------------ # Using the -# :py:meth:`median_filter()\ +# :meth:`median_filter()\ # ` # method of the ``move`` accessor, # we apply a rolling window median filter over a 0.1-second window @@ -125,13 +125,13 @@ def plot_raw_and_smooth_timeseries_and_psd( # %% # .. note:: -# The ``move`` accessor :py:meth:`median_filter()\ +# The ``move`` accessor :meth:`median_filter()\ # ` # method is a convenience method that applies -# :py:func:`movement.filtering.median_filter` +# :func:`movement.filtering.median_filter` # to the ``position`` data variable. # The equivalent function call using the -# :py:mod:`movement.filtering` module would be: +# :mod:`movement.filtering` module would be: # # .. code-block:: python # @@ -249,11 +249,11 @@ def plot_raw_and_smooth_timeseries_and_psd( # Smoothing with a Savitzky-Golay filter # -------------------------------------- # Here we use the -# :py:meth:`savgol_filter()\ +# :meth:`savgol_filter()\ # ` # method of the ``move`` accessor, which is a convenience method that applies -# :py:func:`movement.filtering.savgol_filter` -# (a wrapper around :py:func:`scipy.signal.savgol_filter`), +# :func:`movement.filtering.savgol_filter` +# (a wrapper around :func:`scipy.signal.savgol_filter`), # to the ``position`` data variable. # The Savitzky-Golay filter is a polynomial smoothing filter that can be # applied to time series data on a rolling window basis. diff --git a/movement/analysis/kinematics.py b/movement/analysis/kinematics.py index ed826cc1c..ed2b4b30e 100644 --- a/movement/analysis/kinematics.py +++ b/movement/analysis/kinematics.py @@ -1,28 +1,45 @@ """Compute kinematic variables like velocity and acceleration.""" -import numpy as np import xarray as xr from movement.utils.logging import log_error def compute_displacement(data: xr.DataArray) -> xr.DataArray: - """Compute displacement between consecutive positions. + """Compute displacement array in cartesian coordinates. - This is the difference between consecutive positions of each keypoint for - each individual across time. At each time point ``t``, it's defined as a - vector in cartesian ``(x,y)`` coordinates, pointing from the previous - ``(t-1)`` to the current ``(t)`` position. + The displacement array is defined as the difference between the position + array at time point ``t`` and the position array at time point ``t-1``. + + As a result, for a given individual and keypoint, the displacement vector + at time point ``t``, is the vector pointing from the previous + ``(t-1)`` to the current ``(t)`` position, in cartesian coordinates. Parameters ---------- data : xarray.DataArray - The input data containing ``time`` as a dimension. + The input data array containing position vectors in cartesian + coordinates, with ``time`` as a dimension. Returns ------- xarray.DataArray - An xarray DataArray containing the computed displacement. + An xarray DataArray containing displacement vectors in cartesian + coordinates. + + Notes + ----- + For the ``position`` array of a ``poses`` dataset, the ``displacement`` + array will hold the displacement vectors for every keypoint and every + individual. + + For the ``position`` array of a ``bboxes`` dataset, the ``displacement`` + array will hold the displacement vectors for the centroid of every + individual bounding box. + + For the ``shape`` array of a ``bboxes`` dataset, the + ``displacement`` array will hold vectors with the change in width and + height per bounding box, between consecutive time points. """ _validate_time_dimension(data) @@ -32,66 +49,101 @@ def compute_displacement(data: xr.DataArray) -> xr.DataArray: def compute_velocity(data: xr.DataArray) -> xr.DataArray: - """Compute the velocity in cartesian ``(x,y)`` coordinates. + """Compute velocity array in cartesian coordinates. - Velocity is the first derivative of position for each keypoint - and individual across time. It's computed using numerical differentiation - and assumes equidistant time spacing. + The velocity array is the first time-derivative of the position + array. It is computed by applying the second-order accurate central + differences method on the position array. Parameters ---------- data : xarray.DataArray - The input data containing ``time`` as a dimension. + The input data array containing position vectors in cartesian + coordinates, with ``time`` as a dimension. Returns ------- xarray.DataArray - An xarray DataArray containing the computed velocity. + An xarray DataArray containing velocity vectors in cartesian + coordinates. + + Notes + ----- + For the ``position`` array of a ``poses`` dataset, the ``velocity`` array + will hold the velocity vectors for every keypoint and every individual. + + For the ``position`` array of a ``bboxes`` dataset, the ``velocity`` array + will hold the velocity vectors for the centroid of every individual + bounding box. + + See Also + -------- + :meth:`xarray.DataArray.differentiate` : The underlying method used. """ - return _compute_approximate_derivative(data, order=1) + return _compute_approximate_time_derivative(data, order=1) def compute_acceleration(data: xr.DataArray) -> xr.DataArray: - """Compute acceleration in cartesian ``(x,y)`` coordinates. + """Compute acceleration array in cartesian coordinates. - Acceleration represents the second derivative of position for each keypoint - and individual across time. It's computed using numerical differentiation - and assumes equidistant time spacing. + The acceleration array is the second time-derivative of the + position array. It is computed by applying the second-order accurate + central differences method on the velocity array. Parameters ---------- data : xarray.DataArray - The input data containing ``time`` as a dimension. + The input data array containing position vectors in cartesian + coordinates, with``time`` as a dimension. Returns ------- xarray.DataArray - An xarray DataArray containing the computed acceleration. + An xarray DataArray containing acceleration vectors in cartesian + coordinates. + + Notes + ----- + For the ``position`` array of a ``poses`` dataset, the ``acceleration`` + array will hold the acceleration vectors for every keypoint and every + individual. + + For the ``position`` array of a ``bboxes`` dataset, the ``acceleration`` + array will hold the acceleration vectors for the centroid of every + individual bounding box. + + See Also + -------- + :meth:`xarray.DataArray.differentiate` : The underlying method used. """ - return _compute_approximate_derivative(data, order=2) + return _compute_approximate_time_derivative(data, order=2) -def _compute_approximate_derivative( +def _compute_approximate_time_derivative( data: xr.DataArray, order: int ) -> xr.DataArray: - """Compute the derivative using numerical differentiation. + """Compute the time-derivative of an array using numerical differentiation. - This assumes equidistant time spacing. + This function uses :meth:`xarray.DataArray.differentiate`, + which differentiates the array with the second-order + accurate central differences method. Parameters ---------- data : xarray.DataArray - The input data containing ``time`` as a dimension. + The input data array containing ``time`` as a dimension. order : int - The order of the derivative. 1 for velocity, 2 for - acceleration. Value must be a positive integer. + The order of the time-derivative. For an input containing position + data, use 1 to compute velocity, and 2 to compute acceleration. Value + must be a positive integer. Returns ------- xarray.DataArray - An xarray DataArray containing the derived variable. + An xarray DataArray containing the time-derivative of the + input data. """ if not isinstance(order, int): @@ -100,17 +152,12 @@ def _compute_approximate_derivative( ) if order <= 0: raise log_error(ValueError, "Order must be a positive integer.") + _validate_time_dimension(data) + result = data - dt = data["time"].values[1] - data["time"].values[0] for _ in range(order): - result = xr.apply_ufunc( - np.gradient, - result, - dt, - kwargs={"axis": 0}, - ) - result = result.reindex_like(data) + result = result.differentiate("time") return result @@ -124,11 +171,11 @@ def _validate_time_dimension(data: xr.DataArray) -> None: Raises ------ - AttributeError + ValueError If the input data does not contain a ``time`` dimension. """ if "time" not in data.dims: raise log_error( - AttributeError, "Input data must contain 'time' as a dimension." + ValueError, "Input data must contain 'time' as a dimension." ) diff --git a/movement/filtering.py b/movement/filtering.py index 573e8635d..8432c1335 100644 --- a/movement/filtering.py +++ b/movement/filtering.py @@ -1,4 +1,4 @@ -"""Filter and interpolate pose tracks in ``movement`` datasets.""" +"""Filter and interpolate tracks in ``movement`` datasets.""" import xarray as xr from scipy import signal @@ -40,14 +40,14 @@ def filter_by_confidence( Notes ----- - The point-wise confidence values reported by various pose estimation - frameworks are not standardised, and the range of values can vary. - For example, DeepLabCut reports a likelihood value between 0 and 1, whereas - the point confidence reported by SLEAP can range above 1. - Therefore, the default threshold value will not be appropriate for all - datasets and does not have the same meaning across pose estimation - frameworks. We advise users to inspect the confidence values - in their dataset and adjust the threshold accordingly. + For the poses dataset case, note that the point-wise confidence values + reported by various pose estimation frameworks are not standardised, and + the range of values can vary. For example, DeepLabCut reports a likelihood + value between 0 and 1, whereas the point confidence reported by SLEAP can + range above 1. Therefore, the default threshold value will not be + appropriate for all datasets and does not have the same meaning across + pose estimation frameworks. We advise users to inspect the confidence + values in their dataset and adjust the threshold accordingly. """ data_filtered = data.where(confidence >= threshold) @@ -66,7 +66,7 @@ def interpolate_over_time( ) -> xr.DataArray: """Fill in NaN values by interpolating over the ``time`` dimension. - This method uses :py:meth:`xarray.DataArray.interpolate_na` under the + This method uses :meth:`xarray.DataArray.interpolate_na` under the hood and passes the ``method`` and ``max_gap`` parameters to it. See the xarray documentation for more details on these parameters. @@ -88,14 +88,14 @@ def interpolate_over_time( Returns ------- - xr.DataArray + xarray.DataArray The data where NaN values have been interpolated over using the parameters provided. Notes ----- The ``max_gap`` parameter differs slightly from that in - :py:meth:`xarray.DataArray.interpolate_na`, in which the gap size + :meth:`xarray.DataArray.interpolate_na`, in which the gap size is defined as the difference between the ``time`` coordinate values at the first data point after a gap and the last value before a gap. @@ -127,17 +127,17 @@ def median_filter( data : xarray.DataArray The input data to be smoothed. window : int - The size of the filter window, representing the fixed number + The size of the smoothing window, representing the fixed number of observations used for each window. min_periods : int Minimum number of observations in the window required to have a value (otherwise result is NaN). The default, None, is equivalent to setting ``min_periods`` equal to the size of the window. This argument is directly passed to the ``min_periods`` parameter of - :py:meth:`xarray.DataArray.rolling`. + :meth:`xarray.DataArray.rolling`. print_report : bool Whether to print a report on the number of NaNs in the dataset - before and after filtering. Default is ``True``. + before and after smoothing. Default is ``True``. Returns ------- @@ -146,7 +146,7 @@ def median_filter( Notes ----- - By default, whenever one or more NaNs are present in the filter window, + By default, whenever one or more NaNs are present in the smoothing window, a NaN is returned to the output array. As a result, any stretch of NaNs present in the input data will be propagated proportionally to the size of the window (specifically, by @@ -194,7 +194,7 @@ def savgol_filter( data : xarray.DataArray The input data to be smoothed. window : int - The size of the filter window, representing the fixed number + The size of the smoothing window, representing the fixed number of observations used for each window. polyorder : int The order of the polynomial used to fit the samples. Must be @@ -202,10 +202,10 @@ def savgol_filter( 2 is used. print_report : bool Whether to print a report on the number of NaNs in the dataset - before and after filtering. Default is ``True``. + before and after smoothing. Default is ``True``. **kwargs : dict Additional keyword arguments are passed to - :py:func:`scipy.signal.savgol_filter`. + :func:`scipy.signal.savgol_filter`. Note that the ``axis`` keyword argument may not be overridden. @@ -217,15 +217,15 @@ def savgol_filter( Notes ----- - Uses the :py:func:`scipy.signal.savgol_filter` function to apply a + Uses the :func:`scipy.signal.savgol_filter` function to apply a Savitzky-Golay filter to the input data. See the SciPy documentation for more information on that function. - Whenever one or more NaNs are present in a filter window of the + Whenever one or more NaNs are present in a smoothing window of the input data, a NaN is returned to the output array. As a result, any stretch of NaNs present in the input data will be propagated proportionally to the size of the window (specifically, by ``floor(window/2)``). Note that, unlike - :py:func:`movement.filtering.median_filter()`, there is no ``min_periods`` + :func:`movement.filtering.median_filter`, there is no ``min_periods`` option to control this behaviour. """ diff --git a/movement/io/load_bboxes.py b/movement/io/load_bboxes.py index 9c1772308..8550a2e89 100644 --- a/movement/io/load_bboxes.py +++ b/movement/io/load_bboxes.py @@ -35,19 +35,19 @@ def from_numpy( position_array : np.ndarray Array of shape (n_frames, n_individuals, n_space) containing the tracks of the bounding boxes' centroids. - It will be converted to a :py:class:`xarray.DataArray` object + It will be converted to a :class:`xarray.DataArray` object named "position". shape_array : np.ndarray Array of shape (n_frames, n_individuals, n_space) containing the shape of the bounding boxes. The shape of a bounding box is its width (extent along the x-axis of the image) and height (extent along the y-axis of the image). It will be converted to a - :py:class:`xarray.DataArray` object named "shape". + :class:`xarray.DataArray` object named "shape". confidence_array : np.ndarray, optional Array of shape (n_frames, n_individuals) containing the confidence scores of the bounding boxes. If None (default), the confidence scores are set to an array of NaNs. It will be converted - to a :py:class:`xarray.DataArray` object named "confidence". + to a :class:`xarray.DataArray` object named "confidence". individual_names : list of str, optional List of individual names for the tracked bounding boxes in the video. If None (default), bounding boxes are assigned names based on the size @@ -402,6 +402,11 @@ def _numpy_arrays_from_via_tracks_file(file_path: Path) -> dict: array_dict[key] = np.stack(list_arrays, axis=1).squeeze() + # Transform position_array to represent centroid of bbox, + # rather than top-left corner + # (top left corner: corner of the bbox with minimum x and y coordinates) + array_dict["position_array"] += array_dict["shape_array"] / 2 + # Add remaining arrays to dict array_dict["ID_array"] = df["ID"].unique().reshape(-1, 1) array_dict["frame_array"] = df["frame_number"].unique().reshape(-1, 1) @@ -415,14 +420,16 @@ def _df_from_via_tracks_file(file_path: Path) -> pd.DataFrame: Read the VIA tracks .csv file as a pandas dataframe with columns: - ID: the integer ID of the tracked bounding box. - frame_number: the frame number of the tracked bounding box. - - x: the x-coordinate of the tracked bounding box centroid. - - y: the y-coordinate of the tracked bounding box centroid. + - x: the x-coordinate of the tracked bounding box's top-left corner. + - y: the y-coordinate of the tracked bounding box's top-left corner. - w: the width of the tracked bounding box. - h: the height of the tracked bounding box. - confidence: the confidence score of the tracked bounding box. The dataframe is sorted by ID and frame number, and for each ID, - empty frames are filled in with NaNs. + empty frames are filled in with NaNs. The coordinates of the bboxes + are assumed to be in the image coordinate system (i.e., the top-left + corner of a bbox is its corner with minimum x and y coordinates). """ # Read VIA tracks .csv file as a pandas dataframe df_file = pd.read_csv(file_path, sep=",", header=0) diff --git a/movement/io/load_poses.py b/movement/io/load_poses.py index 6ae2f9e59..2b1a25d87 100644 --- a/movement/io/load_poses.py +++ b/movement/io/load_poses.py @@ -34,11 +34,11 @@ def from_numpy( position_array : np.ndarray Array of shape (n_frames, n_individuals, n_keypoints, n_space) containing the poses. It will be converted to a - :py:class:`xarray.DataArray` object named "position". + :class:`xarray.DataArray` object named "position". confidence_array : np.ndarray, optional Array of shape (n_frames, n_individuals, n_keypoints) containing the point-wise confidence scores. It will be converted to a - :py:class:`xarray.DataArray` object named "confidence". + :class:`xarray.DataArray` object named "confidence". If None (default), the scores will be set to an array of NaNs. individual_names : list of str, optional List of unique names for the individuals in the video. If None diff --git a/movement/io/save_poses.py b/movement/io/save_poses.py index 07ba0bd0e..bc2c0e1cd 100644 --- a/movement/io/save_poses.py +++ b/movement/io/save_poses.py @@ -241,7 +241,7 @@ def to_lp_file( ----- LightningPose saves pose estimation outputs as .csv files, using the same format as single-animal DeepLabCut projects. Therefore, under the hood, - this function calls :py:func:`movement.io.save_poses.to_dlc_file` + this function calls :func:`movement.io.save_poses.to_dlc_file` with ``split_individuals=True``. This setting means that each individual is saved to a separate file, with the individual's name appended to the file path, just before the file extension, diff --git a/movement/move_accessor.py b/movement/move_accessor.py index 933884174..64b176519 100644 --- a/movement/move_accessor.py +++ b/movement/move_accessor.py @@ -1,4 +1,4 @@ -"""Accessor for extending :py:class:`xarray.Dataset` objects.""" +"""Accessor for extending :class:`xarray.Dataset` objects.""" import logging from typing import ClassVar @@ -18,9 +18,9 @@ @xr.register_dataset_accessor("move") class MovementDataset: - """An :py:class:`xarray.Dataset` accessor for ``movement`` data. + """An :class:`xarray.Dataset` accessor for ``movement`` data. - A ``movement`` dataset is an :py:class:`xarray.Dataset` with a specific + A ``movement`` dataset is an :class:`xarray.Dataset` with a specific structure to represent pose tracks or bounding boxes data, associated confidence scores and relevant metadata. @@ -66,8 +66,8 @@ def __getattr__(self, name: str) -> xr.DataArray: This method currently only forwards kinematic property computation and filtering operations to the respective functions in - :py:mod:`movement.analysis.kinematics` and - :py:mod:`movement.filtering`. + :mod:`movement.analysis.kinematics` and + :mod:`movement.filtering`. Parameters ---------- @@ -106,7 +106,7 @@ def kinematics_wrapper( """Provide convenience method for computing kinematic properties. This method forwards kinematic property computation - to the respective functions in :py:mod:`movement.analysis.kinematics`. + to the respective functions in :mod:`movement.analysis.kinematics`. Parameters ---------- @@ -161,7 +161,7 @@ def filtering_wrapper( """Provide convenience method for filtering data variables. This method forwards filtering and/or smoothing to the respective - functions in :py:mod:`movement.filtering`. The data variables to + functions in :mod:`movement.filtering`. The data variables to filter can be specified in ``data_vars``. If ``data_vars`` is not specified, the ``position`` data variable is selected by default. diff --git a/movement/utils/logging.py b/movement/utils/logging.py index 14add0a44..0174e5fff 100644 --- a/movement/utils/logging.py +++ b/movement/utils/logging.py @@ -113,8 +113,8 @@ def log_to_attrs(func): """Log the operation performed by the wrapped function. This decorator appends log entries to the data's ``log`` - attribute. The wrapped function must accept an :py:class:`xarray.Dataset` - or :py:class:`xarray.DataArray` as its first argument and return an + attribute. The wrapped function must accept an :class:`xarray.Dataset` + or :class:`xarray.DataArray` as its first argument and return an object of the same type. """ diff --git a/movement/utils/vector.py b/movement/utils/vector.py index c35990ebe..0d5d88c83 100644 --- a/movement/utils/vector.py +++ b/movement/utils/vector.py @@ -6,6 +6,93 @@ from movement.utils.logging import log_error +def compute_norm(data: xr.DataArray) -> xr.DataArray: + """Compute the norm of the vectors along the spatial dimension. + + The norm of a vector is its magnitude, also called Euclidean norm, 2-norm + or Euclidean length. Note that if the input data is expressed in polar + coordinates, the magnitude of a vector is the same as its radial coordinate + ``rho``. + + Parameters + ---------- + data : xarray.DataArray + The input data array containing either ``space`` or ``space_pol`` + as a dimension. + + Returns + ------- + xarray.DataArray + A data array holding the norm of the input vectors. + Note that this output array has no spatial dimension but preserves + all other dimensions of the input data array (see Notes). + + Notes + ----- + If the input data array is a ``position`` array, this function will compute + the magnitude of the position vectors, for every individual and keypoint, + at every timestep. If the input data array is a ``shape`` array of a + bounding boxes dataset, it will compute the magnitude of the shape + vectors (i.e., the diagonal of the bounding box), + for every individual and at every timestep. + + + """ + if "space" in data.dims: + _validate_dimension_coordinates(data, {"space": ["x", "y"]}) + return xr.apply_ufunc( + np.linalg.norm, + data, + input_core_dims=[["space"]], + kwargs={"axis": -1}, + ) + elif "space_pol" in data.dims: + _validate_dimension_coordinates(data, {"space_pol": ["rho", "phi"]}) + return data.sel(space_pol="rho", drop=True) + else: + _raise_error_for_missing_spatial_dim() + + +def convert_to_unit(data: xr.DataArray) -> xr.DataArray: + """Convert the vectors along the spatial dimension into unit vectors. + + A unit vector is a vector pointing in the same direction as the original + vector but with norm = 1. + + Parameters + ---------- + data : xarray.DataArray + The input data array containing either ``space`` or ``space_pol`` + as a dimension. + + Returns + ------- + xarray.DataArray + A data array holding the unit vectors of the input data array + (all input dimensions are preserved). + + Notes + ----- + Note that the unit vector for the null vector is undefined, since the null + vector has 0 norm and no direction associated with it. + + """ + if "space" in data.dims: + _validate_dimension_coordinates(data, {"space": ["x", "y"]}) + return data / compute_norm(data) + elif "space_pol" in data.dims: + _validate_dimension_coordinates(data, {"space_pol": ["rho", "phi"]}) + # Set both rho and phi values to NaN at null vectors (where rho = 0) + new_data = xr.where(data.sel(space_pol="rho") == 0, np.nan, data) + # Set the rho values to 1 for non-null vectors (phi is preserved) + new_data.loc[{"space_pol": "rho"}] = xr.where( + new_data.sel(space_pol="rho").isnull(), np.nan, 1 + ) + return new_data + else: + _raise_error_for_missing_spatial_dim() + + def cart2pol(data: xr.DataArray) -> xr.DataArray: """Transform Cartesian coordinates to polar. @@ -25,12 +112,7 @@ def cart2pol(data: xr.DataArray) -> xr.DataArray: """ _validate_dimension_coordinates(data, {"space": ["x", "y"]}) - rho = xr.apply_ufunc( - np.linalg.norm, - data, - input_core_dims=[["space"]], - kwargs={"axis": -1}, - ) + rho = compute_norm(data) phi = xr.apply_ufunc( np.arctan2, data.sel(space="y"), @@ -122,3 +204,11 @@ def _validate_dimension_coordinates( ) if error_message: raise log_error(ValueError, error_message) + + +def _raise_error_for_missing_spatial_dim() -> None: + raise log_error( + ValueError, + "Input data array must contain either 'space' or 'space_pol' " + "as dimensions.", + ) diff --git a/pyproject.toml b/pyproject.toml index b552fd420..27348c291 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,6 @@ license = { text = "BSD-3-Clause" } dependencies = [ "numpy", - "pandas<2.2.2;python_version>='3.12'", "pandas", "h5py", "attrs", diff --git a/tests/conftest.py b/tests/conftest.py index 77c8c73ec..272e5eaa8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -39,6 +39,7 @@ def setup_logging(tmp_path): ) +# --------- File validator fixtures --------------------------------- @pytest.fixture def unreadable_file(tmp_path): """Return a dictionary containing the file path and @@ -213,13 +214,42 @@ def sleap_file(request): return pytest.DATA_PATHS.get(request.param) +# ------------ Dataset validator fixtures --------------------------------- + + @pytest.fixture -def valid_bboxes_array(): - """Return a dictionary of valid non-zero arrays for a +def valid_bboxes_arrays_all_zeros(): + """Return a dictionary of valid zero arrays (in terms of shape) for a ValidBboxesDataset. + """ + # define the shape of the arrays + n_frames, n_individuals, n_space = (10, 2, 2) + + # build a valid array for position or shape with all zeros + valid_bbox_array_all_zeros = np.zeros((n_frames, n_individuals, n_space)) + + # return as a dict + return { + "position": valid_bbox_array_all_zeros, + "shape": valid_bbox_array_all_zeros, + "individual_names": ["id_" + str(id) for id in range(n_individuals)], + } + + +# --------------------- Bboxes dataset fixtures ---------------------------- +@pytest.fixture +def valid_bboxes_arrays(): + """Return a dictionary of valid arrays for a + ValidBboxesDataset representing a uniform linear motion. + + It represents 2 individuals for 10 frames, in 2D space. + - Individual 0 moves along the x=y line from the origin. + - Individual 1 moves along the x=-y line line from the origin. - Contains realistic data for 10 frames, 2 individuals, in 2D - with 5 low confidence bounding boxes. + All confidence values are set to 0.9 except the following which are set + to 0.1: + - Individual 0 at frames 2, 3, 4 + - Individual 1 at frames 2, 3 """ # define the shape of the arrays n_frames, n_individuals, n_space = (10, 2, 2) @@ -252,22 +282,21 @@ def valid_bboxes_array(): "position": position, "shape": shape, "confidence": confidence, - "individual_names": ["id_" + str(id) for id in range(n_individuals)], } @pytest.fixture def valid_bboxes_dataset( - valid_bboxes_array, + valid_bboxes_arrays, ): - """Return a valid bboxes dataset with low confidence values and - time in frames. + """Return a valid bboxes dataset for two individuals moving in uniform + linear motion, with 5 frames with low confidence values and time in frames. """ dim_names = MovementDataset.dim_names["bboxes"] - position_array = valid_bboxes_array["position"] - shape_array = valid_bboxes_array["shape"] - confidence_array = valid_bboxes_array["confidence"] + position_array = valid_bboxes_arrays["position"] + shape_array = valid_bboxes_arrays["shape"] + confidence_array = valid_bboxes_arrays["confidence"] n_frames, n_individuals, _ = position_array.shape @@ -315,6 +344,7 @@ def valid_bboxes_dataset_with_nan(valid_bboxes_dataset): return valid_bboxes_dataset +# --------------------- Poses dataset fixtures ---------------------------- @pytest.fixture def valid_position_array(): """Return a function that generates different kinds @@ -384,12 +414,111 @@ def valid_poses_dataset(valid_position_array, request): @pytest.fixture def valid_poses_dataset_with_nan(valid_poses_dataset): """Return a valid pose tracks dataset with NaN values.""" + # Sets position for all keypoints in individual ind1 to NaN + # at timepoints 3, 7, 8 valid_poses_dataset.position.loc[ {"individuals": "ind1", "time": [3, 7, 8]} ] = np.nan return valid_poses_dataset +@pytest.fixture +def valid_poses_array_uniform_linear_motion(): + """Return a dictionary of valid arrays for a + ValidPosesDataset representing a uniform linear motion. + + It represents 2 individuals with 3 keypoints, for 10 frames, in 2D space. + - Individual 0 moves along the x=y line from the origin. + - Individual 1 moves along the x=-y line line from the origin. + + All confidence values for all keypoints are set to 0.9 except + for the keypoints at the following frames which are set to 0.1: + - Individual 0 at frames 2, 3, 4 + - Individual 1 at frames 2, 3 + """ + # define the shape of the arrays + n_frames, n_individuals, n_keypoints, n_space = (10, 2, 3, 2) + + # define centroid (index=0) trajectory in position array + # for each individual, the centroid moves along + # the x=+/-y line, starting from the origin. + # - individual 0 moves along x = y line + # - individual 1 moves along x = -y line + # They move one unit along x and y axes in each frame + frames = np.arange(n_frames) + position = np.empty((n_frames, n_individuals, n_keypoints, n_space)) + position[:, :, 0, 0] = frames[:, None] # reshape to (n_frames, 1) + position[:, 0, 0, 1] = frames + position[:, 1, 0, 1] = -frames + + # define trajectory of left and right keypoints + # for individual 0, at each timepoint: + # - the left keypoint (index=1) is at x_centroid, y_centroid + 1 + # - the right keypoint (index=2) is at x_centroid + 1, y_centroid + # for individual 1, at each timepoint: + # - the left keypoint (index=1) is at x_centroid - 1, y_centroid + # - the right keypoint (index=2) is at x_centroid, y_centroid + 1 + offsets = [ + [(0, 1), (1, 0)], # individual 0: left, right keypoints (x,y) offsets + [(-1, 0), (0, 1)], # individual 1: left, right keypoints (x,y) offsets + ] + for i in range(n_individuals): + for kpt in range(1, n_keypoints): + position[:, i, kpt, 0] = ( + position[:, i, 0, 0] + offsets[i][kpt - 1][0] + ) + position[:, i, kpt, 1] = ( + position[:, i, 0, 1] + offsets[i][kpt - 1][1] + ) + + # build an array of confidence values, all 0.9 + confidence = np.full((n_frames, n_individuals, n_keypoints), 0.9) + # set 5 low-confidence values + # - set 3 confidence values for individual id_0's centroid to 0.1 + # - set 2 confidence values for individual id_1's centroid to 0.1 + idx_start = 2 + confidence[idx_start : idx_start + 3, 0, 0] = 0.1 + confidence[idx_start : idx_start + 2, 1, 0] = 0.1 + + return {"position": position, "confidence": confidence} + + +@pytest.fixture +def valid_poses_dataset_uniform_linear_motion( + valid_poses_array_uniform_linear_motion, +): + """Return a valid poses dataset for two individuals moving in uniform + linear motion, with 5 frames with low confidence values and time in frames. + """ + dim_names = MovementDataset.dim_names["poses"] + + position_array = valid_poses_array_uniform_linear_motion["position"] + confidence_array = valid_poses_array_uniform_linear_motion["confidence"] + + n_frames, n_individuals, _, _ = position_array.shape + + return xr.Dataset( + data_vars={ + "position": xr.DataArray(position_array, dims=dim_names), + "confidence": xr.DataArray(confidence_array, dims=dim_names[:-1]), + }, + coords={ + dim_names[0]: np.arange(n_frames), + dim_names[1]: [f"id_{i}" for i in range(1, n_individuals + 1)], + dim_names[2]: ["centroid", "left", "right"], + dim_names[3]: ["x", "y"], + }, + attrs={ + "fps": None, + "time_unit": "frames", + "source_software": "test", + "source_file": "test_poses.h5", + "ds_type": "poses", + }, + ) + + +# -------------------- Invalid datasets fixtures ------------------------------ @pytest.fixture def not_a_dataset(): """Return data that is not a pose tracks dataset.""" @@ -444,7 +573,7 @@ def kinematic_property(request): return request.param -# VIA tracks CSV fixtures +# ---------------- VIA tracks CSV file fixtures ---------------------------- @pytest.fixture def via_tracks_csv_with_invalid_header(tmp_path): """Return the file path for a file with invalid header.""" @@ -705,6 +834,9 @@ def count_consecutive_nans(da): return (da.isnull().astype(int).diff("time") == 1).sum().item() +# ----------------- Helper fixture ----------------- + + @pytest.fixture def helpers(): """Return an instance of the ``Helpers`` class.""" diff --git a/tests/test_unit/test_filtering.py b/tests/test_unit/test_filtering.py index 0336f0c14..4b4002874 100644 --- a/tests/test_unit/test_filtering.py +++ b/tests/test_unit/test_filtering.py @@ -10,115 +10,233 @@ savgol_filter, ) +# Dataset fixtures +list_valid_datasets_without_nans = [ + "valid_poses_dataset", + "valid_bboxes_dataset", +] +list_valid_datasets_with_nans = [ + f"{dataset}_with_nan" for dataset in list_valid_datasets_without_nans +] +list_all_valid_datasets = ( + list_valid_datasets_without_nans + list_valid_datasets_with_nans +) + @pytest.mark.parametrize( - "max_gap, expected_n_nans", [(None, 0), (1, 8), (2, 0)] + "valid_dataset_with_nan", + list_valid_datasets_with_nans, +) +@pytest.mark.parametrize( + "max_gap, expected_n_nans_in_position", [(None, 0), (0, 3), (1, 2), (2, 0)] ) -def test_interpolate_over_time( - valid_poses_dataset_with_nan, helpers, max_gap, expected_n_nans +def test_interpolate_over_time_on_position( + valid_dataset_with_nan, + max_gap, + expected_n_nans_in_position, + helpers, + request, ): - """Test that the number of NaNs decreases after interpolating + """Test that the number of NaNs decreases after linearly interpolating over time and that the resulting number of NaNs is as expected for different values of ``max_gap``. """ - # First dataset with time unit in frames - data_in_frames = valid_poses_dataset_with_nan.position - # Create second dataset with time unit in seconds - data_in_seconds = data_in_frames.copy() - data_in_seconds["time"] = data_in_seconds["time"] * 0.1 - data_interp_frames = interpolate_over_time(data_in_frames, max_gap=max_gap) - data_interp_seconds = interpolate_over_time( - data_in_seconds, max_gap=max_gap + valid_dataset_in_frames = request.getfixturevalue(valid_dataset_with_nan) + + # Get position array with time unit in frames & seconds + # assuming 10 fps = 0.1 s per frame + valid_dataset_in_seconds = valid_dataset_in_frames.copy() + valid_dataset_in_seconds.coords["time"] = ( + valid_dataset_in_seconds.coords["time"] * 0.1 ) - n_nans_before = helpers.count_nans(data_in_frames) - n_nans_after_frames = helpers.count_nans(data_interp_frames) - n_nans_after_seconds = helpers.count_nans(data_interp_seconds) + position = { + "frames": valid_dataset_in_frames.position, + "seconds": valid_dataset_in_seconds.position, + } + + # Count number of NaNs before and after interpolating position + n_nans_before = helpers.count_nans(position["frames"]) + n_nans_after_per_time_unit = {} + for time_unit in ["frames", "seconds"]: + # interpolate + position_interp = interpolate_over_time( + position[time_unit], method="linear", max_gap=max_gap + ) + # count nans + n_nans_after_per_time_unit[time_unit] = helpers.count_nans( + position_interp + ) + # The number of NaNs should be the same for both datasets # as max_gap is based on number of missing observations (NaNs) - assert n_nans_after_frames == n_nans_after_seconds - assert n_nans_after_frames < n_nans_before - assert n_nans_after_frames == expected_n_nans + assert ( + n_nans_after_per_time_unit["frames"] + == n_nans_after_per_time_unit["seconds"] + ) + + # The number of NaNs should decrease after interpolation + n_nans_after = n_nans_after_per_time_unit["frames"] + if max_gap == 0: + assert n_nans_after == n_nans_before + else: + assert n_nans_after < n_nans_before + + # The number of NaNs after interpolating should be as expected + assert n_nans_after == ( + valid_dataset_in_frames.sizes["space"] + * valid_dataset_in_frames.sizes.get("keypoints", 1) + # in bboxes dataset there is no keypoints dimension + * expected_n_nans_in_position + ) -def test_filter_by_confidence(valid_poses_dataset, helpers): +@pytest.mark.parametrize( + "valid_dataset_no_nans, n_low_confidence_kpts", + [ + ("valid_poses_dataset", 20), + ("valid_bboxes_dataset", 5), + ], +) +def test_filter_by_confidence_on_position( + valid_dataset_no_nans, n_low_confidence_kpts, helpers, request +): """Test that points below the default 0.6 confidence threshold are converted to NaN. """ - data = valid_poses_dataset.position - confidence = valid_poses_dataset.confidence - data_filtered = filter_by_confidence(data, confidence) - n_nans = helpers.count_nans(data_filtered) - assert isinstance(data_filtered, xr.DataArray) - # 5 timepoints * 2 individuals * 2 keypoints * 2 space dimensions - # have confidence below 0.6 - assert n_nans == 40 - - -@pytest.mark.parametrize("window_size", [2, 4]) -def test_median_filter(valid_poses_dataset_with_nan, window_size): - """Test that applying the median filter returns - a different xr.DataArray than the input data. - """ - data = valid_poses_dataset_with_nan.position - data_smoothed = median_filter(data, window_size) - del data_smoothed.attrs["log"] - assert isinstance(data_smoothed, xr.DataArray) and not ( - data_smoothed.equals(data) + # Filter position by confidence + valid_input_dataset = request.getfixturevalue(valid_dataset_no_nans) + position_filtered = filter_by_confidence( + valid_input_dataset.position, + confidence=valid_input_dataset.confidence, + threshold=0.6, ) + # Count number of NaNs in the full array + n_nans = helpers.count_nans(position_filtered) -def test_median_filter_with_nans(valid_poses_dataset_with_nan, helpers): - """Test NaN behaviour of the median filter. The input data - contains NaNs in all keypoints of the first individual at timepoints - 3, 7, and 8 (0-indexed, 10 total timepoints). The median filter - should propagate NaNs within the windows of the filter, - but it should not introduce any NaNs for the second individual. - """ - data = valid_poses_dataset_with_nan.position - data_smoothed = median_filter(data, window=3) - # All points of the first individual are converted to NaNs except - # at timepoints 0, 1, and 5. - assert not ( - data_smoothed.isel(individuals=0, time=[0, 1, 5]).isnull().any() - ) - # 7 timepoints * 1 individual * 2 keypoints * 2 space dimensions - assert helpers.count_nans(data_smoothed) == 28 - # No NaNs should be introduced for the second individual - assert not data_smoothed.isel(individuals=1).isnull().any() + # expected number of nans for poses: + # 5 timepoints * 2 individuals * 2 keypoints + # Note: we count the number of nans in the array, so we multiply + # the number of low confidence keypoints by the number of + # space dimensions + assert isinstance(position_filtered, xr.DataArray) + assert n_nans == valid_input_dataset.sizes["space"] * n_low_confidence_kpts -@pytest.mark.parametrize("window, polyorder", [(2, 1), (4, 2)]) -def test_savgol_filter(valid_poses_dataset_with_nan, window, polyorder): - """Test that applying the Savitzky-Golay filter returns - a different xr.DataArray than the input data. +@pytest.mark.parametrize( + "valid_dataset", + list_all_valid_datasets, +) +@pytest.mark.parametrize( + ("filter_func, filter_kwargs"), + [ + (median_filter, {"window": 2}), + (median_filter, {"window": 4}), + (savgol_filter, {"window": 2, "polyorder": 1}), + (savgol_filter, {"window": 4, "polyorder": 2}), + ], +) +def test_filter_on_position( + filter_func, filter_kwargs, valid_dataset, request +): + """Test that applying a filter to the position data returns + a different xr.DataArray than the input position data. """ - data = valid_poses_dataset_with_nan.position - data_smoothed = savgol_filter(data, window, polyorder=polyorder) - del data_smoothed.attrs["log"] - assert isinstance(data_smoothed, xr.DataArray) and not ( - data_smoothed.equals(data) + # Filter position + valid_input_dataset = request.getfixturevalue(valid_dataset) + position_filtered = filter_func( + valid_input_dataset.position, **filter_kwargs ) + del position_filtered.attrs["log"] + + # filtered array is an xr.DataArray + assert isinstance(position_filtered, xr.DataArray) + + # filtered data should not be equal to the original data + assert not position_filtered.equals(valid_input_dataset.position) + -def test_savgol_filter_with_nans(valid_poses_dataset_with_nan, helpers): - """Test NaN behaviour of the Savitzky-Golay filter. The input data - contains NaN values in all keypoints of the first individual at times - 3, 7, and 8 (0-indexed, 10 total timepoints). - The Savitzky-Golay filter should propagate NaNs within the windows of - the filter, but it should not introduce any NaNs for the second individual. +# Expected number of nans in the position array per +# individual, after applying a filter with window size 3 +@pytest.mark.parametrize( + ("valid_dataset, expected_nans_in_filtered_position_per_indiv"), + [ + ( + "valid_poses_dataset", + {0: 0, 1: 0}, + ), # filtering should not introduce nans if input has no nans + ("valid_bboxes_dataset", {0: 0, 1: 0}), + ("valid_poses_dataset_with_nan", {0: 7, 1: 0}), + ("valid_bboxes_dataset_with_nan", {0: 7, 1: 0}), + ], +) +@pytest.mark.parametrize( + ("filter_func, filter_kwargs"), + [ + (median_filter, {"window": 3}), + (savgol_filter, {"window": 3, "polyorder": 2}), + ], +) +def test_filter_with_nans_on_position( + filter_func, + filter_kwargs, + valid_dataset, + expected_nans_in_filtered_position_per_indiv, + helpers, + request, +): + """Test NaN behaviour of the selected filter. The median and SG filters + should set all values to NaN if one element of the sliding window is NaN. """ - data = valid_poses_dataset_with_nan.position - data_smoothed = savgol_filter(data, window=3, polyorder=2) - # There should be 28 NaNs in total for the first individual, i.e. - # at 7 timepoints, 2 keypoints, 2 space dimensions - # all except for timepoints 0, 1 and 5 - assert helpers.count_nans(data_smoothed) == 28 - assert not ( - data_smoothed.isel(individuals=0, time=[0, 1, 5]).isnull().any() + + def _assert_n_nans_in_position_per_individual( + valid_input_dataset, + position_filtered, + expected_nans_in_filt_position_per_indiv, + ): + # compute n nans in position after filtering per individual + n_nans_after_filtering_per_indiv = { + i: helpers.count_nans(position_filtered.isel(individuals=i)) + for i in range(valid_input_dataset.sizes["individuals"]) + } + + # check number of nans per indiv is as expected + for i in range(valid_input_dataset.sizes["individuals"]): + assert n_nans_after_filtering_per_indiv[i] == ( + expected_nans_in_filt_position_per_indiv[i] + * valid_input_dataset.sizes["space"] + * valid_input_dataset.sizes.get("keypoints", 1) + ) + + # Filter position + valid_input_dataset = request.getfixturevalue(valid_dataset) + position_filtered = filter_func( + valid_input_dataset.position, **filter_kwargs ) - assert not data_smoothed.isel(individuals=1).isnull().any() + # check number of nans per indiv is as expected + _assert_n_nans_in_position_per_individual( + valid_input_dataset, + position_filtered, + expected_nans_in_filtered_position_per_indiv, + ) + + # if input had nans, + # individual 1's position at exact timepoints 0, 1 and 5 is not nan + n_nans_input = helpers.count_nans(valid_input_dataset.position) + if n_nans_input != 0: + assert not ( + position_filtered.isel(individuals=0, time=[0, 1, 5]) + .isnull() + .any() + ) + +@pytest.mark.parametrize( + "valid_dataset", + list_all_valid_datasets, +) @pytest.mark.parametrize( "override_kwargs", [ @@ -128,7 +246,7 @@ def test_savgol_filter_with_nans(valid_poses_dataset_with_nan, helpers): ], ) def test_savgol_filter_kwargs_override( - valid_poses_dataset_with_nan, override_kwargs + valid_dataset, override_kwargs, request ): """Test that overriding keyword arguments in the Savitzky-Golay filter works, except for the ``axis`` argument, which should raise a ValueError. @@ -140,7 +258,7 @@ def test_savgol_filter_kwargs_override( ) with expected_exception: savgol_filter( - valid_poses_dataset_with_nan.position, + request.getfixturevalue(valid_dataset).position, window=3, **override_kwargs, ) diff --git a/tests/test_unit/test_kinematics.py b/tests/test_unit/test_kinematics.py index e5241bc88..7641aeeb6 100644 --- a/tests/test_unit/test_kinematics.py +++ b/tests/test_unit/test_kinematics.py @@ -1,5 +1,3 @@ -from contextlib import nullcontext as does_not_raise - import numpy as np import pytest import xarray as xr @@ -7,106 +5,180 @@ from movement.analysis import kinematics -class TestKinematics: - """Test suite for the kinematics module.""" - - @pytest.fixture - def expected_dataarray(self, valid_poses_dataset): - """Return a function to generate the expected dataarray - for different kinematic properties. - """ - - def _expected_dataarray(property): - """Return an xarray.DataArray with default values and - the expected dimensions and coordinates. - """ - # Expected x,y values for velocity - x_vals = np.array( - [1.0, 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0, 17.0] - ) - y_vals = np.full((10, 2, 2, 1), 4.0) - if property == "acceleration": - x_vals = np.array( - [1.0, 1.5, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 1.5, 1.0] - ) - y_vals = np.full((10, 2, 2, 1), 0) - elif property == "displacement": - x_vals = np.array( - [0.0, 1.0, 3.0, 5.0, 7.0, 9.0, 11.0, 13.0, 15.0, 17.0] - ) - y_vals[0] = 0 - - x_vals = x_vals.reshape(-1, 1, 1, 1) - # Repeat the x_vals to match the shape of the position - x_vals = np.tile(x_vals, (1, 2, 2, 1)) - return xr.DataArray( - np.concatenate( - [x_vals, y_vals], - axis=-1, - ), - dims=valid_poses_dataset.dims, - coords=valid_poses_dataset.coords, - ) - - return _expected_dataarray - - kinematic_test_params = [ - ("valid_poses_dataset", does_not_raise()), - ("valid_poses_dataset_with_nan", does_not_raise()), - ("missing_dim_poses_dataset", pytest.raises(AttributeError)), - ] +@pytest.mark.parametrize( + "valid_dataset_uniform_linear_motion", + [ + "valid_poses_dataset_uniform_linear_motion", + "valid_bboxes_dataset", + ], +) +@pytest.mark.parametrize( + "kinematic_variable, expected_kinematics", + [ + ( + "displacement", + [ + np.vstack([np.zeros((1, 2)), np.ones((9, 2))]), # Individual 0 + np.multiply( + np.vstack([np.zeros((1, 2)), np.ones((9, 2))]), + np.array([1, -1]), + ), # Individual 1 + ], + ), + ( + "velocity", + [ + np.ones((10, 2)), # Individual 0 + np.multiply( + np.ones((10, 2)), np.array([1, -1]) + ), # Individual 1 + ], + ), + ( + "acceleration", + [ + np.zeros((10, 2)), # Individual 0 + np.zeros((10, 2)), # Individual 1 + ], + ), + ], +) +def test_kinematics_uniform_linear_motion( + valid_dataset_uniform_linear_motion, + kinematic_variable, + expected_kinematics, # 2D: n_frames, n_space_dims + request, +): + """Test computed kinematics for a uniform linear motion case. + + Uniform linear motion means the individuals move along a line + at constant velocity. + + We consider 2 individuals ("id_0" and "id_1"), + tracked for 10 frames, along x and y: + - id_0 moves along x=y line from the origin + - id_1 moves along x=-y line from the origin + - they both move one unit (pixel) along each axis in each frame + + If the dataset is a poses dataset, we consider 3 keypoints per individual + (centroid, left, right), that are always in front of the centroid keypoint + at 45deg from the trajectory. + """ + # Compute kinematic array from input dataset + position = request.getfixturevalue( + valid_dataset_uniform_linear_motion + ).position + kinematic_array = getattr(kinematics, f"compute_{kinematic_variable}")( + position + ) - @pytest.mark.parametrize("ds, expected_exception", kinematic_test_params) - def test_displacement( - self, ds, expected_exception, expected_dataarray, request - ): - """Test displacement computation.""" - ds = request.getfixturevalue(ds) - with expected_exception: - result = kinematics.compute_displacement(ds.position) - expected = expected_dataarray("displacement") - if ds.position.isnull().any(): - expected.loc[ - {"individuals": "ind1", "time": [3, 4, 7, 8, 9]} - ] = np.nan - xr.testing.assert_allclose(result, expected) - - @pytest.mark.parametrize("ds, expected_exception", kinematic_test_params) - def test_velocity( - self, ds, expected_exception, expected_dataarray, request - ): - """Test velocity computation.""" - ds = request.getfixturevalue(ds) - with expected_exception: - result = kinematics.compute_velocity(ds.position) - expected = expected_dataarray("velocity") - if ds.position.isnull().any(): - expected.loc[ - {"individuals": "ind1", "time": [2, 4, 6, 7, 8, 9]} - ] = np.nan - xr.testing.assert_allclose(result, expected) - - @pytest.mark.parametrize("ds, expected_exception", kinematic_test_params) - def test_acceleration( - self, ds, expected_exception, expected_dataarray, request - ): - """Test acceleration computation.""" - ds = request.getfixturevalue(ds) - with expected_exception: - result = kinematics.compute_acceleration(ds.position) - expected = expected_dataarray("acceleration") - if ds.position.isnull().any(): - expected.loc[ - {"individuals": "ind1", "time": [1, 3, 5, 6, 7, 8, 9]} - ] = np.nan - xr.testing.assert_allclose(result, expected) - - @pytest.mark.parametrize("order", [0, -1, 1.0, "1"]) - def test_approximate_derivative_with_invalid_order(self, order): - """Test that an error is raised when the order is non-positive.""" - data = np.arange(10) - expected_exception = ( - ValueError if isinstance(order, int) else TypeError + # Build expected data array from the expected numpy array + expected_array = xr.DataArray( + np.stack(expected_kinematics, axis=1), + # Stack along the "individuals" axis + dims=["time", "individuals", "space"], + ) + if "keypoints" in position.coords: + expected_array = expected_array.expand_dims( + {"keypoints": position.coords["keypoints"].size} ) - with pytest.raises(expected_exception): - kinematics._compute_approximate_derivative(data, order=order) + expected_array = expected_array.transpose( + "time", "individuals", "keypoints", "space" + ) + + # Compare the values of the kinematic_array against the expected_array + np.testing.assert_allclose(kinematic_array.values, expected_array.values) + + +@pytest.mark.parametrize( + "valid_dataset_with_nan", + [ + "valid_poses_dataset_with_nan", + "valid_bboxes_dataset_with_nan", + ], +) +@pytest.mark.parametrize( + "kinematic_variable, expected_nans_per_individual", + [ + ("displacement", [5, 0]), # individual 0, individual 1 + ("velocity", [6, 0]), + ("acceleration", [7, 0]), + ], +) +def test_kinematics_with_dataset_with_nans( + valid_dataset_with_nan, + kinematic_variable, + expected_nans_per_individual, + helpers, + request, +): + """Test kinematics computation for a dataset with nans. + + We test that the kinematics can be computed and that the number + of nan values in the kinematic array is as expected. + + """ + # compute kinematic array + valid_dataset = request.getfixturevalue(valid_dataset_with_nan) + position = valid_dataset.position + kinematic_array = getattr(kinematics, f"compute_{kinematic_variable}")( + position + ) + + # compute n nans in kinematic array per individual + n_nans_kinematics_per_indiv = [ + helpers.count_nans(kinematic_array.isel(individuals=i)) + for i in range(valid_dataset.sizes["individuals"]) + ] + + # expected nans per individual adjusted for space and keypoints dimensions + expected_nans_adjusted = [ + n + * valid_dataset.sizes["space"] + * valid_dataset.sizes.get("keypoints", 1) + for n in expected_nans_per_individual + ] + # check number of nans per individual is as expected in kinematic array + np.testing.assert_array_equal( + n_nans_kinematics_per_indiv, expected_nans_adjusted + ) + + +@pytest.mark.parametrize( + "invalid_dataset, expected_exception", + [ + ("not_a_dataset", pytest.raises(AttributeError)), + ("empty_dataset", pytest.raises(AttributeError)), + ("missing_var_poses_dataset", pytest.raises(AttributeError)), + ("missing_var_bboxes_dataset", pytest.raises(AttributeError)), + ("missing_dim_poses_dataset", pytest.raises(ValueError)), + ("missing_dim_bboxes_dataset", pytest.raises(ValueError)), + ], +) +@pytest.mark.parametrize( + "kinematic_variable", + [ + "displacement", + "velocity", + "acceleration", + ], +) +def test_kinematics_with_invalid_dataset( + invalid_dataset, + expected_exception, + kinematic_variable, + request, +): + """Test kinematics computation with an invalid dataset.""" + with expected_exception: + position = request.getfixturevalue(invalid_dataset).position + getattr(kinematics, f"compute_{kinematic_variable}")(position) + + +@pytest.mark.parametrize("order", [0, -1, 1.0, "1"]) +def test_approximate_derivative_with_invalid_order(order): + """Test that an error is raised when the order is non-positive.""" + data = np.arange(10) + expected_exception = ValueError if isinstance(order, int) else TypeError + with pytest.raises(expected_exception): + kinematics._compute_approximate_time_derivative(data, order=order) diff --git a/tests/test_unit/test_load_bboxes.py b/tests/test_unit/test_load_bboxes.py index dbed362cf..474e61183 100644 --- a/tests/test_unit/test_load_bboxes.py +++ b/tests/test_unit/test_load_bboxes.py @@ -419,3 +419,54 @@ def test_fps_and_time_coords( else: start_frame = 0 assert_time_coordinates(ds, expected_fps, start_frame) + + +def test_df_from_via_tracks_file(via_tracks_file): + """Test that the helper function correctly reads the VIA tracks .csv file + as a dataframe. + """ + df = load_bboxes._df_from_via_tracks_file(via_tracks_file) + + assert isinstance(df, pd.DataFrame) + assert len(df.frame_number.unique()) == 5 + assert ( + df.shape[0] == len(df.ID.unique()) * 5 + ) # all individuals in all frames (even if nan) + assert list(df.columns) == [ + "ID", + "frame_number", + "x", + "y", + "w", + "h", + "confidence", + ] + + +def test_position_numpy_array_from_via_tracks_file(via_tracks_file): + """Test the extracted position array from the VIA tracks .csv file + represents the centroid of the bbox. + """ + # Extract numpy arrays from VIA tracks .csv file + bboxes_arrays = load_bboxes._numpy_arrays_from_via_tracks_file( + via_tracks_file + ) + + # Read VIA tracks .csv file as a dataframe + df = load_bboxes._df_from_via_tracks_file(via_tracks_file) + + # Compute centroid positions from the dataframe + # (go thru in the same order as ID array) + list_derived_centroids = [] + for id in bboxes_arrays["ID_array"]: + df_one_id = df[df["ID"] == id.item()] + centroid_position = np.array( + [df_one_id.x + df_one_id.w / 2, df_one_id.y + df_one_id.h / 2] + ).T # frames, xy + list_derived_centroids.append(centroid_position) + + # Compare to extracted position array + assert np.allclose( + bboxes_arrays["position_array"], # frames, individuals, xy + np.stack(list_derived_centroids, axis=1), + ) diff --git a/tests/test_unit/test_logging.py b/tests/test_unit/test_logging.py index d0a8c3bf5..348a36872 100644 --- a/tests/test_unit/test_logging.py +++ b/tests/test_unit/test_logging.py @@ -1,6 +1,7 @@ import logging import pytest +import xarray as xr from movement.utils.logging import log_error, log_to_attrs, log_warning @@ -45,27 +46,41 @@ def test_log_warning(caplog): assert caplog.records[0].levelname == "WARNING" -@pytest.mark.parametrize("input_data", ["dataset", "dataarray"]) -def test_log_to_attrs(input_data, valid_poses_dataset): +@pytest.mark.parametrize( + "input_data", + [ + "valid_poses_dataset", + "valid_bboxes_dataset", + ], +) +@pytest.mark.parametrize( + "selector_fn, expected_selector_type", + [ + (lambda ds: ds, xr.Dataset), # take full dataset + (lambda ds: ds.position, xr.DataArray), # take position data array + ], +) +def test_log_to_attrs( + input_data, selector_fn, expected_selector_type, request +): """Test that the ``log_to_attrs()`` decorator appends - log entries to the output data's ``log`` attribute and - checks that ``attrs`` contains all expected values. + log entries to the dataset's or the data array's ``log`` + attribute and check that ``attrs`` contains all the expected values. """ + # a fake operation on the dataset to log @log_to_attrs def fake_func(data, arg, kwarg=None): return data - input_data = ( - valid_poses_dataset - if input_data == "dataset" - else valid_poses_dataset.position - ) + # apply operation to dataset or data array + dataset = request.getfixturevalue(input_data) + input_data = selector_fn(dataset) output_data = fake_func(input_data, "test1", kwarg="test2") + # check the log in the dataset is as expected + assert isinstance(output_data, expected_selector_type) assert "log" in output_data.attrs assert output_data.attrs["log"][0]["operation"] == "fake_func" - assert ( - output_data.attrs["log"][0]["arg_1"] == "test1" - and output_data.attrs["log"][0]["kwarg"] == "test2" - ) + assert output_data.attrs["log"][0]["arg_1"] == "test1" + assert output_data.attrs["log"][0]["kwarg"] == "test2" diff --git a/tests/test_unit/test_reports.py b/tests/test_unit/test_reports.py index 51d441ea6..79c3bc892 100644 --- a/tests/test_unit/test_reports.py +++ b/tests/test_unit/test_reports.py @@ -4,28 +4,126 @@ @pytest.mark.parametrize( - "data_selection", + "valid_dataset", [ - lambda ds: ds.position, # Entire dataset - lambda ds: ds.position.sel( - individuals="ind1" - ), # Missing "individuals" dim - lambda ds: ds.position.sel( - keypoints="key1" - ), # Missing "keypoints" dim - lambda ds: ds.position.sel( - individuals="ind1", keypoints="key1" - ), # Missing "individuals" and "keypoints" dims + "valid_poses_dataset", + "valid_bboxes_dataset", + "valid_poses_dataset_with_nan", + "valid_bboxes_dataset_with_nan", ], ) -def test_report_nan_values( - capsys, valid_poses_dataset_with_nan, data_selection +@pytest.mark.parametrize( + "data_selection, list_expected_individuals_indices", + [ + (lambda ds: ds.position, [0, 1]), # full position data array + ( + lambda ds: ds.position.isel(individuals=0), + [0], + ), # position of individual 0 only + ], +) +def test_report_nan_values_in_position_selecting_individual( + valid_dataset, + data_selection, + list_expected_individuals_indices, + request, ): - """Test that the nan-value reporting function handles data - with missing ``individuals`` and/or ``keypoint`` dims, and - that the dataset name is included in the report. + """Test that the nan-value reporting function handles position data + with specific ``individuals`` , and that the data array name (position) + and only the relevant individuals are included in the report. """ - data = data_selection(valid_poses_dataset_with_nan) - assert data.name in report_nan_values( - data - ), "Dataset name should be in the output" + # extract relevant position data + input_dataset = request.getfixturevalue(valid_dataset) + output_data_array = data_selection(input_dataset) + + # produce report + report_str = report_nan_values(output_data_array) + + # check report of nan values includes name of data array + assert output_data_array.name in report_str + + # check report of nan values includes selected individuals only + list_expected_individuals = [ + input_dataset["individuals"][idx].item() + for idx in list_expected_individuals_indices + ] + list_not_expected_individuals = [ + indiv.item() + for indiv in input_dataset["individuals"] + if indiv.item() not in list_expected_individuals + ] + assert all([ind in report_str for ind in list_expected_individuals]) + assert all( + [ind not in report_str for ind in list_not_expected_individuals] + ) + + +@pytest.mark.parametrize( + "valid_dataset", + [ + "valid_poses_dataset", + "valid_poses_dataset_with_nan", + ], +) +@pytest.mark.parametrize( + "data_selection, list_expected_keypoints, list_expected_individuals", + [ + ( + lambda ds: ds.position, + ["key1", "key2"], + ["ind1", "ind2"], + ), # Report nans in position for all keypoints and individuals + ( + lambda ds: ds.position.sel(keypoints="key1"), + [], + ["ind1", "ind2"], + ), # Report nans in position for keypoint "key1", for all individuals + # Note: if only one keypoint exists, it is not explicitly reported + ( + lambda ds: ds.position.sel(individuals="ind1", keypoints="key1"), + [], + ["ind1"], + ), # Report nans in position for individual "ind1" and keypoint "key1" + # Note: if only one keypoint exists, it is not explicitly reported + ], +) +def test_report_nan_values_in_position_selecting_keypoint( + valid_dataset, + data_selection, + list_expected_keypoints, + list_expected_individuals, + request, +): + """Test that the nan-value reporting function handles position data + with specific ``keypoints`` , and that the data array name (position) + and only the relevant keypoints are included in the report. + """ + # extract relevant position data + input_dataset = request.getfixturevalue(valid_dataset) + output_data_array = data_selection(input_dataset) + + # produce report + report_str = report_nan_values(output_data_array) + + # check report of nan values includes name of data array + assert output_data_array.name in report_str + + # check report of nan values includes only selected keypoints + list_not_expected_keypoints = [ + indiv.item() + for indiv in input_dataset["keypoints"] + if indiv.item() not in list_expected_keypoints + ] + assert all([kpt in report_str for kpt in list_expected_keypoints]) + assert all([kpt not in report_str for kpt in list_not_expected_keypoints]) + + # check report of nan values includes selected individuals only + list_not_expected_individuals = [ + indiv.item() + for indiv in input_dataset["individuals"] + if indiv.item() not in list_expected_individuals + ] + assert all([ind in report_str for ind in list_expected_individuals]) + assert all( + [ind not in report_str for ind in list_not_expected_individuals] + ) diff --git a/tests/test_unit/test_validators/test_datasets_validators.py b/tests/test_unit/test_validators/test_datasets_validators.py index a882162ea..493f1d460 100644 --- a/tests/test_unit/test_validators/test_datasets_validators.py +++ b/tests/test_unit/test_validators/test_datasets_validators.py @@ -76,29 +76,14 @@ def position_array_params(request): ), # not an ndarray ( np.zeros((10, 2, 3)), - f"Expected '{key}' to have 2 spatial " "coordinates, but got 3.", + f"Expected '{key}_array' to have 2 spatial " + "coordinates, but got 3.", ), # last dim not 2 ] - for key in ["position_array", "shape_array"] + for key in ["position", "shape"] } -@pytest.fixture -def valid_bboxes_inputs(): - """Return a dictionary with valid inputs for a ValidBboxesDataset.""" - n_frames, n_individuals, n_space = (10, 2, 2) - # valid array for position or shape - valid_bbox_array = np.zeros((n_frames, n_individuals, n_space)) - - return { - "position_array": valid_bbox_array, - "shape_array": valid_bbox_array, - "individual_names": [ - "id_" + str(id) for id in range(valid_bbox_array.shape[1]) - ], - } - - # Tests pose dataset @pytest.mark.parametrize( "invalid_position_array, log_message", @@ -223,7 +208,7 @@ def test_poses_dataset_validator_source_software( # Tests bboxes dataset @pytest.mark.parametrize( "invalid_position_array, log_message", - invalid_bboxes_arrays_and_expected_log["position_array"], + invalid_bboxes_arrays_and_expected_log["position"], ) def test_bboxes_dataset_validator_with_invalid_position_array( invalid_position_array, log_message, request @@ -232,19 +217,19 @@ def test_bboxes_dataset_validator_with_invalid_position_array( with pytest.raises(ValueError) as excinfo: ValidBboxesDataset( position_array=invalid_position_array, - shape_array=request.getfixturevalue("valid_bboxes_inputs")[ - "shape_array" - ], - individual_names=request.getfixturevalue("valid_bboxes_inputs")[ - "individual_names" - ], + shape_array=request.getfixturevalue( + "valid_bboxes_arrays_all_zeros" + )["shape"], + individual_names=request.getfixturevalue( + "valid_bboxes_arrays_all_zeros" + )["individual_names"], ) assert str(excinfo.value) == log_message @pytest.mark.parametrize( "invalid_shape_array, log_message", - invalid_bboxes_arrays_and_expected_log["shape_array"], + invalid_bboxes_arrays_and_expected_log["shape"], ) def test_bboxes_dataset_validator_with_invalid_shape_array( invalid_shape_array, log_message, request @@ -252,13 +237,13 @@ def test_bboxes_dataset_validator_with_invalid_shape_array( """Test that invalid shape arrays raise an error.""" with pytest.raises(ValueError) as excinfo: ValidBboxesDataset( - position_array=request.getfixturevalue("valid_bboxes_inputs")[ - "position_array" - ], + position_array=request.getfixturevalue( + "valid_bboxes_arrays_all_zeros" + )["position"], shape_array=invalid_shape_array, - individual_names=request.getfixturevalue("valid_bboxes_inputs")[ - "individual_names" - ], + individual_names=request.getfixturevalue( + "valid_bboxes_arrays_all_zeros" + )["individual_names"], ) assert str(excinfo.value) == log_message @@ -274,15 +259,19 @@ def test_bboxes_dataset_validator_with_invalid_shape_array( ( [1, 2, 3], pytest.raises(ValueError), - "Expected 'individual_names' to have length 2, but got 3.", + "Expected 'individual_names' to have length 2, " + f"but got {len([1, 2, 3])}.", ), # length doesn't match position_array.shape[1] + # from valid_bboxes_arrays_all_zeros fixture ( ["id_1", "id_1"], pytest.raises(ValueError), "individual_names passed to the dataset are not unique. " "There are 2 elements in the list, but " "only 1 are unique.", - ), # some IDs are not unique + ), # some IDs are not unique. + # Note: length of individual_names list should match + # n_individuals in valid_bboxes_arrays_all_zeros fixture ], ) def test_bboxes_dataset_validator_individual_names( @@ -291,12 +280,12 @@ def test_bboxes_dataset_validator_individual_names( """Test individual_names inputs.""" with expected_exception as excinfo: ds = ValidBboxesDataset( - position_array=request.getfixturevalue("valid_bboxes_inputs")[ - "position_array" - ], - shape_array=request.getfixturevalue("valid_bboxes_inputs")[ - "shape_array" - ], + position_array=request.getfixturevalue( + "valid_bboxes_arrays_all_zeros" + )["position"], + shape_array=request.getfixturevalue( + "valid_bboxes_arrays_all_zeros" + )["shape"], individual_names=list_individual_names, ) if list_individual_names is None: @@ -313,13 +302,14 @@ def test_bboxes_dataset_validator_individual_names( ( np.ones((10, 3, 2)), pytest.raises(ValueError), - "Expected 'confidence_array' to have shape (10, 2), " - "but got (10, 3, 2).", - ), # will not match position_array shape + f"Expected 'confidence_array' to have shape (10, 2), " + f"but got {np.ones((10, 3, 2)).shape}.", + ), # will not match shape of position_array in + # valid_bboxes_arrays_all_zeros fixture ( [1, 2, 3], pytest.raises(ValueError), - f"Expected a numpy array, but got {type(list())}.", + f"Expected a numpy array, but got {type([1, 2, 3])}.", ), # not an ndarray, should raise ValueError ( None, @@ -334,15 +324,15 @@ def test_bboxes_dataset_validator_confidence_array( """Test that invalid confidence arrays raise the appropriate errors.""" with expected_exception as excinfo: ds = ValidBboxesDataset( - position_array=request.getfixturevalue("valid_bboxes_inputs")[ - "position_array" - ], - shape_array=request.getfixturevalue("valid_bboxes_inputs")[ - "shape_array" - ], - individual_names=request.getfixturevalue("valid_bboxes_inputs")[ - "individual_names" - ], + position_array=request.getfixturevalue( + "valid_bboxes_arrays_all_zeros" + )["position"], + shape_array=request.getfixturevalue( + "valid_bboxes_arrays_all_zeros" + )["shape"], + individual_names=request.getfixturevalue( + "valid_bboxes_arrays_all_zeros" + )["individual_names"], confidence_array=confidence_array, ) if confidence_array is None: @@ -387,15 +377,15 @@ def test_bboxes_dataset_validator_frame_array( """Test that invalid frame arrays raise the appropriate errors.""" with expected_exception as excinfo: ds = ValidBboxesDataset( - position_array=request.getfixturevalue("valid_bboxes_inputs")[ - "position_array" - ], - shape_array=request.getfixturevalue("valid_bboxes_inputs")[ - "shape_array" - ], - individual_names=request.getfixturevalue("valid_bboxes_inputs")[ - "individual_names" - ], + position_array=request.getfixturevalue( + "valid_bboxes_arrays_all_zeros" + )["position"], + shape_array=request.getfixturevalue( + "valid_bboxes_arrays_all_zeros" + )["shape"], + individual_names=request.getfixturevalue( + "valid_bboxes_arrays_all_zeros" + )["individual_names"], frame_array=frame_array, )