Skip to content

Commit

Permalink
Update integration tests (#295)
Browse files Browse the repository at this point in the history
* Add filter with nan under threshold and varying window

* Get kinematics tests

* Adapt integration tests for kinematics+polar

* Update integration tests for filtering

* Fix factor 2 difference

* Update conftest

* Remove redundant comment in conftest

* Apply feedback from kinematic tests

* Cosmetic changes

* Spoof user-agent to avoid 403 error

* Check different URL

* Ignore link to license temporarily

* Try fake-useragent

* Revert "Try fake-useragent"

This reverts commit d67de0e.
  • Loading branch information
sfmig authored Oct 17, 2024
1 parent bfb20d2 commit a42838d
Show file tree
Hide file tree
Showing 4 changed files with 161 additions and 64 deletions.
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@
# A list of regular expressions that match URIs that should not be checked
linkcheck_ignore = [
"https://pubs.acs.org/doi/*", # Checking dois is forbidden here
"https://opensource.org/license/bsd-3-clause/", # to avoid odd 403 error
]

myst_url_schemes = {
Expand Down
71 changes: 30 additions & 41 deletions tests/test_integration/test_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,70 +9,56 @@

@pytest.fixture
def sample_dataset():
"""Return a single-animal sample dataset, with time unit in frames.
This allows us to better control the expected number of NaNs in the tests.
"""
"""Return a single-animal sample dataset, with time unit in frames."""
ds_path = fetch_dataset_paths("DLC_single-mouse_EPM.predictions.h5")[
"poses"
]
ds = load_poses.from_dlc_file(ds_path)
ds["velocity"] = ds.move.compute_velocity()

return ds


@pytest.mark.parametrize("window", [3, 5, 6, 13])
def test_nan_propagation_through_filters(sample_dataset, window, helpers):
"""Test NaN propagation when passing a DataArray through
multiple filters sequentially. For the ``median_filter``
and ``savgol_filter``, the number of NaNs is expected to increase
"""Test NaN propagation is as expected when passing a DataArray through
filter by confidence, Savgol filter and interpolation.
For the ``savgol_filter``, the number of NaNs is expected to increase
at most by the filter's window length minus one (``window - 1``)
multiplied by the number of consecutive NaNs in the input data.
"""
# Introduce nans via filter_by_confidence
# Compute number of low confidence keypoints
n_low_confidence_kpts = (sample_dataset.confidence.data < 0.6).sum()

# Check filter position by confidence creates correct number of NaNs
sample_dataset.update(
{"position": sample_dataset.move.filter_by_confidence()}
)
expected_n_nans = 13136
n_nans_confilt = helpers.count_nans(sample_dataset.position)
assert n_nans_confilt == expected_n_nans, (
f"Expected {expected_n_nans} NaNs in filtered data, "
f"got: {n_nans_confilt}"
)
n_consecutive_nans = helpers.count_consecutive_nans(
sample_dataset.position
)
# Apply median filter and check that
# it doesn't introduce too many or too few NaNs
sample_dataset.update(
{"position": sample_dataset.move.median_filter(window)}
)
n_nans_medfilt = helpers.count_nans(sample_dataset.position)
max_nans_increase = (window - 1) * n_consecutive_nans
assert (
n_nans_medfilt <= n_nans_confilt + max_nans_increase
), "Median filter introduced more NaNs than expected."
n_total_nans_input = helpers.count_nans(sample_dataset.position)

assert (
n_nans_medfilt >= n_nans_confilt
), "Median filter mysteriously removed NaNs."
n_consecutive_nans = helpers.count_consecutive_nans(
n_total_nans_input
== n_low_confidence_kpts * sample_dataset.dims["space"]
)

# Compute maximum expected increase in NaNs due to filtering
n_consecutive_nans_input = helpers.count_consecutive_nans(
sample_dataset.position
)
max_nans_increase = (window - 1) * n_consecutive_nans_input

# Apply savgol filter and check that
# it doesn't introduce too many or too few NaNs
# Apply savgol filter and check that number of NaNs is within threshold
sample_dataset.update(
{"position": sample_dataset.move.savgol_filter(window, polyorder=2)}
)
n_nans_savgol = helpers.count_nans(sample_dataset.position)
max_nans_increase = (window - 1) * n_consecutive_nans
assert (
n_nans_savgol <= n_nans_medfilt + max_nans_increase
), "Savgol filter introduced more NaNs than expected."
assert (
n_nans_savgol >= n_nans_medfilt
), "Savgol filter mysteriously removed NaNs."

# Interpolate data (without max_gap) to eliminate all NaNs
n_total_nans_savgol = helpers.count_nans(sample_dataset.position)

# Check that filtering does not reduce number of nans
assert n_total_nans_savgol >= n_total_nans_input
# Check that the increase in nans is below the expected threshold
assert n_total_nans_savgol - n_total_nans_input <= max_nans_increase

# Interpolate data (without max_gap) and check it eliminates all NaNs
sample_dataset.update(
{"position": sample_dataset.move.interpolate_over_time()}
)
Expand Down Expand Up @@ -105,6 +91,9 @@ def test_accessor_filter_method(
applied, if valid data variables are passed, otherwise
raise an exception.
"""
# Compute velocity
sample_dataset["velocity"] = sample_dataset.move.compute_velocity()

with expected_exception as expected_type:
if method in ["median_filter", "savgol_filter"]:
# supply required "window" argument
Expand Down
106 changes: 83 additions & 23 deletions tests/test_integration/test_kinematics_vector_transform.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,93 @@
from contextlib import nullcontext as does_not_raise
import math

import numpy as np
import pytest
import xarray as xr

from movement.utils import vector


class TestKinematicsVectorTransform:
"""Test the vector transformation functionality with
various kinematic properties.
@pytest.mark.parametrize(
"valid_dataset_uniform_linear_motion",
[
"valid_poses_dataset_uniform_linear_motion",
"valid_bboxes_dataset",
],
)
@pytest.mark.parametrize(
"kinematic_variable, expected_kinematics_polar",
[
(
"displacement",
[
np.vstack(
[
np.zeros((1, 2)),
np.tile([math.sqrt(2), math.atan(1)], (9, 1)),
],
), # Individual 0, rho=sqrt(2), phi=45deg
np.vstack(
[
np.zeros((1, 2)),
np.tile([math.sqrt(2), -math.atan(1)], (9, 1)),
]
), # Individual 1, rho=sqrt(2), phi=-45deg
],
),
(
"velocity",
[
np.tile(
[math.sqrt(2), math.atan(1)], (10, 1)
), # Individual O, rho, phi=45deg
np.tile(
[math.sqrt(2), -math.atan(1)], (10, 1)
), # Individual 1, rho, phi=-45deg
],
),
(
"acceleration",
[
np.zeros((10, 2)), # Individual 0
np.zeros((10, 2)), # Individual 1
],
),
],
)
def test_cart2pol_transform_on_kinematics(
valid_dataset_uniform_linear_motion,
kinematic_variable,
expected_kinematics_polar,
request,
):
"""Test transformation between Cartesian and polar coordinates
with various kinematic properties.
"""
ds = request.getfixturevalue(valid_dataset_uniform_linear_motion)
kinematic_array_cart = getattr(ds.move, f"compute_{kinematic_variable}")()
kinematic_array_pol = vector.cart2pol(kinematic_array_cart)

@pytest.mark.parametrize(
"ds, expected_exception",
[
("valid_poses_dataset", does_not_raise()),
("valid_poses_dataset_with_nan", does_not_raise()),
("missing_dim_poses_dataset", pytest.raises(RuntimeError)),
],
# Build expected data array
expected_array_pol = xr.DataArray(
np.stack(expected_kinematics_polar, axis=1),
# Stack along the "individuals" axis
dims=["time", "individuals", "space"],
)
if "keypoints" in ds.position.coords:
expected_array_pol = expected_array_pol.expand_dims(
{"keypoints": ds.position.coords["keypoints"].size}
)
expected_array_pol = expected_array_pol.transpose(
"time", "individuals", "keypoints", "space"
)

# Compare the values of the kinematic_array against the expected_array
np.testing.assert_allclose(
kinematic_array_pol.values, expected_array_pol.values
)

# Check we can recover the original Cartesian array
kinematic_array_cart_recover = vector.pol2cart(kinematic_array_pol)
xr.testing.assert_allclose(
kinematic_array_cart, kinematic_array_cart_recover
)
def test_cart_and_pol_transform(
self, ds, expected_exception, kinematic_property, request
):
"""Test transformation between Cartesian and polar coordinates
with various kinematic properties.
"""
ds = request.getfixturevalue(ds)
with expected_exception:
data = getattr(ds.move, f"compute_{kinematic_property}")()
pol_data = vector.cart2pol(data)
cart_data = vector.pol2cart(pol_data)
xr.testing.assert_allclose(cart_data, data)
47 changes: 47 additions & 0 deletions tests/test_unit/test_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,53 @@ def _assert_n_nans_in_position_per_individual(
)


@pytest.mark.parametrize(
"valid_dataset_with_nan",
list_valid_datasets_with_nans,
)
@pytest.mark.parametrize(
"window",
[3, 5, 6, 10], # data is nframes = 10
)
@pytest.mark.parametrize(
"filter_func",
[median_filter, savgol_filter],
)
def test_filter_with_nans_on_position_varying_window(
valid_dataset_with_nan, window, filter_func, helpers, request
):
"""Test that the number of NaNs in the filtered position data
increases at most by the filter's window length minus one
multiplied by the number of consecutive NaNs in the input data.
"""
# Prepare kwargs per filter
kwargs = {"window": window}
if filter_func == savgol_filter:
kwargs["polyorder"] = 2

# Filter position
valid_input_dataset = request.getfixturevalue(valid_dataset_with_nan)
position_filtered = filter_func(
valid_input_dataset.position,
**kwargs,
)

# Count number of NaNs in the input and filtered position data
n_total_nans_initial = helpers.count_nans(valid_input_dataset.position)
n_consecutive_nans_initial = helpers.count_consecutive_nans(
valid_input_dataset.position
)

n_total_nans_filtered = helpers.count_nans(position_filtered)

max_nans_increase = (window - 1) * n_consecutive_nans_initial

# Check that filtering does not reduce number of nans
assert n_total_nans_filtered >= n_total_nans_initial
# Check that the increase in nans is below the expected threshold
assert n_total_nans_filtered - n_total_nans_initial <= max_nans_increase


@pytest.mark.parametrize(
"valid_dataset",
list_all_valid_datasets,
Expand Down

0 comments on commit a42838d

Please sign in to comment.