Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update integration tests #295

Merged
merged 17 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@
# A list of regular expressions that match URIs that should not be checked
linkcheck_ignore = [
"https://pubs.acs.org/doi/*", # Checking dois is forbidden here
"https://opensource.org/license/bsd-3-clause/", # to avoid odd 403 error
]

myst_url_schemes = {
Expand Down
71 changes: 30 additions & 41 deletions tests/test_integration/test_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,70 +9,56 @@

@pytest.fixture
def sample_dataset():
"""Return a single-animal sample dataset, with time unit in frames.
This allows us to better control the expected number of NaNs in the tests.
"""
"""Return a single-animal sample dataset, with time unit in frames."""
ds_path = fetch_dataset_paths("DLC_single-mouse_EPM.predictions.h5")[
"poses"
]
ds = load_poses.from_dlc_file(ds_path)
ds["velocity"] = ds.move.compute_velocity()

return ds


@pytest.mark.parametrize("window", [3, 5, 6, 13])
def test_nan_propagation_through_filters(sample_dataset, window, helpers):
"""Test NaN propagation when passing a DataArray through
multiple filters sequentially. For the ``median_filter``
and ``savgol_filter``, the number of NaNs is expected to increase
"""Test NaN propagation is as expected when passing a DataArray through
filter by confidence, Savgol filter and interpolation.
For the ``savgol_filter``, the number of NaNs is expected to increase
at most by the filter's window length minus one (``window - 1``)
multiplied by the number of consecutive NaNs in the input data.
"""
# Introduce nans via filter_by_confidence
# Compute number of low confidence keypoints
n_low_confidence_kpts = (sample_dataset.confidence.data < 0.6).sum()

# Check filter position by confidence creates correct number of NaNs
sample_dataset.update(
{"position": sample_dataset.move.filter_by_confidence()}
)
expected_n_nans = 13136
n_nans_confilt = helpers.count_nans(sample_dataset.position)
assert n_nans_confilt == expected_n_nans, (
f"Expected {expected_n_nans} NaNs in filtered data, "
f"got: {n_nans_confilt}"
)
n_consecutive_nans = helpers.count_consecutive_nans(
sample_dataset.position
)
# Apply median filter and check that
# it doesn't introduce too many or too few NaNs
sample_dataset.update(
{"position": sample_dataset.move.median_filter(window)}
)
n_nans_medfilt = helpers.count_nans(sample_dataset.position)
max_nans_increase = (window - 1) * n_consecutive_nans
assert (
n_nans_medfilt <= n_nans_confilt + max_nans_increase
), "Median filter introduced more NaNs than expected."
n_total_nans_input = helpers.count_nans(sample_dataset.position)

assert (
n_nans_medfilt >= n_nans_confilt
), "Median filter mysteriously removed NaNs."
n_consecutive_nans = helpers.count_consecutive_nans(
n_total_nans_input
== n_low_confidence_kpts * sample_dataset.dims["space"]
)

# Compute maximum expected increase in NaNs due to filtering
n_consecutive_nans_input = helpers.count_consecutive_nans(
sample_dataset.position
)
max_nans_increase = (window - 1) * n_consecutive_nans_input

# Apply savgol filter and check that
# it doesn't introduce too many or too few NaNs
# Apply savgol filter and check that number of NaNs is within threshold
sample_dataset.update(
{"position": sample_dataset.move.savgol_filter(window, polyorder=2)}
)
n_nans_savgol = helpers.count_nans(sample_dataset.position)
max_nans_increase = (window - 1) * n_consecutive_nans
assert (
n_nans_savgol <= n_nans_medfilt + max_nans_increase
), "Savgol filter introduced more NaNs than expected."
assert (
n_nans_savgol >= n_nans_medfilt
), "Savgol filter mysteriously removed NaNs."

# Interpolate data (without max_gap) to eliminate all NaNs
n_total_nans_savgol = helpers.count_nans(sample_dataset.position)

# Check that filtering does not reduce number of nans
assert n_total_nans_savgol >= n_total_nans_input
# Check that the increase in nans is below the expected threshold
assert n_total_nans_savgol - n_total_nans_input <= max_nans_increase

# Interpolate data (without max_gap) and check it eliminates all NaNs
sample_dataset.update(
{"position": sample_dataset.move.interpolate_over_time()}
)
Expand Down Expand Up @@ -105,6 +91,9 @@ def test_accessor_filter_method(
applied, if valid data variables are passed, otherwise
raise an exception.
"""
# Compute velocity
sample_dataset["velocity"] = sample_dataset.move.compute_velocity()

with expected_exception as expected_type:
if method in ["median_filter", "savgol_filter"]:
# supply required "window" argument
Expand Down
106 changes: 83 additions & 23 deletions tests/test_integration/test_kinematics_vector_transform.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,93 @@
from contextlib import nullcontext as does_not_raise
import math

import numpy as np
import pytest
import xarray as xr

from movement.utils import vector


class TestKinematicsVectorTransform:
"""Test the vector transformation functionality with
various kinematic properties.
@pytest.mark.parametrize(
"valid_dataset_uniform_linear_motion",
[
"valid_poses_dataset_uniform_linear_motion",
"valid_bboxes_dataset",
],
)
@pytest.mark.parametrize(
"kinematic_variable, expected_kinematics_polar",
[
(
"displacement",
[
np.vstack(
[
np.zeros((1, 2)),
np.tile([math.sqrt(2), math.atan(1)], (9, 1)),
],
), # Individual 0, rho=sqrt(2), phi=45deg
np.vstack(
[
np.zeros((1, 2)),
np.tile([math.sqrt(2), -math.atan(1)], (9, 1)),
]
), # Individual 1, rho=sqrt(2), phi=-45deg
],
),
(
"velocity",
[
np.tile(
[math.sqrt(2), math.atan(1)], (10, 1)
), # Individual O, rho, phi=45deg
np.tile(
[math.sqrt(2), -math.atan(1)], (10, 1)
), # Individual 1, rho, phi=-45deg
],
),
(
"acceleration",
[
np.zeros((10, 2)), # Individual 0
np.zeros((10, 2)), # Individual 1
],
),
],
)
def test_cart2pol_transform_on_kinematics(
valid_dataset_uniform_linear_motion,
kinematic_variable,
expected_kinematics_polar,
request,
):
"""Test transformation between Cartesian and polar coordinates
with various kinematic properties.
"""
ds = request.getfixturevalue(valid_dataset_uniform_linear_motion)
kinematic_array_cart = getattr(ds.move, f"compute_{kinematic_variable}")()
kinematic_array_pol = vector.cart2pol(kinematic_array_cart)

@pytest.mark.parametrize(
"ds, expected_exception",
[
("valid_poses_dataset", does_not_raise()),
("valid_poses_dataset_with_nan", does_not_raise()),
("missing_dim_poses_dataset", pytest.raises(RuntimeError)),
],
# Build expected data array
expected_array_pol = xr.DataArray(
np.stack(expected_kinematics_polar, axis=1),
# Stack along the "individuals" axis
dims=["time", "individuals", "space"],
)
if "keypoints" in ds.position.coords:
expected_array_pol = expected_array_pol.expand_dims(
{"keypoints": ds.position.coords["keypoints"].size}
)
expected_array_pol = expected_array_pol.transpose(
"time", "individuals", "keypoints", "space"
)

# Compare the values of the kinematic_array against the expected_array
np.testing.assert_allclose(
kinematic_array_pol.values, expected_array_pol.values
)

# Check we can recover the original Cartesian array
kinematic_array_cart_recover = vector.pol2cart(kinematic_array_pol)
xr.testing.assert_allclose(
kinematic_array_cart, kinematic_array_cart_recover
)
def test_cart_and_pol_transform(
self, ds, expected_exception, kinematic_property, request
):
"""Test transformation between Cartesian and polar coordinates
with various kinematic properties.
"""
ds = request.getfixturevalue(ds)
with expected_exception:
data = getattr(ds.move, f"compute_{kinematic_property}")()
pol_data = vector.cart2pol(data)
cart_data = vector.pol2cart(pol_data)
xr.testing.assert_allclose(cart_data, data)
47 changes: 47 additions & 0 deletions tests/test_unit/test_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,53 @@ def _assert_n_nans_in_position_per_individual(
)


@pytest.mark.parametrize(
"valid_dataset_with_nan",
list_valid_datasets_with_nans,
)
@pytest.mark.parametrize(
"window",
[3, 5, 6, 10], # data is nframes = 10
)
@pytest.mark.parametrize(
"filter_func",
[median_filter, savgol_filter],
)
def test_filter_with_nans_on_position_varying_window(
valid_dataset_with_nan, window, filter_func, helpers, request
):
"""Test that the number of NaNs in the filtered position data
increases at most by the filter's window length minus one
multiplied by the number of consecutive NaNs in the input data.
"""
# Prepare kwargs per filter
kwargs = {"window": window}
if filter_func == savgol_filter:
kwargs["polyorder"] = 2

# Filter position
valid_input_dataset = request.getfixturevalue(valid_dataset_with_nan)
position_filtered = filter_func(
valid_input_dataset.position,
**kwargs,
)

# Count number of NaNs in the input and filtered position data
n_total_nans_initial = helpers.count_nans(valid_input_dataset.position)
n_consecutive_nans_initial = helpers.count_consecutive_nans(
valid_input_dataset.position
)

n_total_nans_filtered = helpers.count_nans(position_filtered)

max_nans_increase = (window - 1) * n_consecutive_nans_initial

# Check that filtering does not reduce number of nans
assert n_total_nans_filtered >= n_total_nans_initial
# Check that the increase in nans is below the expected threshold
assert n_total_nans_filtered - n_total_nans_initial <= max_nans_increase


@pytest.mark.parametrize(
"valid_dataset",
list_all_valid_datasets,
Expand Down
Loading