Skip to content

Commit

Permalink
Update test function args + fix indentation
Browse files Browse the repository at this point in the history
  • Loading branch information
lochhh committed Sep 9, 2024
1 parent 925b17b commit a20d283
Showing 1 changed file with 113 additions and 111 deletions.
224 changes: 113 additions & 111 deletions tests/test_unit/test_kinematics.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,126 +185,128 @@ def test_approximate_derivative_with_invalid_order(order):
with pytest.raises(expected_exception):
kinematics._compute_approximate_time_derivative(data, order=order)

@pytest.mark.parametrize(
"dim, pairs, expected_data",
[
(
"individuals",
("ind1", "ind2"),
np.array(

@pytest.mark.parametrize(
"dim, pairs, expected_data",
[
(
"individuals",
("ind1", "ind2"),
np.array(
[
[
[1.0, 0.0, np.sqrt(2)],
[1.0, np.sqrt(2), 0.0],
[0.0, 1.0, 1.0],
],
[
[
[1.0, 0.0, np.sqrt(2)],
[1.0, np.sqrt(2), 0.0],
[0.0, 1.0, 1.0],
],
[
[np.sqrt(13), 0.0, np.sqrt(8)],
[1.0, np.sqrt(8), 0.0],
[0.0, np.sqrt(13), 1.0],
],
]
),
[np.sqrt(13), 0.0, np.sqrt(8)],
[1.0, np.sqrt(8), 0.0],
[0.0, np.sqrt(13), 1.0],
],
]
),
(
"keypoints",
("key1", "key2"),
np.array(
),
(
"keypoints",
("key1", "key2"),
np.array(
[
[
[
[np.sqrt(2), 0.0, 1.0],
[1.0, 1.0, 0.0],
[0.0, np.sqrt(2), 1.0],
],
[
[np.sqrt(8), 0.0, np.sqrt(13)],
[1.0, np.sqrt(13), 0.0],
[0.0, np.sqrt(8), 1.0],
],
]
),
[np.sqrt(2), 0.0, 1.0],
[1.0, 1.0, 0.0],
[0.0, np.sqrt(2), 1.0],
],
[
[np.sqrt(8), 0.0, np.sqrt(13)],
[1.0, np.sqrt(13), 0.0],
[0.0, np.sqrt(8), 1.0],
],
]
),
),
],
)
def test_cdist(dim, pairs, expected_data, pairwise_distances_dataset):
"""Test the computation of pairwise distances with known values."""
core_dim = "keypoints" if dim == "individuals" else "individuals"
input_dataarray = pairwise_distances_dataset.position
expected = xr.DataArray(
expected_data,
coords=[
input_dataarray.time.values,
getattr(input_dataarray, core_dim).values,
getattr(input_dataarray, core_dim).values,
],
dims=["time", pairs[0], pairs[1]],
)
a = input_dataarray.sel({dim: pairs[0]})
b = input_dataarray.sel({dim: pairs[1]})
result = kinematics.cdist(a, b, dim)
xr.testing.assert_equal(
result,
expected,
)
def test_cdist(
self, dim, pairs, expected_data, pairwise_distances_dataset
):
"""Test the computation of pairwise distances with known values."""
core_dim = "keypoints" if dim == "individuals" else "individuals"
input_dataarray = pairwise_distances_dataset.position
expected = xr.DataArray(
expected_data,
coords=[
input_dataarray.time.values,
getattr(input_dataarray, core_dim).values,
getattr(input_dataarray, core_dim).values,
],
dims=["time", pairs[0], pairs[1]],
)
a = input_dataarray.sel({dim: pairs[0]})
b = input_dataarray.sel({dim: pairs[1]})
result = kinematics.cdist(a, b, dim)
xr.testing.assert_equal(
result,
expected,
)

def expected_pairwise_distances(self, pairs, input_ds, dim):
"""Return a list of the expected data variable names
for pairwise distances tests.
"""
if pairs is None:
paired_elements = list(
itertools.combinations(getattr(input_ds, dim).values, 2)

def expected_pairwise_distances(pairs, input_ds, dim):
"""Return a list of the expected data variable names
for pairwise distances tests.
"""
if pairs is None:
paired_elements = list(
itertools.combinations(getattr(input_ds, dim).values, 2)
)
else:
paired_elements = [
(elem1, elem2)
for elem1, elem2_list in pairs.items()
for elem2 in (
[elem2_list] if isinstance(elem2_list, str) else elem2_list
)
else:
paired_elements = [
(elem1, elem2)
for elem1, elem2_list in pairs.items()
for elem2 in (
[elem2_list] if isinstance(elem2_list, str) else elem2_list
)
]
expected_data = [
f"dist_{elem1}_{elem2}" for elem1, elem2 in paired_elements
]
return expected_data
expected_data = [
f"dist_{elem1}_{elem2}" for elem1, elem2 in paired_elements
]
return expected_data

@pytest.mark.parametrize(
"dim, pairs",
[
("individuals", {"ind1": ["ind2"]}), # list input
("individuals", {"ind1": "ind2"}), # string input
("individuals", {"ind1": ["ind2", "ind3"], "ind2": "ind3"}),
("individuals", None), # all pairs
("keypoints", {"key1": ["key2"]}), # list input
("keypoints", {"key1": "key2"}), # string input
("keypoints", {"key1": ["key2", "key3"], "key2": "key3"}),
("keypoints", None), # all pairs
],

@pytest.mark.parametrize(
"dim, pairs",
[
("individuals", {"ind1": ["ind2"]}), # list input
("individuals", {"ind1": "ind2"}), # string input
("individuals", {"ind1": ["ind2", "ind3"], "ind2": "ind3"}),
("individuals", None), # all pairs
("keypoints", {"key1": ["key2"]}), # list input
("keypoints", {"key1": "key2"}), # string input
("keypoints", {"key1": ["key2", "key3"], "key2": "key3"}),
("keypoints", None), # all pairs
],
)
def test_compute_pairwise_distances_with_valid_pairs(
pairwise_distances_dataset, dim, pairs
):
"""Test that the expected pairwise distances are computed
for valid ``pairs`` inputs.
"""
result = getattr(kinematics, f"compute_inter{dim[:-1]}_distances")(
pairwise_distances_dataset.position, pairs=pairs
)
def test_compute_pairwise_distances_with_valid_pairs(
self, pairwise_distances_dataset, dim, pairs
):
"""Test that the expected pairwise distances are computed
for valid ``pairs`` inputs.
"""
result = getattr(kinematics, f"compute_inter{dim[:-1]}_distances")(
pairwise_distances_dataset.position, pairs=pairs
)
expected_data_vars = self.expected_pairwise_distances(
pairs, pairwise_distances_dataset, dim
)
if isinstance(result, dict):
assert set(result.keys()) == set(expected_data_vars)
else: # expect single DataArray
assert isinstance(result, xr.DataArray)
expected_data_vars = expected_pairwise_distances(
pairs, pairwise_distances_dataset, dim
)
if isinstance(result, dict):
assert set(result.keys()) == set(expected_data_vars)
else: # expect single DataArray
assert isinstance(result, xr.DataArray)

def test_compute_pairwise_distances_with_invalid_dim(
self, pairwise_distances_dataset
):
"""Test that an error is raised when an invalid dimension is passed."""
with pytest.raises(ValueError):
kinematics._compute_pairwise_distances(
pairwise_distances_dataset.position, "invalid_dim"
)

def test_compute_pairwise_distances_with_invalid_dim(
pairwise_distances_dataset,
):
"""Test that an error is raised when an invalid dimension is passed."""
with pytest.raises(ValueError):
kinematics._compute_pairwise_distances(
pairwise_distances_dataset.position, "invalid_dim"
)

0 comments on commit a20d283

Please sign in to comment.