Skip to content

Commit

Permalink
More fixes for test_kinematics.py
Browse files Browse the repository at this point in the history
  • Loading branch information
b-peri committed Sep 24, 2024
1 parent df6bed9 commit b00edcd
Showing 1 changed file with 38 additions and 38 deletions.
76 changes: 38 additions & 38 deletions tests/test_unit/test_kinematics.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def test_approximate_derivative_with_invalid_order(order):


@pytest.fixture
def valid_data_array_for_head_vector():
def valid_data_array_for_forward_vector():
"""Return a position data array for an individual with 3 keypoints
(left ear, right ear and nose), tracked for 4 frames, in x-y space.
"""
Expand Down Expand Up @@ -215,120 +215,120 @@ def valid_data_array_for_head_vector():


@pytest.fixture
def invalid_input_type_for_head_vector(valid_data_array_for_head_vector):
def invalid_input_type_for_forward_vector(valid_data_array_for_forward_vector):
"""Return a numpy array of position values by individual, per keypoint,
over time.
"""
return valid_data_array_for_head_vector.values
return valid_data_array_for_forward_vector.values


@pytest.fixture
def invalid_dimensions_for_head_vector(valid_data_array_for_head_vector):
def invalid_dimensions_for_forward_vector(valid_data_array_for_forward_vector):
"""Return a position DataArray in which the ``keypoints`` dimension has
been dropped.
"""
return valid_data_array_for_head_vector.sel(keypoints="nose", drop=True)
return valid_data_array_for_forward_vector.sel(keypoints="nose", drop=True)


@pytest.fixture
def invalid_spatial_dimensions_for_head_vector(
valid_data_array_for_head_vector,
def invalid_spatial_dimensions_for_forward_vector(
valid_data_array_for_forward_vector,
):
"""Return a position DataArray containing three spatial dimensions."""
dataarray_3d = valid_data_array_for_head_vector.pad(
dataarray_3d = valid_data_array_for_forward_vector.pad(
space=(0, 1), constant_values=0
)
return dataarray_3d.assign_coords(space=["x", "y", "z"])


@pytest.fixture
def valid_data_array_for_head_vector_with_nans(
valid_data_array_for_head_vector,
def valid_data_array_for_forward_vector_with_nans(
valid_data_array_for_forward_vector,
):
"""Return a position DataArray where position values are NaN for the
``left_ear`` keypoint at time ``1``.
"""
nan_dataarray = valid_data_array_for_head_vector.where(
(valid_data_array_for_head_vector.time != 1)
| (valid_data_array_for_head_vector.keypoints != "left_ear")
nan_dataarray = valid_data_array_for_forward_vector.where(
(valid_data_array_for_forward_vector.time != 1)
| (valid_data_array_for_forward_vector.keypoints != "left_ear")
)
return nan_dataarray


def test_compute_2d_head_direction_vector(valid_data_array_for_head_vector):
"""Test that the correct output head direction vectors
def test_compute_forward_vector(valid_data_array_for_forward_vector):
"""Test that the correct output forward direction vectors
are computed from a valid mock dataset.
"""
head_vector = kinematics.compute_2d_head_direction_vector(
valid_data_array_for_head_vector, "left_ear", "right_ear"
forward_vector = kinematics.compute_forward_vector(
valid_data_array_for_forward_vector, "left_ear", "right_ear"
)
known_vectors = np.array([[[0, 1]], [[-1, 0]], [[0, -1]], [[1, 0]]])
known_vectors = np.array([[[0, -1]], [[1, 0]], [[0, 1]], [[-1, 0]]])

assert (
isinstance(head_vector, xr.DataArray)
and ("space" in head_vector.dims)
and ("keypoints" not in head_vector.dims)
isinstance(forward_vector, xr.DataArray)
and ("space" in forward_vector.dims)
and ("keypoints" not in forward_vector.dims)
)
assert np.equal(head_vector.values, known_vectors).all()
assert np.equal(forward_vector.values, known_vectors).all()


@pytest.mark.parametrize(
"input_data, expected_error, expected_match_str, keypoints",
[
(
"invalid_input_type_for_head_vector",
"invalid_input_type_for_forward_vector",
TypeError,
"must be an xarray.DataArray",
["left_ear", "right_ear"],
),
(
"invalid_dimensions_for_head_vector",
AttributeError,
"'time', 'space', and 'keypoints'",
"invalid_dimensions_for_forward_vector",
ValueError,
"Input data must contain ['keypoints']",
["left_ear", "right_ear"],
),
(
"invalid_spatial_dimensions_for_head_vector",
"invalid_spatial_dimensions_for_forward_vector",
ValueError,
"must have 2 (and only 2) spatial dimensions",
["left_ear", "right_ear"],
),
(
"valid_data_array_for_head_vector",
"valid_data_array_for_forward_vector",
ValueError,
"keypoints may not be identical",
["left_ear", "left_ear"],
),
],
)
def test_compute_2d_head_direction_vector_with_invalid_input(
def test_compute_forward_vector_with_invalid_input(
input_data, keypoints, expected_error, expected_match_str, request
):
"""Test that ``compute_2d_head_direction_vector`` catches errors
"""Test that ``compute_forward_vector`` catches errors
correctly when passed invalid inputs.
"""
# Get fixture
input_data = request.getfixturevalue(input_data)

# Catch error
with pytest.raises(expected_error, match=re.escape(expected_match_str)):
kinematics.compute_2d_head_direction_vector(
kinematics.compute_forward_vector(
input_data, keypoints[0], keypoints[1]
)


def test_nan_behavior_2d_head_vector(
valid_data_array_for_head_vector_with_nans,
def test_nan_behavior_forward_vector(
valid_data_array_for_forward_vector_with_nans,
):
"""Test that ``compute_head_direction_vector()`` generates the
"""Test that ``compute_forward_vector()`` generates the
expected output for a valid input DataArray containing ``NaN``
position values at a single time (``1``) and keypoint
(``left_ear``).
"""
head_vector = kinematics.compute_2d_head_direction_vector(
valid_data_array_for_head_vector_with_nans, "left_ear", "right_ear"
forward_vector = kinematics.compute_forward_vector(
valid_data_array_for_forward_vector_with_nans, "left_ear", "right_ear"
)
assert (
np.isnan(head_vector.values[1, 0, :]).all()
and not np.isnan(head_vector.values[[0, 2, 3], 0, :]).any()
np.isnan(forward_vector.values[1, 0, :]).all()
and not np.isnan(forward_vector.values[[0, 2, 3], 0, :]).any()
)

0 comments on commit b00edcd

Please sign in to comment.