diff --git a/docs/source/conf.py b/docs/source/conf.py index 9b051fb0..fda3e86f 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -174,6 +174,7 @@ # A list of regular expressions that match URIs that should not be checked linkcheck_ignore = [ "https://pubs.acs.org/doi/*", # Checking dois is forbidden here + "https://opensource.org/license/bsd-3-clause/", # to avoid odd 403 error ] myst_url_schemes = { diff --git a/tests/test_integration/test_filtering.py b/tests/test_integration/test_filtering.py index 90efce6c..cba430f0 100644 --- a/tests/test_integration/test_filtering.py +++ b/tests/test_integration/test_filtering.py @@ -9,70 +9,56 @@ @pytest.fixture def sample_dataset(): - """Return a single-animal sample dataset, with time unit in frames. - This allows us to better control the expected number of NaNs in the tests. - """ + """Return a single-animal sample dataset, with time unit in frames.""" ds_path = fetch_dataset_paths("DLC_single-mouse_EPM.predictions.h5")[ "poses" ] ds = load_poses.from_dlc_file(ds_path) - ds["velocity"] = ds.move.compute_velocity() + return ds @pytest.mark.parametrize("window", [3, 5, 6, 13]) def test_nan_propagation_through_filters(sample_dataset, window, helpers): - """Test NaN propagation when passing a DataArray through - multiple filters sequentially. For the ``median_filter`` - and ``savgol_filter``, the number of NaNs is expected to increase + """Test NaN propagation is as expected when passing a DataArray through + filter by confidence, Savgol filter and interpolation. + For the ``savgol_filter``, the number of NaNs is expected to increase at most by the filter's window length minus one (``window - 1``) multiplied by the number of consecutive NaNs in the input data. """ - # Introduce nans via filter_by_confidence + # Compute number of low confidence keypoints + n_low_confidence_kpts = (sample_dataset.confidence.data < 0.6).sum() + + # Check filter position by confidence creates correct number of NaNs sample_dataset.update( {"position": sample_dataset.move.filter_by_confidence()} ) - expected_n_nans = 13136 - n_nans_confilt = helpers.count_nans(sample_dataset.position) - assert n_nans_confilt == expected_n_nans, ( - f"Expected {expected_n_nans} NaNs in filtered data, " - f"got: {n_nans_confilt}" - ) - n_consecutive_nans = helpers.count_consecutive_nans( - sample_dataset.position - ) - # Apply median filter and check that - # it doesn't introduce too many or too few NaNs - sample_dataset.update( - {"position": sample_dataset.move.median_filter(window)} - ) - n_nans_medfilt = helpers.count_nans(sample_dataset.position) - max_nans_increase = (window - 1) * n_consecutive_nans - assert ( - n_nans_medfilt <= n_nans_confilt + max_nans_increase - ), "Median filter introduced more NaNs than expected." + n_total_nans_input = helpers.count_nans(sample_dataset.position) + assert ( - n_nans_medfilt >= n_nans_confilt - ), "Median filter mysteriously removed NaNs." - n_consecutive_nans = helpers.count_consecutive_nans( + n_total_nans_input + == n_low_confidence_kpts * sample_dataset.dims["space"] + ) + + # Compute maximum expected increase in NaNs due to filtering + n_consecutive_nans_input = helpers.count_consecutive_nans( sample_dataset.position ) + max_nans_increase = (window - 1) * n_consecutive_nans_input - # Apply savgol filter and check that - # it doesn't introduce too many or too few NaNs + # Apply savgol filter and check that number of NaNs is within threshold sample_dataset.update( {"position": sample_dataset.move.savgol_filter(window, polyorder=2)} ) - n_nans_savgol = helpers.count_nans(sample_dataset.position) - max_nans_increase = (window - 1) * n_consecutive_nans - assert ( - n_nans_savgol <= n_nans_medfilt + max_nans_increase - ), "Savgol filter introduced more NaNs than expected." - assert ( - n_nans_savgol >= n_nans_medfilt - ), "Savgol filter mysteriously removed NaNs." - # Interpolate data (without max_gap) to eliminate all NaNs + n_total_nans_savgol = helpers.count_nans(sample_dataset.position) + + # Check that filtering does not reduce number of nans + assert n_total_nans_savgol >= n_total_nans_input + # Check that the increase in nans is below the expected threshold + assert n_total_nans_savgol - n_total_nans_input <= max_nans_increase + + # Interpolate data (without max_gap) and check it eliminates all NaNs sample_dataset.update( {"position": sample_dataset.move.interpolate_over_time()} ) @@ -105,6 +91,9 @@ def test_accessor_filter_method( applied, if valid data variables are passed, otherwise raise an exception. """ + # Compute velocity + sample_dataset["velocity"] = sample_dataset.move.compute_velocity() + with expected_exception as expected_type: if method in ["median_filter", "savgol_filter"]: # supply required "window" argument diff --git a/tests/test_integration/test_kinematics_vector_transform.py b/tests/test_integration/test_kinematics_vector_transform.py index 65318a08..63ecc2e4 100644 --- a/tests/test_integration/test_kinematics_vector_transform.py +++ b/tests/test_integration/test_kinematics_vector_transform.py @@ -1,33 +1,93 @@ -from contextlib import nullcontext as does_not_raise +import math +import numpy as np import pytest import xarray as xr from movement.utils import vector -class TestKinematicsVectorTransform: - """Test the vector transformation functionality with - various kinematic properties. +@pytest.mark.parametrize( + "valid_dataset_uniform_linear_motion", + [ + "valid_poses_dataset_uniform_linear_motion", + "valid_bboxes_dataset", + ], +) +@pytest.mark.parametrize( + "kinematic_variable, expected_kinematics_polar", + [ + ( + "displacement", + [ + np.vstack( + [ + np.zeros((1, 2)), + np.tile([math.sqrt(2), math.atan(1)], (9, 1)), + ], + ), # Individual 0, rho=sqrt(2), phi=45deg + np.vstack( + [ + np.zeros((1, 2)), + np.tile([math.sqrt(2), -math.atan(1)], (9, 1)), + ] + ), # Individual 1, rho=sqrt(2), phi=-45deg + ], + ), + ( + "velocity", + [ + np.tile( + [math.sqrt(2), math.atan(1)], (10, 1) + ), # Individual O, rho, phi=45deg + np.tile( + [math.sqrt(2), -math.atan(1)], (10, 1) + ), # Individual 1, rho, phi=-45deg + ], + ), + ( + "acceleration", + [ + np.zeros((10, 2)), # Individual 0 + np.zeros((10, 2)), # Individual 1 + ], + ), + ], +) +def test_cart2pol_transform_on_kinematics( + valid_dataset_uniform_linear_motion, + kinematic_variable, + expected_kinematics_polar, + request, +): + """Test transformation between Cartesian and polar coordinates + with various kinematic properties. """ + ds = request.getfixturevalue(valid_dataset_uniform_linear_motion) + kinematic_array_cart = getattr(ds.move, f"compute_{kinematic_variable}")() + kinematic_array_pol = vector.cart2pol(kinematic_array_cart) - @pytest.mark.parametrize( - "ds, expected_exception", - [ - ("valid_poses_dataset", does_not_raise()), - ("valid_poses_dataset_with_nan", does_not_raise()), - ("missing_dim_poses_dataset", pytest.raises(RuntimeError)), - ], + # Build expected data array + expected_array_pol = xr.DataArray( + np.stack(expected_kinematics_polar, axis=1), + # Stack along the "individuals" axis + dims=["time", "individuals", "space"], + ) + if "keypoints" in ds.position.coords: + expected_array_pol = expected_array_pol.expand_dims( + {"keypoints": ds.position.coords["keypoints"].size} + ) + expected_array_pol = expected_array_pol.transpose( + "time", "individuals", "keypoints", "space" + ) + + # Compare the values of the kinematic_array against the expected_array + np.testing.assert_allclose( + kinematic_array_pol.values, expected_array_pol.values + ) + + # Check we can recover the original Cartesian array + kinematic_array_cart_recover = vector.pol2cart(kinematic_array_pol) + xr.testing.assert_allclose( + kinematic_array_cart, kinematic_array_cart_recover ) - def test_cart_and_pol_transform( - self, ds, expected_exception, kinematic_property, request - ): - """Test transformation between Cartesian and polar coordinates - with various kinematic properties. - """ - ds = request.getfixturevalue(ds) - with expected_exception: - data = getattr(ds.move, f"compute_{kinematic_property}")() - pol_data = vector.cart2pol(data) - cart_data = vector.pol2cart(pol_data) - xr.testing.assert_allclose(cart_data, data) diff --git a/tests/test_unit/test_filtering.py b/tests/test_unit/test_filtering.py index 4b400287..d51af1be 100644 --- a/tests/test_unit/test_filtering.py +++ b/tests/test_unit/test_filtering.py @@ -233,6 +233,53 @@ def _assert_n_nans_in_position_per_individual( ) +@pytest.mark.parametrize( + "valid_dataset_with_nan", + list_valid_datasets_with_nans, +) +@pytest.mark.parametrize( + "window", + [3, 5, 6, 10], # data is nframes = 10 +) +@pytest.mark.parametrize( + "filter_func", + [median_filter, savgol_filter], +) +def test_filter_with_nans_on_position_varying_window( + valid_dataset_with_nan, window, filter_func, helpers, request +): + """Test that the number of NaNs in the filtered position data + increases at most by the filter's window length minus one + multiplied by the number of consecutive NaNs in the input data. + """ + # Prepare kwargs per filter + kwargs = {"window": window} + if filter_func == savgol_filter: + kwargs["polyorder"] = 2 + + # Filter position + valid_input_dataset = request.getfixturevalue(valid_dataset_with_nan) + position_filtered = filter_func( + valid_input_dataset.position, + **kwargs, + ) + + # Count number of NaNs in the input and filtered position data + n_total_nans_initial = helpers.count_nans(valid_input_dataset.position) + n_consecutive_nans_initial = helpers.count_consecutive_nans( + valid_input_dataset.position + ) + + n_total_nans_filtered = helpers.count_nans(position_filtered) + + max_nans_increase = (window - 1) * n_consecutive_nans_initial + + # Check that filtering does not reduce number of nans + assert n_total_nans_filtered >= n_total_nans_initial + # Check that the increase in nans is below the expected threshold + assert n_total_nans_filtered - n_total_nans_initial <= max_nans_increase + + @pytest.mark.parametrize( "valid_dataset", list_all_valid_datasets,