Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

General validator function for checking dimensions and coordinates #294

Merged
merged 11 commits into from
Sep 17, 2024
66 changes: 26 additions & 40 deletions movement/analysis/kinematics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import xarray as xr

from movement.utils.logging import log_error
from movement.validators.arrays import validate_dims_coords


def compute_displacement(data: xr.DataArray) -> xr.DataArray:
Expand All @@ -18,8 +19,8 @@ def compute_displacement(data: xr.DataArray) -> xr.DataArray:
Parameters
----------
data : xarray.DataArray
The input data array containing position vectors in cartesian
coordinates, with ``time`` as a dimension.
The input data containing position information, with ``time``
and ``space`` (in Cartesian coordinates) as required dimensions.
lochhh marked this conversation as resolved.
Show resolved Hide resolved

Returns
-------
Expand All @@ -42,7 +43,7 @@ def compute_displacement(data: xr.DataArray) -> xr.DataArray:
height per bounding box, between consecutive time points.

"""
_validate_time_dimension(data)
validate_dims_coords(data, {"time": [], "space": []})
result = data.diff(dim="time")
result = result.reindex(data.coords, fill_value=0)
return result
Expand All @@ -58,8 +59,8 @@ def compute_velocity(data: xr.DataArray) -> xr.DataArray:
Parameters
----------
data : xarray.DataArray
The input data array containing position vectors in cartesian
coordinates, with ``time`` as a dimension.
The input data containing position information, with ``time``
and ``space`` (in Cartesian coordinates) as required dimensions.
lochhh marked this conversation as resolved.
Show resolved Hide resolved

Returns
-------
Expand All @@ -78,10 +79,13 @@ def compute_velocity(data: xr.DataArray) -> xr.DataArray:

See Also
--------
:meth:`xarray.DataArray.differentiate` : The underlying method used.
compute_time_derivative : The underlying function used.

"""
return _compute_approximate_time_derivative(data, order=1)
# validate only presence of Cartesian space dimension
# (presence of time dimension will be checked in compute_time_derivative)
validate_dims_coords(data, {"space": []})
return compute_time_derivative(data, order=1)


def compute_acceleration(data: xr.DataArray) -> xr.DataArray:
Expand All @@ -94,8 +98,8 @@ def compute_acceleration(data: xr.DataArray) -> xr.DataArray:
Parameters
----------
data : xarray.DataArray
The input data array containing position vectors in cartesian
coordinates, with``time`` as a dimension.
The input data containing position information, with ``time``
and ``space`` (in Cartesian coordinates) as required dimensions.
lochhh marked this conversation as resolved.
Show resolved Hide resolved

Returns
-------
Expand All @@ -115,15 +119,16 @@ def compute_acceleration(data: xr.DataArray) -> xr.DataArray:

See Also
--------
:meth:`xarray.DataArray.differentiate` : The underlying method used.
compute_time_derivative : The underlying function used.

"""
return _compute_approximate_time_derivative(data, order=2)
# validate only presence of Cartesian space dimension
# (presence of time dimension will be checked in compute_time_derivative)
validate_dims_coords(data, {"space": []})
return compute_time_derivative(data, order=2)


def _compute_approximate_time_derivative(
data: xr.DataArray, order: int
) -> xr.DataArray:
def compute_time_derivative(data: xr.DataArray, order: int) -> xr.DataArray:
"""Compute the time-derivative of an array using numerical differentiation.

