diff --git a/tests/test_unit/test_kinematics.py b/tests/test_unit/test_kinematics.py index c2db89c27..5914dfd24 100644 --- a/tests/test_unit/test_kinematics.py +++ b/tests/test_unit/test_kinematics.py @@ -185,126 +185,128 @@ def test_approximate_derivative_with_invalid_order(order): with pytest.raises(expected_exception): kinematics._compute_approximate_time_derivative(data, order=order) - @pytest.mark.parametrize( - "dim, pairs, expected_data", - [ - ( - "individuals", - ("ind1", "ind2"), - np.array( + +@pytest.mark.parametrize( + "dim, pairs, expected_data", + [ + ( + "individuals", + ("ind1", "ind2"), + np.array( + [ + [ + [1.0, 0.0, np.sqrt(2)], + [1.0, np.sqrt(2), 0.0], + [0.0, 1.0, 1.0], + ], [ - [ - [1.0, 0.0, np.sqrt(2)], - [1.0, np.sqrt(2), 0.0], - [0.0, 1.0, 1.0], - ], - [ - [np.sqrt(13), 0.0, np.sqrt(8)], - [1.0, np.sqrt(8), 0.0], - [0.0, np.sqrt(13), 1.0], - ], - ] - ), + [np.sqrt(13), 0.0, np.sqrt(8)], + [1.0, np.sqrt(8), 0.0], + [0.0, np.sqrt(13), 1.0], + ], + ] ), - ( - "keypoints", - ("key1", "key2"), - np.array( + ), + ( + "keypoints", + ("key1", "key2"), + np.array( + [ [ - [ - [np.sqrt(2), 0.0, 1.0], - [1.0, 1.0, 0.0], - [0.0, np.sqrt(2), 1.0], - ], - [ - [np.sqrt(8), 0.0, np.sqrt(13)], - [1.0, np.sqrt(13), 0.0], - [0.0, np.sqrt(8), 1.0], - ], - ] - ), + [np.sqrt(2), 0.0, 1.0], + [1.0, 1.0, 0.0], + [0.0, np.sqrt(2), 1.0], + ], + [ + [np.sqrt(8), 0.0, np.sqrt(13)], + [1.0, np.sqrt(13), 0.0], + [0.0, np.sqrt(8), 1.0], + ], + ] ), + ), + ], +) +def test_cdist(dim, pairs, expected_data, pairwise_distances_dataset): + """Test the computation of pairwise distances with known values.""" + core_dim = "keypoints" if dim == "individuals" else "individuals" + input_dataarray = pairwise_distances_dataset.position + expected = xr.DataArray( + expected_data, + coords=[ + input_dataarray.time.values, + getattr(input_dataarray, core_dim).values, + getattr(input_dataarray, core_dim).values, ], + dims=["time", pairs[0], pairs[1]], + ) + a = input_dataarray.sel({dim: pairs[0]}) + b = input_dataarray.sel({dim: pairs[1]}) + result = kinematics.cdist(a, b, dim) + xr.testing.assert_equal( + result, + expected, ) - def test_cdist( - self, dim, pairs, expected_data, pairwise_distances_dataset - ): - """Test the computation of pairwise distances with known values.""" - core_dim = "keypoints" if dim == "individuals" else "individuals" - input_dataarray = pairwise_distances_dataset.position - expected = xr.DataArray( - expected_data, - coords=[ - input_dataarray.time.values, - getattr(input_dataarray, core_dim).values, - getattr(input_dataarray, core_dim).values, - ], - dims=["time", pairs[0], pairs[1]], - ) - a = input_dataarray.sel({dim: pairs[0]}) - b = input_dataarray.sel({dim: pairs[1]}) - result = kinematics.cdist(a, b, dim) - xr.testing.assert_equal( - result, - expected, - ) - def expected_pairwise_distances(self, pairs, input_ds, dim): - """Return a list of the expected data variable names - for pairwise distances tests. - """ - if pairs is None: - paired_elements = list( - itertools.combinations(getattr(input_ds, dim).values, 2) + +def expected_pairwise_distances(pairs, input_ds, dim): + """Return a list of the expected data variable names + for pairwise distances tests. + """ + if pairs is None: + paired_elements = list( + itertools.combinations(getattr(input_ds, dim).values, 2) + ) + else: + paired_elements = [ + (elem1, elem2) + for elem1, elem2_list in pairs.items() + for elem2 in ( + [elem2_list] if isinstance(elem2_list, str) else elem2_list ) - else: - paired_elements = [ - (elem1, elem2) - for elem1, elem2_list in pairs.items() - for elem2 in ( - [elem2_list] if isinstance(elem2_list, str) else elem2_list - ) - ] - expected_data = [ - f"dist_{elem1}_{elem2}" for elem1, elem2 in paired_elements ] - return expected_data + expected_data = [ + f"dist_{elem1}_{elem2}" for elem1, elem2 in paired_elements + ] + return expected_data - @pytest.mark.parametrize( - "dim, pairs", - [ - ("individuals", {"ind1": ["ind2"]}), # list input - ("individuals", {"ind1": "ind2"}), # string input - ("individuals", {"ind1": ["ind2", "ind3"], "ind2": "ind3"}), - ("individuals", None), # all pairs - ("keypoints", {"key1": ["key2"]}), # list input - ("keypoints", {"key1": "key2"}), # string input - ("keypoints", {"key1": ["key2", "key3"], "key2": "key3"}), - ("keypoints", None), # all pairs - ], + +@pytest.mark.parametrize( + "dim, pairs", + [ + ("individuals", {"ind1": ["ind2"]}), # list input + ("individuals", {"ind1": "ind2"}), # string input + ("individuals", {"ind1": ["ind2", "ind3"], "ind2": "ind3"}), + ("individuals", None), # all pairs + ("keypoints", {"key1": ["key2"]}), # list input + ("keypoints", {"key1": "key2"}), # string input + ("keypoints", {"key1": ["key2", "key3"], "key2": "key3"}), + ("keypoints", None), # all pairs + ], +) +def test_compute_pairwise_distances_with_valid_pairs( + pairwise_distances_dataset, dim, pairs +): + """Test that the expected pairwise distances are computed + for valid ``pairs`` inputs. + """ + result = getattr(kinematics, f"compute_inter{dim[:-1]}_distances")( + pairwise_distances_dataset.position, pairs=pairs ) - def test_compute_pairwise_distances_with_valid_pairs( - self, pairwise_distances_dataset, dim, pairs - ): - """Test that the expected pairwise distances are computed - for valid ``pairs`` inputs. - """ - result = getattr(kinematics, f"compute_inter{dim[:-1]}_distances")( - pairwise_distances_dataset.position, pairs=pairs - ) - expected_data_vars = self.expected_pairwise_distances( - pairs, pairwise_distances_dataset, dim - ) - if isinstance(result, dict): - assert set(result.keys()) == set(expected_data_vars) - else: # expect single DataArray - assert isinstance(result, xr.DataArray) + expected_data_vars = expected_pairwise_distances( + pairs, pairwise_distances_dataset, dim + ) + if isinstance(result, dict): + assert set(result.keys()) == set(expected_data_vars) + else: # expect single DataArray + assert isinstance(result, xr.DataArray) - def test_compute_pairwise_distances_with_invalid_dim( - self, pairwise_distances_dataset - ): - """Test that an error is raised when an invalid dimension is passed.""" - with pytest.raises(ValueError): - kinematics._compute_pairwise_distances( - pairwise_distances_dataset.position, "invalid_dim" - ) + +def test_compute_pairwise_distances_with_invalid_dim( + pairwise_distances_dataset, +): + """Test that an error is raised when an invalid dimension is passed.""" + with pytest.raises(ValueError): + kinematics._compute_pairwise_distances( + pairwise_distances_dataset.position, "invalid_dim" + )