Skip to content

Commit

Permalink
General validator function for checking dimensions and coordinates (#294
Browse files Browse the repository at this point in the history
)

* replace time dim validator with more generic validator

* constrain kinematic functions to cartesian coordinates

* renamed new validator to validate_dims_coords

* add examples in validator docstring

* unit tests for the new validator

* Apply suggestions from code review

do note validate x,y space coordinates specifically.

Co-authored-by: Chang Huan Lo <[email protected]>

* reuse fixture valid_poses_dataset_uniform_linear_motion

* combine two unit tests into one

* expose public `compute_time_derivative` function

* Refactor test

* Update refs to `compute_time_derivative`

---------

Co-authored-by: Chang Huan Lo <[email protected]>
  • Loading branch information
niksirbi and lochhh authored Sep 17, 2024
1 parent 644c1b1 commit 9c80786
Show file tree
Hide file tree
Showing 5 changed files with 151 additions and 89 deletions.
66 changes: 26 additions & 40 deletions movement/analysis/kinematics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import xarray as xr

from movement.utils.logging import log_error
from movement.validators.arrays import validate_dims_coords


def compute_displacement(data: xr.DataArray) -> xr.DataArray:
Expand All @@ -18,8 +19,8 @@ def compute_displacement(data: xr.DataArray) -> xr.DataArray:
Parameters
----------
data : xarray.DataArray
The input data array containing position vectors in cartesian
coordinates, with ``time`` as a dimension.
The input data containing position information, with ``time``
and ``space`` (in Cartesian coordinates) as required dimensions.
Returns
-------
Expand All @@ -42,7 +43,7 @@ def compute_displacement(data: xr.DataArray) -> xr.DataArray:
height per bounding box, between consecutive time points.
"""
_validate_time_dimension(data)
validate_dims_coords(data, {"time": [], "space": []})
result = data.diff(dim="time")
result = result.reindex(data.coords, fill_value=0)
return result
Expand All @@ -58,8 +59,8 @@ def compute_velocity(data: xr.DataArray) -> xr.DataArray:
Parameters
----------
data : xarray.DataArray
The input data array containing position vectors in cartesian
coordinates, with ``time`` as a dimension.
The input data containing position information, with ``time``
and ``space`` (in Cartesian coordinates) as required dimensions.
Returns
-------
Expand All @@ -78,10 +79,13 @@ def compute_velocity(data: xr.DataArray) -> xr.DataArray:
See Also
--------
:meth:`xarray.DataArray.differentiate` : The underlying method used.
compute_time_derivative : The underlying function used.
"""
return _compute_approximate_time_derivative(data, order=1)
# validate only presence of Cartesian space dimension
# (presence of time dimension will be checked in compute_time_derivative)
validate_dims_coords(data, {"space": []})
return compute_time_derivative(data, order=1)


