Skip to content

Commit

Permalink
Fix dims and coords returned by compute_forward_vector (#382)
Browse files Browse the repository at this point in the history
* Implement the expected fix

updates:
- [github.com/astral-sh/ruff-pre-commit: v0.8.1 → v0.8.6](astral-sh/ruff-pre-commit@v0.8.1...v0.8.6)
- [github.com/pre-commit/mirrors-mypy: v1.13.0 → v1.14.1](pre-commit/mirrors-mypy@v1.13.0...v1.14.1)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* Revert "Implement the expected fix"

This reverts commit edbaa9c.

* Update method to explicitly construct and drop spatial z dimensions

* Force nan test to check for preserved coordinates

* Force explicit coordinate preservation checks in input/output test

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
willGraham01 and pre-commit-ci[bot] authored Jan 21, 2025
1 parent fc30b5a commit 15b3f41
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 11 deletions.
7 changes: 6 additions & 1 deletion movement/kinematics.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,12 +292,17 @@ def compute_forward_vector(
upward_vector = xr.DataArray(
np.tile(upward_vector.reshape(1, -1), [len(data.time), 1]),
dims=["time", "space"],
coords={
"space": ["x", "y", "z"],
},
)
# Compute forward direction as the cross product
# (right-to-left) cross (forward) = up
forward_vector = xr.cross(
right_to_left_vector, upward_vector, dim="space"
)[:, :, :-1] # keep only the first 2 dimensions of the result
).drop_sel(
space="z"
) # keep only the first 2 spatal dimensions of the result
# Return unit vector
return forward_vector / compute_norm(forward_vector)

Expand Down
43 changes: 33 additions & 10 deletions tests/test_unit/test_kinematics.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,11 +462,14 @@ def test_compute_forward_vector(valid_data_array_for_forward_vector):
)
known_vectors = np.array([[[0, -1]], [[1, 0]], [[0, 1]], [[-1, 0]]])

assert (
isinstance(forward_vector, xr.DataArray)
and ("space" in forward_vector.dims)
and ("keypoints" not in forward_vector.dims)
)
for output_array in [forward_vector, forward_vector_flipped, head_vector]:
assert isinstance(output_array, xr.DataArray)
for preserved_coord in ["time", "space", "individuals"]:
assert np.all(
output_array[preserved_coord]
== valid_data_array_for_forward_vector[preserved_coord]
)
assert set(output_array["space"].values) == {"x", "y"}
assert np.equal(forward_vector.values, known_vectors).all()
assert np.equal(forward_vector_flipped.values, known_vectors * -1).all()
assert head_vector.equals(forward_vector)
Expand Down Expand Up @@ -518,20 +521,40 @@ def test_compute_forward_vector_with_invalid_input(


def test_nan_behavior_forward_vector(
valid_data_array_for_forward_vector_with_nans,
valid_data_array_for_forward_vector_with_nans: xr.DataArray,
):
"""Test that ``compute_forward_vector()`` generates the
expected output for a valid input DataArray containing ``NaN``
position values at a single time (``1``) and keypoint
(``left_ear``).
"""
nan_time = 1
forward_vector = kinematics.compute_forward_vector(
valid_data_array_for_forward_vector_with_nans, "left_ear", "right_ear"
)
assert (
np.isnan(forward_vector.values[1, 0, :]).all()
and not np.isnan(forward_vector.values[[0, 2, 3], 0, :]).any()
)
# Check coord preservation
for preserved_coord in ["time", "space", "individuals"]:
assert np.all(
forward_vector[preserved_coord]
== valid_data_array_for_forward_vector_with_nans[preserved_coord]
)
assert set(forward_vector["space"].values) == {"x", "y"}
# Should have NaN values in the forward vector at time 1 and left_ear
nan_values = forward_vector.sel(time=nan_time)
assert nan_values.shape == (1, 2)
assert np.isnan(
nan_values
).all(), "NaN values not returned where expected!"
# Should have no NaN values in the forward vector in other positions
assert not np.isnan(
forward_vector.sel(
time=[
t
for t in valid_data_array_for_forward_vector_with_nans.time
if t != nan_time
]
)
).any()


@pytest.mark.parametrize(
Expand Down

0 comments on commit 15b3f41

Please sign in to comment.