From 9c80786455fe0375d689430102eebe3bfa645ad2 Mon Sep 17 00:00:00 2001 From: Niko Sirmpilatze Date: Tue, 17 Sep 2024 14:17:42 +0100 Subject: [PATCH] General validator function for checking dimensions and coordinates (#294) * replace time dim validator with more generic validator * constrain kinematic functions to cartesian coordinates * renamed new validator to validate_dims_coords * add examples in validator docstring * unit tests for the new validator * Apply suggestions from code review do note validate x,y space coordinates specifically. Co-authored-by: Chang Huan Lo * reuse fixture valid_poses_dataset_uniform_linear_motion * combine two unit tests into one * expose public `compute_time_derivative` function * Refactor test * Update refs to `compute_time_derivative` --------- Co-authored-by: Chang Huan Lo --- movement/analysis/kinematics.py | 66 ++++++++----------- movement/utils/vector.py | 55 ++-------------- movement/validators/arrays.py | 61 +++++++++++++++++ tests/test_unit/test_kinematics.py | 2 +- .../test_validators/test_array_validators.py | 56 ++++++++++++++++ 5 files changed, 151 insertions(+), 89 deletions(-) create mode 100644 movement/validators/arrays.py create mode 100644 tests/test_unit/test_validators/test_array_validators.py diff --git a/movement/analysis/kinematics.py b/movement/analysis/kinematics.py index ed2b4b30..b2bbbf9b 100644 --- a/movement/analysis/kinematics.py +++ b/movement/analysis/kinematics.py @@ -3,6 +3,7 @@ import xarray as xr from movement.utils.logging import log_error +from movement.validators.arrays import validate_dims_coords def compute_displacement(data: xr.DataArray) -> xr.DataArray: @@ -18,8 +19,8 @@ def compute_displacement(data: xr.DataArray) -> xr.DataArray: Parameters ---------- data : xarray.DataArray - The input data array containing position vectors in cartesian - coordinates, with ``time`` as a dimension. + The input data containing position information, with ``time`` + and ``space`` (in Cartesian coordinates) as required dimensions. Returns ------- @@ -42,7 +43,7 @@ def compute_displacement(data: xr.DataArray) -> xr.DataArray: height per bounding box, between consecutive time points. """ - _validate_time_dimension(data) + validate_dims_coords(data, {"time": [], "space": []}) result = data.diff(dim="time") result = result.reindex(data.coords, fill_value=0) return result @@ -58,8 +59,8 @@ def compute_velocity(data: xr.DataArray) -> xr.DataArray: Parameters ---------- data : xarray.DataArray - The input data array containing position vectors in cartesian - coordinates, with ``time`` as a dimension. + The input data containing position information, with ``time`` + and ``space`` (in Cartesian coordinates) as required dimensions. Returns ------- @@ -78,10 +79,13 @@ def compute_velocity(data: xr.DataArray) -> xr.DataArray: See Also -------- - :meth:`xarray.DataArray.differentiate` : The underlying method used. + compute_time_derivative : The underlying function used. """ - return _compute_approximate_time_derivative(data, order=1) + # validate only presence of Cartesian space dimension + # (presence of time dimension will be checked in compute_time_derivative) + validate_dims_coords(data, {"space": []}) + return compute_time_derivative(data, order=1) def compute_acceleration(data: xr.DataArray) -> xr.DataArray: @@ -94,8 +98,8 @@ def compute_acceleration(data: xr.DataArray) -> xr.DataArray: Parameters ---------- data : xarray.DataArray - The input data array containing position vectors in cartesian - coordinates, with``time`` as a dimension. + The input data containing position information, with ``time`` + and ``space`` (in Cartesian coordinates) as required dimensions. Returns ------- @@ -115,15 +119,16 @@ def compute_acceleration(data: xr.DataArray) -> xr.DataArray: See Also -------- - :meth:`xarray.DataArray.differentiate` : The underlying method used. + compute_time_derivative : The underlying function used. """ - return _compute_approximate_time_derivative(data, order=2) + # validate only presence of Cartesian space dimension + # (presence of time dimension will be checked in compute_time_derivative) + validate_dims_coords(data, {"space": []}) + return compute_time_derivative(data, order=2) -def _compute_approximate_time_derivative( - data: xr.DataArray, order: int -) -> xr.DataArray: +def compute_time_derivative(data: xr.DataArray, order: int) -> xr.DataArray: """Compute the time-derivative of an array using numerical differentiation. This function uses :meth:`xarray.DataArray.differentiate`, @@ -133,7 +138,7 @@ def _compute_approximate_time_derivative( Parameters ---------- data : xarray.DataArray - The input data array containing ``time`` as a dimension. + The input data containing ``time`` as a required dimension. order : int The order of the time-derivative. For an input containing position data, use 1 to compute velocity, and 2 to compute acceleration. Value @@ -142,8 +147,11 @@ def _compute_approximate_time_derivative( Returns ------- xarray.DataArray - An xarray DataArray containing the time-derivative of the - input data. + An xarray DataArray containing the time-derivative of the input data. + + See Also + -------- + :meth:`xarray.DataArray.differentiate` : The underlying method used. """ if not isinstance(order, int): @@ -152,30 +160,8 @@ def _compute_approximate_time_derivative( ) if order <= 0: raise log_error(ValueError, "Order must be a positive integer.") - - _validate_time_dimension(data) - + validate_dims_coords(data, {"time": []}) result = data for _ in range(order): result = result.differentiate("time") return result - - -def _validate_time_dimension(data: xr.DataArray) -> None: - """Validate the input data contains a ``time`` dimension. - - Parameters - ---------- - data : xarray.DataArray - The input data to validate. - - Raises - ------ - ValueError - If the input data does not contain a ``time`` dimension. - - """ - if "time" not in data.dims: - raise log_error( - ValueError, "Input data must contain 'time' as a dimension." - ) diff --git a/movement/utils/vector.py b/movement/utils/vector.py index 0d5d88c8..c91e43ec 100644 --- a/movement/utils/vector.py +++ b/movement/utils/vector.py @@ -4,6 +4,7 @@ import xarray as xr from movement.utils.logging import log_error +from movement.validators.arrays import validate_dims_coords def compute_norm(data: xr.DataArray) -> xr.DataArray: @@ -39,7 +40,7 @@ def compute_norm(data: xr.DataArray) -> xr.DataArray: """ if "space" in data.dims: - _validate_dimension_coordinates(data, {"space": ["x", "y"]}) + validate_dims_coords(data, {"space": ["x", "y"]}) return xr.apply_ufunc( np.linalg.norm, data, @@ -47,7 +48,7 @@ def compute_norm(data: xr.DataArray) -> xr.DataArray: kwargs={"axis": -1}, ) elif "space_pol" in data.dims: - _validate_dimension_coordinates(data, {"space_pol": ["rho", "phi"]}) + validate_dims_coords(data, {"space_pol": ["rho", "phi"]}) return data.sel(space_pol="rho", drop=True) else: _raise_error_for_missing_spatial_dim() @@ -78,10 +79,10 @@ def convert_to_unit(data: xr.DataArray) -> xr.DataArray: """ if "space" in data.dims: - _validate_dimension_coordinates(data, {"space": ["x", "y"]}) + validate_dims_coords(data, {"space": ["x", "y"]}) return data / compute_norm(data) elif "space_pol" in data.dims: - _validate_dimension_coordinates(data, {"space_pol": ["rho", "phi"]}) + validate_dims_coords(data, {"space_pol": ["rho", "phi"]}) # Set both rho and phi values to NaN at null vectors (where rho = 0) new_data = xr.where(data.sel(space_pol="rho") == 0, np.nan, data) # Set the rho values to 1 for non-null vectors (phi is preserved) @@ -111,7 +112,7 @@ def cart2pol(data: xr.DataArray) -> xr.DataArray: ``phi`` returned are in radians, in the range ``[-pi, pi]``. """ - _validate_dimension_coordinates(data, {"space": ["x", "y"]}) + validate_dims_coords(data, {"space": ["x", "y"]}) rho = compute_norm(data) phi = xr.apply_ufunc( np.arctan2, @@ -147,7 +148,7 @@ def pol2cart(data: xr.DataArray) -> xr.DataArray: in the dimension coordinate. """ - _validate_dimension_coordinates(data, {"space_pol": ["rho", "phi"]}) + validate_dims_coords(data, {"space_pol": ["rho", "phi"]}) rho = data.sel(space_pol="rho") phi = data.sel(space_pol="phi") x = rho * np.cos(phi) @@ -164,48 +165,6 @@ def pol2cart(data: xr.DataArray) -> xr.DataArray: ).transpose(*dims) -def _validate_dimension_coordinates( - data: xr.DataArray, required_dim_coords: dict -) -> None: - """Validate the input data array. - - Ensure that it contains the required dimensions and coordinates. - - Parameters - ---------- - data : xarray.DataArray - The input data to validate. - required_dim_coords : dict - A dictionary of required dimensions and their corresponding - coordinate values. - - Raises - ------ - ValueError - If the input data does not contain the required dimension(s) - and/or the required coordinate(s). - - """ - missing_dims = [dim for dim in required_dim_coords if dim not in data.dims] - error_message = "" - if missing_dims: - error_message += ( - f"Input data must contain {missing_dims} as dimensions.\n" - ) - missing_coords = [] - for dim, coords in required_dim_coords.items(): - missing_coords = [ - coord for coord in coords if coord not in data.coords.get(dim, []) - ] - if missing_coords: - error_message += ( - "Input data must contain " - f"{missing_coords} in the '{dim}' coordinates." - ) - if error_message: - raise log_error(ValueError, error_message) - - def _raise_error_for_missing_spatial_dim() -> None: raise log_error( ValueError, diff --git a/movement/validators/arrays.py b/movement/validators/arrays.py new file mode 100644 index 00000000..76847571 --- /dev/null +++ b/movement/validators/arrays.py @@ -0,0 +1,61 @@ +"""Validators for data arrays.""" + +import xarray as xr + +from movement.utils.logging import log_error + + +def validate_dims_coords( + data: xr.DataArray, required_dim_coords: dict +) -> None: + """Validate dimensions and coordinates in a data array. + + This function raises a ValueError if the specified dimensions and + coordinates are not present in the input data array. + + Parameters + ---------- + data : xarray.DataArray + The input data array to validate. + required_dim_coords : dict + A dictionary of required dimensions and their corresponding + coordinate values. If you don't need to specify any + coordinate values, you can pass an empty list. + + Examples + -------- + Validate that a data array contains the dimension 'time'. No specific + coordinates are required. + + >>> validate_dims_coords(data, {"time": []}) + + Validate that a data array contains the dimensions 'time' and 'space', + and that the 'space' dimension contains the coordinates 'x' and 'y'. + + >>> validate_dims_coords(data, {"time": [], "space": ["x", "y"]}) + + Raises + ------ + ValueError + If the input data does not contain the required dimension(s) + and/or the required coordinate(s). + + """ + missing_dims = [dim for dim in required_dim_coords if dim not in data.dims] + error_message = "" + if missing_dims: + error_message += ( + f"Input data must contain {missing_dims} as dimensions.\n" + ) + missing_coords = [] + for dim, coords in required_dim_coords.items(): + missing_coords = [ + coord for coord in coords if coord not in data.coords.get(dim, []) + ] + if missing_coords: + error_message += ( + "Input data must contain " + f"{missing_coords} in the '{dim}' coordinates." + ) + if error_message: + raise log_error(ValueError, error_message) diff --git a/tests/test_unit/test_kinematics.py b/tests/test_unit/test_kinematics.py index 7641aeeb..a1b933e0 100644 --- a/tests/test_unit/test_kinematics.py +++ b/tests/test_unit/test_kinematics.py @@ -181,4 +181,4 @@ def test_approximate_derivative_with_invalid_order(order): data = np.arange(10) expected_exception = ValueError if isinstance(order, int) else TypeError with pytest.raises(expected_exception): - kinematics._compute_approximate_time_derivative(data, order=order) + kinematics.compute_time_derivative(data, order=order) diff --git a/tests/test_unit/test_validators/test_array_validators.py b/tests/test_unit/test_validators/test_array_validators.py new file mode 100644 index 00000000..a1a4412c --- /dev/null +++ b/tests/test_unit/test_validators/test_array_validators.py @@ -0,0 +1,56 @@ +import re +from contextlib import nullcontext as does_not_raise + +import pytest + +from movement.validators.arrays import validate_dims_coords + + +def expect_value_error_with_message(error_msg): + """Expect a ValueError with the specified error message.""" + return pytest.raises(ValueError, match=re.escape(error_msg)) + + +valid_cases = [ + ({"time": []}, does_not_raise()), + ({"time": [0, 1]}, does_not_raise()), + ({"space": ["x", "y"]}, does_not_raise()), + ({"time": [], "space": []}, does_not_raise()), + ({"time": [], "space": ["x", "y"]}, does_not_raise()), +] # Valid cases (no error) + +invalid_cases = [ + ( + {"spacetime": []}, + expect_value_error_with_message( + "Input data must contain ['spacetime'] as dimensions." + ), + ), + ( + {"time": [0, 100], "space": ["x", "y"]}, + expect_value_error_with_message( + "Input data must contain [100] in the 'time' coordinates." + ), + ), + ( + {"space": ["x", "y", "z"]}, + expect_value_error_with_message( + "Input data must contain ['z'] in the 'space' coordinates." + ), + ), +] # Invalid cases (raise ValueError) + + +@pytest.mark.parametrize( + "required_dims_coords, expected_exception", + valid_cases + invalid_cases, +) +def test_validate_dims_coords( + valid_poses_dataset_uniform_linear_motion, # fixture from conftest.py + required_dims_coords, + expected_exception, +): + """Test validate_dims_coords for both valid and invalid inputs.""" + position_array = valid_poses_dataset_uniform_linear_motion["position"] + with expected_exception: + validate_dims_coords(position_array, required_dims_coords)