def compute_acceleration(data: xr.DataArray) -> xr.DataArray:
Expand All @@ -94,8 +98,8 @@ def compute_acceleration(data: xr.DataArray) -> xr.DataArray:
Parameters
----------
data : xarray.DataArray
The input data array containing position vectors in cartesian
coordinates, with``time`` as a dimension.
The input data containing position information, with ``time``
and ``space`` (in Cartesian coordinates) as required dimensions.
Returns
-------
Expand All @@ -115,15 +119,16 @@ def compute_acceleration(data: xr.DataArray) -> xr.DataArray:
See Also
--------
:meth:`xarray.DataArray.differentiate` : The underlying method used.
compute_time_derivative : The underlying function used.
"""
return _compute_approximate_time_derivative(data, order=2)
# validate only presence of Cartesian space dimension
# (presence of time dimension will be checked in compute_time_derivative)
validate_dims_coords(data, {"space": []})
return compute_time_derivative(data, order=2)


def _compute_approximate_time_derivative(
data: xr.DataArray, order: int
) -> xr.DataArray:
def compute_time_derivative(data: xr.DataArray, order: int) -> xr.DataArray:
"""Compute the time-derivative of an array using numerical differentiation.
This function uses :meth:`xarray.DataArray.differentiate`,
Expand All @@ -133,7 +138,7 @@ def _compute_approximate_time_derivative(
Parameters
----------
data : xarray.DataArray
The input data array containing ``time`` as a dimension.
The input data containing ``time`` as a required dimension.
order : int
The order of the time-derivative. For an input containing position
data, use 1 to compute velocity, and 2 to compute acceleration. Value
Expand All @@ -142,8 +147,11 @@ def _compute_approximate_time_derivative(
Returns
-------
xarray.DataArray
An xarray DataArray containing the time-derivative of the
input data.
An xarray DataArray containing the time-derivative of the input data.
See Also
--------
:meth:`xarray.DataArray.differentiate` : The underlying method used.
"""
if not isinstance(order, int):
Expand All @@ -152,30 +160,8 @@ def _compute_approximate_time_derivative(
)
if order <= 0:
raise log_error(ValueError, "Order must be a positive integer.")

_validate_time_dimension(data)

validate_dims_coords(data, {"time": []})
result = data
for _ in range(order):
result = result.differentiate("time")
return result


def _validate_time_dimension(data: xr.DataArray) -> None:
"""Validate the input data contains a ``time`` dimension.
Parameters
----------
data : xarray.DataArray
The input data to validate.
Raises
------
ValueError
If the input data does not contain a ``time`` dimension.
"""
if "time" not in data.dims:
raise log_error(
ValueError, "Input data must contain 'time' as a dimension."
)
55 changes: 7 additions & 48 deletions movement/utils/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import xarray as xr

from movement.utils.logging import log_error
from movement.validators.arrays import validate_dims_coords


def compute_norm(data: xr.DataArray) -> xr.DataArray:
Expand Down Expand Up @@ -39,15 +40,15 @@ def compute_norm(data: xr.DataArray) -> xr.DataArray:
"""
if "space" in data.dims:
_validate_dimension_coordinates(data, {"space": ["x", "y"]})
validate_dims_coords(data, {"space": ["x", "y"]})
return xr.apply_ufunc(
np.linalg.norm,
data,
input_core_dims=[["space"]],
kwargs={"axis": -1},
)
elif "space_pol" in data.dims:
_validate_dimension_coordinates(data, {"space_pol": ["rho", "phi"]})
validate_dims_coords(data, {"space_pol": ["rho", "phi"]})
return data.sel(space_pol="rho", drop=True)
else:
_raise_error_for_missing_spatial_dim()
Expand Down Expand Up @@ -78,10 +79,10 @@ def convert_to_unit(data: xr.DataArray) -> xr.DataArray:
"""
if "space" in data.dims:
_validate_dimension_coordinates(data, {"space": ["x", "y"]})
validate_dims_coords(data, {"space": ["x", "y"]})
return data / compute_norm(data)
elif "space_pol" in data.dims:
_validate_dimension_coordinates(data, {"space_pol": ["rho", "phi"]})
validate_dims_coords(data, {"space_pol": ["rho", "phi"]})
# Set both rho and phi values to NaN at null vectors (where rho = 0)
new_data = xr.where(data.sel(space_pol="rho") == 0, np.nan, data)
# Set the rho values to 1 for non-null vectors (phi is preserved)
Expand Down Expand Up @@ -111,7 +112,7 @@ def cart2pol(data: xr.DataArray) -> xr.DataArray:
``phi`` returned are in radians, in the range ``[-pi, pi]``.
"""
_validate_dimension_coordinates(data, {"space": ["x", "y"]})
validate_dims_coords(data, {"space": ["x", "y"]})
rho = compute_norm(data)
phi = xr.apply_ufunc(
np.arctan2,
Expand Down Expand Up @@ -147,7 +148,7 @@ def pol2cart(data: xr.DataArray) -> xr.DataArray:
in the dimension coordinate.
"""
_validate_dimension_coordinates(data, {"space_pol": ["rho", "phi"]})
validate_dims_coords(data, {"space_pol": ["rho", "phi"]})
rho = data.sel(space_pol="rho")
phi = data.sel(space_pol="phi")
x = rho * np.cos(phi)
Expand All @@ -164,48 +165,6 @@ def pol2cart(data: xr.DataArray) -> xr.DataArray:
).transpose(*dims)


def _validate_dimension_coordinates(
data: xr.DataArray, required_dim_coords: dict
) -> None:
"""Validate the input data array.
Ensure that it contains the required dimensions and coordinates.
Parameters
----------
data : xarray.DataArray
The input data to validate.
required_dim_coords : dict
A dictionary of required dimensions and their corresponding
coordinate values.
Raises
------
ValueError
If the input data does not contain the required dimension(s)
and/or the required coordinate(s).
"""
missing_dims = [dim for dim in required_dim_coords if dim not in data.dims]
error_message = ""
if missing_dims:
error_message += (
f"Input data must contain {missing_dims} as dimensions.\n"
)
missing_coords = []
for dim, coords in required_dim_coords.items():
missing_coords = [
coord for coord in coords if coord not in data.coords.get(dim, [])
]
if missing_coords:
error_message += (
"Input data must contain "
f"{missing_coords} in the '{dim}' coordinates."
)
if error_message:
raise log_error(ValueError, error_message)


def _raise_error_for_missing_spatial_dim() -> None:
raise log_error(
ValueError,
Expand Down
61 changes: 61 additions & 0 deletions movement/validators/arrays.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""Validators for data arrays."""

import xarray as xr

from movement.utils.logging import log_error


def validate_dims_coords(
data: xr.DataArray, required_dim_coords: dict
) -> None:
"""Validate dimensions and coordinates in a data array.
This function raises a ValueError if the specified dimensions and
coordinates are not present in the input data array.
Parameters
----------
data : xarray.DataArray
The input data array to validate.
required_dim_coords : dict
A dictionary of required dimensions and their corresponding
coordinate values. If you don't need to specify any
coordinate values, you can pass an empty list.
Examples
--------
Validate that a data array contains the dimension 'time'. No specific
coordinates are required.
>>> validate_dims_coords(data, {"time": []})
Validate that a data array contains the dimensions 'time' and 'space',
and that the 'space' dimension contains the coordinates 'x' and 'y'.
>>> validate_dims_coords(data, {"time": [], "space": ["x", "y"]})
Raises
------
ValueError
If the input data does not contain the required dimension(s)
and/or the required coordinate(s).
"""
missing_dims = [dim for dim in required_dim_coords if dim not in data.dims]
error_message = ""
if missing_dims:
error_message += (
f"Input data must contain {missing_dims} as dimensions.\n"
)
missing_coords = []
for dim, coords in required_dim_coords.items():
missing_coords = [
coord for coord in coords if coord not in data.coords.get(dim, [])
]
if missing_coords:
error_message += (
"Input data must contain "
f"{missing_coords} in the '{dim}' coordinates."
)
if error_message:
raise log_error(ValueError, error_message)
2 changes: 1 addition & 1 deletion tests/test_unit/test_kinematics.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,4 +181,4 @@ def test_approximate_derivative_with_invalid_order(order):
data = np.arange(10)
expected_exception = ValueError if isinstance(order, int) else TypeError
with pytest.raises(expected_exception):
kinematics._compute_approximate_time_derivative(data, order=order)
kinematics.compute_time_derivative(data, order=order)
56 changes: 56 additions & 0 deletions tests/test_unit/test_validators/test_array_validators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import re
from contextlib import nullcontext as does_not_raise

import pytest

from movement.validators.arrays import validate_dims_coords


def expect_value_error_with_message(error_msg):
"""Expect a ValueError with the specified error message."""
return pytest.raises(ValueError, match=re.escape(error_msg))


valid_cases = [
({"time": []}, does_not_raise()),
({"time": [0, 1]}, does_not_raise()),
({"space": ["x", "y"]}, does_not_raise()),
({"time": [], "space": []}, does_not_raise()),
({"time": [], "space": ["x", "y"]}, does_not_raise()),
] # Valid cases (no error)

invalid_cases = [
(
{"spacetime": []},
expect_value_error_with_message(
"Input data must contain ['spacetime'] as dimensions."
),
),
(
{"time": [0, 100], "space": ["x", "y"]},
expect_value_error_with_message(
"Input data must contain [100] in the 'time' coordinates."
),
),
(
{"space": ["x", "y", "z"]},
expect_value_error_with_message(
"Input data must contain ['z'] in the 'space' coordinates."
),
),
] # Invalid cases (raise ValueError)


@pytest.mark.parametrize(
"required_dims_coords, expected_exception",
valid_cases + invalid_cases,
)
def test_validate_dims_coords(
valid_poses_dataset_uniform_linear_motion, # fixture from conftest.py
required_dims_coords,
expected_exception,
):
"""Test validate_dims_coords for both valid and invalid inputs."""
position_array = valid_poses_dataset_uniform_linear_motion["position"]
with expected_exception:
validate_dims_coords(position_array, required_dims_coords)

0 comments on commit 9c80786

Please sign in to comment.