diff --git a/movement/analysis/kinematics.py b/movement/analysis/kinematics.py index e8d8346fa..1d2731942 100644 --- a/movement/analysis/kinematics.py +++ b/movement/analysis/kinematics.py @@ -220,14 +220,8 @@ def cdist( core_dim = "individuals" if dim == "keypoints" else "keypoints" elem1 = getattr(a, dim).item() elem2 = getattr(b, dim).item() - if a.coords.get(core_dim) is None: - a = a.assign_coords({core_dim: "temp"}) - if b.coords.get(core_dim) is None: - b = b.assign_coords({core_dim: "temp"}) - if a.coords[core_dim].ndim == 0: - a = a.expand_dims(core_dim).transpose("time", "space", core_dim) - if b.coords[core_dim].ndim == 0: - b = b.expand_dims(core_dim).transpose("time", "space", core_dim) + a = _validate_core_dimension(a, core_dim) + b = _validate_core_dimension(b, core_dim) result = xr.apply_ufunc( _cdist, a, @@ -610,6 +604,35 @@ def _compute_pairwise_distances( return pairwise_distances +def _validate_core_dimension( + data: xr.DataArray, core_dim: str +) -> xr.DataArray: + """Validate the input data contains the required core dimension. + + This function ensures the input data contains the ``core_dim`` + required when applying :func:`scipy.spatial.distance.cdist` to + the input data, by adding a temporary dimension if necessary. + + Parameters + ---------- + data : xarray.DataArray + The input data to validate. + core_dim : str + The core dimension to validate. + + Returns + ------- + xarray.DataArray + The input data with the core dimension validated. + + """ + if data.coords.get(core_dim) is None: + data = data.assign_coords({core_dim: "temp_dim"}) + if data.coords[core_dim].ndim == 0: + data = data.expand_dims(core_dim).transpose("time", "space", core_dim) + return data + + def _validate_time_dimension(data: xr.DataArray) -> None: """Validate the input data contains a ``time`` dimension. diff --git a/tests/conftest.py b/tests/conftest.py index 0d1acc514..ba66cc585 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -574,40 +574,6 @@ def kinematic_property(request): return request.param -@pytest.fixture -def pairwise_distances_dataset(): - """Return a minimal poses dataset with 3 individuals - and 3 keypoints for pairwise distances computation. - """ - time = np.arange(2) - space = ["x", "y"] - individuals = ["ind1", "ind2", "ind3"] - keypoints = ["key1", "key2", "key3"] - data = np.array( - [ - [ - [[1, 1], [0, 0], [1, 0]], - [[1, 0], [1, 1], [0, 0]], - [[0, 0], [1, 0], [1, 1]], - ], - [ - [[3, 6], [1, 4], [0, 4]], - [[0, 4], [3, 6], [1, 4]], - [[1, 4], [0, 4], [3, 6]], - ], - ] - ) - return xr.Dataset( - data_vars={ - "position": xr.DataArray( - data, - coords=[time, individuals, keypoints, space], - dims=["time", "individuals", "keypoints", "space"], - ) - } - ) - - # ---------------- VIA tracks CSV file fixtures ---------------------------- @pytest.fixture def via_tracks_csv_with_invalid_header(tmp_path): diff --git a/tests/test_unit/test_kinematics.py b/tests/test_unit/test_kinematics.py index e3e1d4b2c..2b3db4b27 100644 --- a/tests/test_unit/test_kinematics.py +++ b/tests/test_unit/test_kinematics.py @@ -1,5 +1,4 @@ import itertools -from contextlib import nullcontext as does_not_raise import numpy as np import pytest @@ -188,52 +187,42 @@ def test_approximate_derivative_with_invalid_order(order): @pytest.mark.parametrize( - "dim, pairs, expected_data", + "dim, expected_data", [ ( "individuals", - ("ind1", "ind2"), np.array( [ [ - [1.0, 0.0, np.sqrt(2)], - [1.0, np.sqrt(2), 0.0], [0.0, 1.0, 1.0], + [1.0, np.sqrt(2), 0.0], + [1.0, 2.0, np.sqrt(2)], ], [ - [np.sqrt(13), 0.0, np.sqrt(8)], - [1.0, np.sqrt(8), 0.0], - [0.0, np.sqrt(13), 1.0], + [2.0, np.sqrt(5), 1.0], + [3.0, np.sqrt(10), 2.0], + [np.sqrt(5), np.sqrt(8), np.sqrt(2)], ], ] ), ), ( "keypoints", - ("key1", "key2"), np.array( - [ - [ - [np.sqrt(2), 0.0, 1.0], - [1.0, 1.0, 0.0], - [0.0, np.sqrt(2), 1.0], - ], - [ - [np.sqrt(8), 0.0, np.sqrt(13)], - [1.0, np.sqrt(13), 0.0], - [0.0, np.sqrt(8), 1.0], - ], - ] + [[[1.0, 1.0], [1.0, 1.0]], [[1.0, np.sqrt(5)], [3.0, 1.0]]] ), ), ], ) def test_cdist_with_known_values( - dim, pairs, expected_data, pairwise_distances_dataset + dim, expected_data, valid_poses_dataset_uniform_linear_motion ): """Test the computation of pairwise distances with known values.""" core_dim = "keypoints" if dim == "individuals" else "individuals" - input_dataarray = pairwise_distances_dataset.position + input_dataarray = valid_poses_dataset_uniform_linear_motion.position.sel( + time=slice(0, 1) + ) # Use only the first two frames for simplicity + pairs = input_dataarray[dim].values[:2] expected = xr.DataArray( expected_data, coords=[ @@ -252,62 +241,76 @@ def test_cdist_with_known_values( ) +@pytest.mark.parametrize( + "valid_dataset", + [ + "valid_poses_dataset_uniform_linear_motion", + "valid_bboxes_dataset", + ], +) @pytest.mark.parametrize( "selection_fn", [ + # individuals dim is scalar, + # poses: multiple keypoints + # bboxes: missing keypoints dim + # e.g. comparing 2 individuals from the same data array lambda position: ( - position.sel(individuals="ind1"), - position.sel(individuals="ind2"), - ), # individuals dim is scalar + position.isel(individuals=0), + position.isel(individuals=1), + ), + # individuals dim is 1D + # poses: multiple keypoints + # bboxes: missing keypoints dim + # e.g. comparing 2 single-individual data arrays lambda position: ( position.where( - position.individuals == "ind1", drop=True + position.individuals == position.individuals[0], drop=True ).squeeze(), position.where( - position.individuals == "ind2", drop=True + position.individuals == position.individuals[1], drop=True ).squeeze(), - ), # individuals dim is 1D - lambda position: ( - position.sel(individuals="ind1", keypoints="key1"), - position.sel(individuals="ind2", keypoints="key1"), - ), # both individuals and keypoints dims are scalar + ), + # both individuals and keypoints dims are scalar (poses only) + # e.g. comparing 2 individuals from the same data array, + # at the same keypoint lambda position: ( - position.where(position.keypoints == "key1", drop=True).sel( - individuals="ind1" - ), - position.where(position.keypoints == "key1", drop=True).sel( - individuals="ind2" - ), - ), # keypoints dim is 1D + position.isel(individuals=0, keypoints=0), + position.isel(individuals=1, keypoints=0), + ), + # individuals dim is scalar, keypoints dim is 1D (poses only) + # e.g. comparing 2 single-individual, single-keypoint data arrays lambda position: ( - position.drop_sel(keypoints=position.keypoints.values[1:]) - .squeeze(drop=True) - .sel(individuals="ind1"), - position.drop_sel(keypoints=position.keypoints.values[1:]) - .squeeze(drop=True) - .sel(individuals="ind2"), - ), # missing core dim + position.where( + position.keypoints == position.keypoints[0], drop=True + ).isel(individuals=0), + position.where( + position.keypoints == position.keypoints[0], drop=True + ).isel(individuals=1), + ), ], ids=[ "dim_has_ndim_0", "dim_has_ndim_1", "core_dim_has_ndim_0", "core_dim_has_ndim_1", - "missing_core_dim", ], ) -def test_cdist_with_single_dim_inputs( - pairwise_distances_dataset, selection_fn -): - """Test that the computation of pairwise distances - works regardless of whether the input DataArrays have - ```dim``` and ```core_dim``` being either scalar (ndim=0) - or 1D (ndim=1), or if ``core_dim`` is missing. +def test_cdist_with_single_dim_inputs(valid_dataset, selection_fn, request): + """Test that the pairwise distances data array is successfully + returned regardless of whether the input DataArrays have + ``dim`` ("individuals") and ``core_dim`` ("keypoints") + being either scalar (ndim=0) or 1D (ndim=1), + or if ``core_dim`` is missing. """ - position = pairwise_distances_dataset.position - a, b = selection_fn(position) - with does_not_raise(): - kinematics.cdist(a, b, "individuals") + if request.node.callspec.id not in [ + "core_dim_has_ndim_0-valid_bboxes_dataset", + "core_dim_has_ndim_1-valid_bboxes_dataset", + ]: # Skip tests with keypoints dim for bboxes + valid_dataset = request.getfixturevalue(valid_dataset) + position = valid_dataset.position + a, b = selection_fn(position) + assert isinstance(kinematics.cdist(a, b, "individuals"), xr.DataArray) def expected_pairwise_distances(pairs, input_ds, dim): @@ -335,27 +338,27 @@ def expected_pairwise_distances(pairs, input_ds, dim): @pytest.mark.parametrize( "dim, pairs", [ - ("individuals", {"ind1": ["ind2"]}), # list input - ("individuals", {"ind1": "ind2"}), # string input - ("individuals", {"ind1": ["ind2", "ind3"], "ind2": "ind3"}), + ("individuals", {"id_1": ["id_2"]}), # list input + ("individuals", {"id_1": "id_2"}), # string input + ("individuals", {"id_1": ["id_2"], "id_2": "id_1"}), ("individuals", None), # all pairs - ("keypoints", {"key1": ["key2"]}), # list input - ("keypoints", {"key1": "key2"}), # string input - ("keypoints", {"key1": ["key2", "key3"], "key2": "key3"}), + ("keypoints", {"centroid": ["left"]}), # list input + ("keypoints", {"centroid": "left"}), # string input + ("keypoints", {"centroid": ["left"], "left": "right"}), ("keypoints", None), # all pairs ], ) def test_compute_pairwise_distances_with_valid_pairs( - pairwise_distances_dataset, dim, pairs + valid_poses_dataset_uniform_linear_motion, dim, pairs ): """Test that the expected pairwise distances are computed for valid ``pairs`` inputs. """ result = getattr(kinematics, f"compute_inter{dim[:-1]}_distances")( - pairwise_distances_dataset.position, pairs=pairs + valid_poses_dataset_uniform_linear_motion.position, pairs=pairs ) expected_data_vars = expected_pairwise_distances( - pairs, pairwise_distances_dataset, dim + pairs, valid_poses_dataset_uniform_linear_motion, dim ) if isinstance(result, dict): assert set(result.keys()) == set(expected_data_vars) @@ -364,10 +367,10 @@ def test_compute_pairwise_distances_with_valid_pairs( def test_compute_pairwise_distances_with_invalid_dim( - pairwise_distances_dataset, + valid_poses_dataset_uniform_linear_motion, ): """Test that an error is raised when an invalid dimension is passed.""" with pytest.raises(ValueError): kinematics._compute_pairwise_distances( - pairwise_distances_dataset.position, "invalid_dim" + valid_poses_dataset_uniform_linear_motion.position, "invalid_dim" )