This function uses :meth:`xarray.DataArray.differentiate`,
Expand All @@ -133,7 +138,7 @@ def _compute_approximate_time_derivative(
Parameters
----------
data : xarray.DataArray
The input data array containing ``time`` as a dimension.
The input data containing ``time`` as a required dimension.
order : int
The order of the time-derivative. For an input containing position
data, use 1 to compute velocity, and 2 to compute acceleration. Value
Expand All @@ -142,8 +147,11 @@ def _compute_approximate_time_derivative(
Returns
-------
xarray.DataArray
An xarray DataArray containing the time-derivative of the
input data.
An xarray DataArray containing the time-derivative of the input data.

See Also
--------
:meth:`xarray.DataArray.differentiate` : The underlying method used.

"""
if not isinstance(order, int):
Expand All @@ -152,30 +160,8 @@ def _compute_approximate_time_derivative(
)
if order <= 0:
raise log_error(ValueError, "Order must be a positive integer.")

_validate_time_dimension(data)

validate_dims_coords(data, {"time": []})
result = data
for _ in range(order):
result = result.differentiate("time")
return result


def _validate_time_dimension(data: xr.DataArray) -> None:
"""Validate the input data contains a ``time`` dimension.

Parameters
----------
data : xarray.DataArray
The input data to validate.

Raises
------
ValueError
If the input data does not contain a ``time`` dimension.

"""
if "time" not in data.dims:
raise log_error(
ValueError, "Input data must contain 'time' as a dimension."
)
55 changes: 7 additions & 48 deletions movement/utils/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import xarray as xr

from movement.utils.logging import log_error
from movement.validators.arrays import validate_dims_coords


def compute_norm(data: xr.DataArray) -> xr.DataArray:
Expand Down Expand Up @@ -39,15 +40,15 @@ def compute_norm(data: xr.DataArray) -> xr.DataArray:

"""
if "space" in data.dims:
_validate_dimension_coordinates(data, {"space": ["x", "y"]})
validate_dims_coords(data, {"space": ["x", "y"]})
return xr.apply_ufunc(
np.linalg.norm,
data,
input_core_dims=[["space"]],
kwargs={"axis": -1},
)
elif "space_pol" in data.dims:
_validate_dimension_coordinates(data, {"space_pol": ["rho", "phi"]})
validate_dims_coords(data, {"space_pol": ["rho", "phi"]})
return data.sel(space_pol="rho", drop=True)
else:
_raise_error_for_missing_spatial_dim()
Expand Down Expand Up @@ -78,10 +79,10 @@ def convert_to_unit(data: xr.DataArray) -> xr.DataArray:

"""
if "space" in data.dims:
_validate_dimension_coordinates(data, {"space": ["x", "y"]})
validate_dims_coords(data, {"space": ["x", "y"]})
return data / compute_norm(data)
elif "space_pol" in data.dims:
_validate_dimension_coordinates(data, {"space_pol": ["rho", "phi"]})
validate_dims_coords(data, {"space_pol": ["rho", "phi"]})
# Set both rho and phi values to NaN at null vectors (where rho = 0)
new_data = xr.where(data.sel(space_pol="rho") == 0, np.nan, data)
# Set the rho values to 1 for non-null vectors (phi is preserved)
Expand Down Expand Up @@ -111,7 +112,7 @@ def cart2pol(data: xr.DataArray) -> xr.DataArray:
``phi`` returned are in radians, in the range ``[-pi, pi]``.

"""
_validate_dimension_coordinates(data, {"space": ["x", "y"]})
validate_dims_coords(data, {"space": ["x", "y"]})
rho = compute_norm(data)
phi = xr.apply_ufunc(
np.arctan2,
Expand Down Expand Up @@ -147,7 +148,7 @@ def pol2cart(data: xr.DataArray) -> xr.DataArray:
in the dimension coordinate.

"""
_validate_dimension_coordinates(data, {"space_pol": ["rho", "phi"]})
validate_dims_coords(data, {"space_pol": ["rho", "phi"]})
rho = data.sel(space_pol="rho")
phi = data.sel(space_pol="phi")
x = rho * np.cos(phi)
Expand All @@ -164,48 +165,6 @@ def pol2cart(data: xr.DataArray) -> xr.DataArray:
).transpose(*dims)


def _validate_dimension_coordinates(
data: xr.DataArray, required_dim_coords: dict
) -> None:
"""Validate the input data array.

Ensure that it contains the required dimensions and coordinates.

Parameters
----------
data : xarray.DataArray
The input data to validate.
required_dim_coords : dict
A dictionary of required dimensions and their corresponding
coordinate values.

Raises
------
ValueError
If the input data does not contain the required dimension(s)
and/or the required coordinate(s).

"""
missing_dims = [dim for dim in required_dim_coords if dim not in data.dims]
error_message = ""
if missing_dims:
error_message += (
f"Input data must contain {missing_dims} as dimensions.\n"
)
missing_coords = []
for dim, coords in required_dim_coords.items():
missing_coords = [
coord for coord in coords if coord not in data.coords.get(dim, [])
]
if missing_coords:
error_message += (
"Input data must contain "
f"{missing_coords} in the '{dim}' coordinates."
)
if error_message:
raise log_error(ValueError, error_message)


def _raise_error_for_missing_spatial_dim() -> None:
raise log_error(
ValueError,
Expand Down
61 changes: 61 additions & 0 deletions movement/validators/arrays.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""Validators for data arrays."""

import xarray as xr

from movement.utils.logging import log_error


def validate_dims_coords(
data: xr.DataArray, required_dim_coords: dict
) -> None:
"""Validate dimensions and coordinates in a data array.

This function raises a ValueError if the specified dimensions and
coordinates are not present in the input data array.

Parameters
----------
data : xarray.DataArray
The input data array to validate.
required_dim_coords : dict
A dictionary of required dimensions and their corresponding
coordinate values. If you don't need to specify any
coordinate values, you can pass an empty list.

Examples
--------
Validate that a data array contains the dimension 'time'. No specific
coordinates are required.

>>> validate_dims_coords(data, {"time": []})

Validate that a data array contains the dimensions 'time' and 'space',
and that the 'space' dimension contains the coordinates 'x' and 'y'.

>>> validate_dims_coords(data, {"time": [], "space": ["x", "y"]})

Raises
------
ValueError
If the input data does not contain the required dimension(s)
and/or the required coordinate(s).

"""
missing_dims = [dim for dim in required_dim_coords if dim not in data.dims]
error_message = ""
if missing_dims:
error_message += (
f"Input data must contain {missing_dims} as dimensions.\n"
)
missing_coords = []
for dim, coords in required_dim_coords.items():
missing_coords = [
coord for coord in coords if coord not in data.coords.get(dim, [])
]
if missing_coords:
error_message += (
"Input data must contain "
f"{missing_coords} in the '{dim}' coordinates."
)
if error_message:
raise log_error(ValueError, error_message)
2 changes: 1 addition & 1 deletion tests/test_unit/test_kinematics.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,4 +181,4 @@ def test_approximate_derivative_with_invalid_order(order):
data = np.arange(10)
expected_exception = ValueError if isinstance(order, int) else TypeError
with pytest.raises(expected_exception):
kinematics._compute_approximate_time_derivative(data, order=order)
kinematics.compute_time_derivative(data, order=order)
56 changes: 56 additions & 0 deletions tests/test_unit/test_validators/test_array_validators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import re
from contextlib import nullcontext as does_not_raise

import pytest

from movement.validators.arrays import validate_dims_coords


def expect_value_error_with_message(error_msg):
"""Expect a ValueError with the specified error message."""
return pytest.raises(ValueError, match=re.escape(error_msg))


valid_cases = [
({"time": []}, does_not_raise()),
({"time": [0, 1]}, does_not_raise()),
({"space": ["x", "y"]}, does_not_raise()),
({"time": [], "space": []}, does_not_raise()),
({"time": [], "space": ["x", "y"]}, does_not_raise()),
] # Valid cases (no error)

invalid_cases = [
(
{"spacetime": []},
expect_value_error_with_message(
"Input data must contain ['spacetime'] as dimensions."
),
),
(
{"time": [0, 100], "space": ["x", "y"]},
expect_value_error_with_message(
"Input data must contain [100] in the 'time' coordinates."
),
),
(
{"space": ["x", "y", "z"]},
expect_value_error_with_message(
"Input data must contain ['z'] in the 'space' coordinates."
),
),
] # Invalid cases (raise ValueError)


@pytest.mark.parametrize(
"required_dims_coords, expected_exception",
valid_cases + invalid_cases,
)
def test_validate_dims_coords(
valid_poses_dataset_uniform_linear_motion, # fixture from conftest.py
required_dims_coords,
expected_exception,
):
"""Test validate_dims_coords for both valid and invalid inputs."""
position_array = valid_poses_dataset_uniform_linear_motion["position"]
with expected_exception:
validate_dims_coords(position_array, required_dims_coords)
Loading