Skip to content

Commit

Permalink
Added unit test for compute_head_direction_vector()
Browse files Browse the repository at this point in the history
  • Loading branch information
b-peri committed Aug 22, 2024
1 parent c079611 commit 7b61234
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 8 deletions.
12 changes: 4 additions & 8 deletions movement/analysis/navigation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
def compute_head_direction_vector(
data: xr.DataArray, left_keypoint: str, right_keypoint: str
):
"""Compute the head direction vector given two keypoints on the head.
"""Compute the 2D head direction vector given two keypoints on the head.
The head direction vector is computed as a vector perpendicular to the
line connecting two keypoints on either side of the head, pointing
forwards (in a rostral direction).
forwards (in the rostral direction).
Parameters
----------
Expand Down Expand Up @@ -46,13 +46,9 @@ def compute_head_direction_vector(
"Input data must contain 'time', 'space', and 'keypoints' as "
"dimensions.",
)
if not all(
keypoint in data.keypoints
for keypoint in [left_keypoint, right_keypoint]
):
if left_keypoint == right_keypoint:
raise log_error(
AttributeError,
"The selected keypoints could not be found in the input dataset",
ValueError, "The left and right keypoints may not be identical."
)

# Select the right and left keypoints
Expand Down
82 changes: 82 additions & 0 deletions tests/test_unit/test_navigation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import numpy as np
import pytest
import xarray as xr

from movement.analysis import navigation


@pytest.fixture
def mock_dataset():
"""Return a mock DataArray containing four known head orientations."""
time = np.array([0, 1, 2, 3])
individuals = np.array(["individual_0"])
keypoints = np.array(["left_ear", "right_ear", "nose"])
space = np.array(["x", "y"])

ds = xr.DataArray(
[
[[[1, 0], [-1, 0], [0, -1]]], # time 0
[[[0, 1], [0, -1], [1, 0]]], # time 1
[[[-1, 0], [1, 0], [0, 1]]], # time 2
[[[0, -1], [0, 1], [-1, 0]]], # time 3
],
dims=["time", "individuals", "keypoints", "space"],
coords={
"time": time,
"individuals": individuals,
"keypoints": keypoints,
"space": space,
},
)
return ds


def test_compute_head_direction_vector(mock_dataset):
"""Test that the correct head direction vectors
are computed from a basic mock dataset.
"""
# Test that validators work
with pytest.raises(
TypeError,
match="Input data must be an xarray.DataArray, but got <class "
"'numpy.ndarray'>.",
):
np_array = [
[[[1, 0], [-1, 0], [0, -1]]],
[[[0, 1], [0, -1], [1, 0]]],
[[[-1, 0], [1, 0], [0, 1]]],
[[[0, -1], [0, 1], [-1, 0]]],
]
navigation.compute_head_direction_vector(
np_array, "left_ear", "right_ear"
)

with pytest.raises(
ValueError,
match="Input data must contain 'time', 'space', and 'keypoints'"
" as dimensions.",
):
mock_dataset_keypoint = mock_dataset.sel(keypoints="nose", drop=True)
navigation.compute_head_direction_vector(
mock_dataset_keypoint, "left_ear", "right_ear"
)

with pytest.raises(
ValueError, match="The left and right keypoints may not be identical."
):
navigation.compute_head_direction_vector(
mock_dataset, "left_ear", "left_ear"
)

# Test that output contains correct datatype, dimensions, and values
head_vector = navigation.compute_head_direction_vector(
mock_dataset, "left_ear", "right_ear"
)
known_vectors = np.array([[0, 2], [-2, 0], [0, -2], [2, 0]])

assert (
isinstance(head_vector, xr.DataArray)
and ("space" in head_vector.dims)
and ("keypoints" not in head_vector.dims)
)
assert np.equal(head_vector.values, known_vectors).all()

0 comments on commit 7b61234

Please sign in to comment.