From d33ac25db117bcf46f6a852f71b0ef324a9f76e7 Mon Sep 17 00:00:00 2001 From: niksirbi Date: Mon, 26 Aug 2024 11:37:50 +0100 Subject: [PATCH] replace time dim validator with more generic validator --- movement/analysis/kinematics.py | 25 ++------------- movement/utils/vector.py | 55 +++++---------------------------- movement/validators/arrays.py | 49 +++++++++++++++++++++++++++++ 3 files changed, 59 insertions(+), 70 deletions(-) create mode 100644 movement/validators/arrays.py diff --git a/movement/analysis/kinematics.py b/movement/analysis/kinematics.py index 15375a534..c9538787b 100644 --- a/movement/analysis/kinematics.py +++ b/movement/analysis/kinematics.py @@ -3,6 +3,7 @@ import xarray as xr from movement.utils.logging import log_error +from movement.validators.arrays import validate_dimension_coordinates def compute_displacement(data: xr.DataArray) -> xr.DataArray: @@ -26,7 +27,7 @@ def compute_displacement(data: xr.DataArray) -> xr.DataArray: An xarray DataArray containing the computed displacement. """ - _validate_time_dimension(data) + validate_dimension_coordinates(data, {"time": []}) result = data.diff(dim="time") result = result.reindex(data.coords, fill_value=0) return result @@ -113,28 +114,8 @@ def _compute_approximate_time_derivative( ) if order <= 0: raise log_error(ValueError, "Order must be a positive integer.") - _validate_time_dimension(data) + validate_dimension_coordinates(data, {"time": []}) result = data for _ in range(order): result = result.differentiate("time") return result - - -def _validate_time_dimension(data: xr.DataArray) -> None: - """Validate the input data contains a ``time`` dimension. - - Parameters - ---------- - data : xarray.DataArray - The input data to validate. - - Raises - ------ - ValueError - If the input data does not contain a ``time`` dimension. - - """ - if "time" not in data.dims: - raise log_error( - ValueError, "Input data must contain 'time' as a dimension." - ) diff --git a/movement/utils/vector.py b/movement/utils/vector.py index 0d5d88c83..dfe29c5d7 100644 --- a/movement/utils/vector.py +++ b/movement/utils/vector.py @@ -4,6 +4,7 @@ import xarray as xr from movement.utils.logging import log_error +from movement.validators.arrays import validate_dimension_coordinates def compute_norm(data: xr.DataArray) -> xr.DataArray: @@ -39,7 +40,7 @@ def compute_norm(data: xr.DataArray) -> xr.DataArray: """ if "space" in data.dims: - _validate_dimension_coordinates(data, {"space": ["x", "y"]}) + validate_dimension_coordinates(data, {"space": ["x", "y"]}) return xr.apply_ufunc( np.linalg.norm, data, @@ -47,7 +48,7 @@ def compute_norm(data: xr.DataArray) -> xr.DataArray: kwargs={"axis": -1}, ) elif "space_pol" in data.dims: - _validate_dimension_coordinates(data, {"space_pol": ["rho", "phi"]}) + validate_dimension_coordinates(data, {"space_pol": ["rho", "phi"]}) return data.sel(space_pol="rho", drop=True) else: _raise_error_for_missing_spatial_dim() @@ -78,10 +79,10 @@ def convert_to_unit(data: xr.DataArray) -> xr.DataArray: """ if "space" in data.dims: - _validate_dimension_coordinates(data, {"space": ["x", "y"]}) + validate_dimension_coordinates(data, {"space": ["x", "y"]}) return data / compute_norm(data) elif "space_pol" in data.dims: - _validate_dimension_coordinates(data, {"space_pol": ["rho", "phi"]}) + validate_dimension_coordinates(data, {"space_pol": ["rho", "phi"]}) # Set both rho and phi values to NaN at null vectors (where rho = 0) new_data = xr.where(data.sel(space_pol="rho") == 0, np.nan, data) # Set the rho values to 1 for non-null vectors (phi is preserved) @@ -111,7 +112,7 @@ def cart2pol(data: xr.DataArray) -> xr.DataArray: ``phi`` returned are in radians, in the range ``[-pi, pi]``. """ - _validate_dimension_coordinates(data, {"space": ["x", "y"]}) + validate_dimension_coordinates(data, {"space": ["x", "y"]}) rho = compute_norm(data) phi = xr.apply_ufunc( np.arctan2, @@ -147,7 +148,7 @@ def pol2cart(data: xr.DataArray) -> xr.DataArray: in the dimension coordinate. """ - _validate_dimension_coordinates(data, {"space_pol": ["rho", "phi"]}) + validate_dimension_coordinates(data, {"space_pol": ["rho", "phi"]}) rho = data.sel(space_pol="rho") phi = data.sel(space_pol="phi") x = rho * np.cos(phi) @@ -164,48 +165,6 @@ def pol2cart(data: xr.DataArray) -> xr.DataArray: ).transpose(*dims) -def _validate_dimension_coordinates( - data: xr.DataArray, required_dim_coords: dict -) -> None: - """Validate the input data array. - - Ensure that it contains the required dimensions and coordinates. - - Parameters - ---------- - data : xarray.DataArray - The input data to validate. - required_dim_coords : dict - A dictionary of required dimensions and their corresponding - coordinate values. - - Raises - ------ - ValueError - If the input data does not contain the required dimension(s) - and/or the required coordinate(s). - - """ - missing_dims = [dim for dim in required_dim_coords if dim not in data.dims] - error_message = "" - if missing_dims: - error_message += ( - f"Input data must contain {missing_dims} as dimensions.\n" - ) - missing_coords = [] - for dim, coords in required_dim_coords.items(): - missing_coords = [ - coord for coord in coords if coord not in data.coords.get(dim, []) - ] - if missing_coords: - error_message += ( - "Input data must contain " - f"{missing_coords} in the '{dim}' coordinates." - ) - if error_message: - raise log_error(ValueError, error_message) - - def _raise_error_for_missing_spatial_dim() -> None: raise log_error( ValueError, diff --git a/movement/validators/arrays.py b/movement/validators/arrays.py new file mode 100644 index 000000000..260480408 --- /dev/null +++ b/movement/validators/arrays.py @@ -0,0 +1,49 @@ +"""Validators for data arrays.""" + +import xarray as xr + +from movement.utils.logging import log_error + + +def validate_dimension_coordinates( + data: xr.DataArray, required_dim_coords: dict +) -> None: + """Validate dimensions and coordinates in a data array. + + This function raises a ValueError if the specified dimensions and + coordinates are not present in the input data array.. + + Parameters + ---------- + data : xarray.DataArray + The input data array to validate. + required_dim_coords : dict + A dictionary of required dimensions and their corresponding + coordinate values. If you don't need to specify any + coordinate values, you can pass an empty list. + + Raises + ------ + ValueError + If the input data does not contain the required dimension(s) + and/or the required coordinate(s). + + """ + missing_dims = [dim for dim in required_dim_coords if dim not in data.dims] + error_message = "" + if missing_dims: + error_message += ( + f"Input data must contain {missing_dims} as dimensions.\n" + ) + missing_coords = [] + for dim, coords in required_dim_coords.items(): + missing_coords = [ + coord for coord in coords if coord not in data.coords.get(dim, []) + ] + if missing_coords: + error_message += ( + "Input data must contain " + f"{missing_coords} in the '{dim}' coordinates." + ) + if error_message: + raise log_error(ValueError, error_message)