Skip to content

Commit

Permalink
Spell out expected pairs in test
Browse files Browse the repository at this point in the history
  • Loading branch information
lochhh committed Oct 25, 2024
1 parent 57a4ad6 commit 29a9f0a
Showing 1 changed file with 25 additions and 36 deletions.
61 changes: 25 additions & 36 deletions tests/test_unit/test_kinematics.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import itertools
import re

import numpy as np
Expand Down Expand Up @@ -479,54 +478,44 @@ def test_cdist_with_single_dim_inputs(valid_dataset, selection_fn, request):
assert isinstance(kinematics._cdist(a, b, "individuals"), xr.DataArray)


def expected_pairwise_distances(pairs, input_ds, dim):
"""Return a list of the expected data variable names
for pairwise distances tests.
"""
if pairs == "all":
paired_elements = list(
itertools.combinations(getattr(input_ds, dim).values, 2)
)
else:
paired_elements = [
(elem1, elem2)
for elem1, elem2_list in pairs.items()
for elem2 in (
[elem2_list] if isinstance(elem2_list, str) else elem2_list
)
]
expected_data = [
f"dist_{elem1}_{elem2}" for elem1, elem2 in paired_elements
]
return expected_data


@pytest.mark.parametrize(
"dim, pairs",
"dim, pairs, expected_data_vars",
[
("individuals", {"id_1": ["id_2"]}), # list input
("individuals", {"id_1": "id_2"}), # string input
("individuals", {"id_1": ["id_2"], "id_2": "id_1"}),
("individuals", "all"), # all pairs
("keypoints", {"centroid": ["left"]}), # list input
("keypoints", {"centroid": "left"}), # string input
("keypoints", {"centroid": ["left"], "left": "right"}),
("keypoints", "all"), # all pairs
("individuals", {"id_1": ["id_2"]}, None), # list input
("individuals", {"id_1": "id_2"}, None), # string input
(
"individuals",
{"id_1": ["id_2"], "id_2": "id_1"},
[("id_1", "id_2"), ("id_2", "id_1")],
),
("individuals", "all", None), # all pairs
("keypoints", {"centroid": ["left"]}, None), # list input
("keypoints", {"centroid": "left"}, None), # string input
(
"keypoints",
{"centroid": ["left"], "left": "right"},
[("centroid", "left"), ("left", "right")],
),
(
"keypoints",
"all",
[("centroid", "left"), ("centroid", "right"), ("left", "right")],
), # all pairs
],
)
def test_compute_pairwise_distances_with_valid_pairs(
valid_poses_dataset_uniform_linear_motion, dim, pairs
valid_poses_dataset_uniform_linear_motion, dim, pairs, expected_data_vars
):
"""Test that the expected pairwise distances are computed
for valid ``pairs`` inputs.
"""
result = kinematics.compute_pairwise_distances(
valid_poses_dataset_uniform_linear_motion.position, dim, pairs
)
expected_data_vars = expected_pairwise_distances(
pairs, valid_poses_dataset_uniform_linear_motion, dim
)
if isinstance(result, dict):
expected_data_vars = [
f"dist_{pair[0]}_{pair[1]}" for pair in expected_data_vars
]
assert set(result.keys()) == set(expected_data_vars)
else: # expect single DataArray
assert isinstance(result, xr.DataArray)
Expand Down

0 comments on commit 29a9f0a

Please sign in to comment.