From 29a9f0a1ba72ef929f7d6d5dc6d2a6827b281bd7 Mon Sep 17 00:00:00 2001 From: lochhh Date: Fri, 25 Oct 2024 17:33:43 +0100 Subject: [PATCH] Spell out expected pairs in test --- tests/test_unit/test_kinematics.py | 61 ++++++++++++------------------ 1 file changed, 25 insertions(+), 36 deletions(-) diff --git a/tests/test_unit/test_kinematics.py b/tests/test_unit/test_kinematics.py index 012b12fd8..250082df1 100644 --- a/tests/test_unit/test_kinematics.py +++ b/tests/test_unit/test_kinematics.py @@ -1,4 +1,3 @@ -import itertools import re import numpy as np @@ -479,43 +478,33 @@ def test_cdist_with_single_dim_inputs(valid_dataset, selection_fn, request): assert isinstance(kinematics._cdist(a, b, "individuals"), xr.DataArray) -def expected_pairwise_distances(pairs, input_ds, dim): - """Return a list of the expected data variable names - for pairwise distances tests. - """ - if pairs == "all": - paired_elements = list( - itertools.combinations(getattr(input_ds, dim).values, 2) - ) - else: - paired_elements = [ - (elem1, elem2) - for elem1, elem2_list in pairs.items() - for elem2 in ( - [elem2_list] if isinstance(elem2_list, str) else elem2_list - ) - ] - expected_data = [ - f"dist_{elem1}_{elem2}" for elem1, elem2 in paired_elements - ] - return expected_data - - @pytest.mark.parametrize( - "dim, pairs", + "dim, pairs, expected_data_vars", [ - ("individuals", {"id_1": ["id_2"]}), # list input - ("individuals", {"id_1": "id_2"}), # string input - ("individuals", {"id_1": ["id_2"], "id_2": "id_1"}), - ("individuals", "all"), # all pairs - ("keypoints", {"centroid": ["left"]}), # list input - ("keypoints", {"centroid": "left"}), # string input - ("keypoints", {"centroid": ["left"], "left": "right"}), - ("keypoints", "all"), # all pairs + ("individuals", {"id_1": ["id_2"]}, None), # list input + ("individuals", {"id_1": "id_2"}, None), # string input + ( + "individuals", + {"id_1": ["id_2"], "id_2": "id_1"}, + [("id_1", "id_2"), ("id_2", "id_1")], + ), + ("individuals", "all", None), # all pairs + ("keypoints", {"centroid": ["left"]}, None), # list input + ("keypoints", {"centroid": "left"}, None), # string input + ( + "keypoints", + {"centroid": ["left"], "left": "right"}, + [("centroid", "left"), ("left", "right")], + ), + ( + "keypoints", + "all", + [("centroid", "left"), ("centroid", "right"), ("left", "right")], + ), # all pairs ], ) def test_compute_pairwise_distances_with_valid_pairs( - valid_poses_dataset_uniform_linear_motion, dim, pairs + valid_poses_dataset_uniform_linear_motion, dim, pairs, expected_data_vars ): """Test that the expected pairwise distances are computed for valid ``pairs`` inputs. @@ -523,10 +512,10 @@ def test_compute_pairwise_distances_with_valid_pairs( result = kinematics.compute_pairwise_distances( valid_poses_dataset_uniform_linear_motion.position, dim, pairs ) - expected_data_vars = expected_pairwise_distances( - pairs, valid_poses_dataset_uniform_linear_motion, dim - ) if isinstance(result, dict): + expected_data_vars = [ + f"dist_{pair[0]}_{pair[1]}" for pair in expected_data_vars + ] assert set(result.keys()) == set(expected_data_vars) else: # expect single DataArray assert isinstance(result, xr.DataArray)