Skip to content

Commit

Permalink
replace time dim validator with more generic validator
Browse files Browse the repository at this point in the history
  • Loading branch information
niksirbi committed Aug 29, 2024
1 parent f47d6e7 commit d33ac25
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 70 deletions.
25 changes: 3 additions & 22 deletions movement/analysis/kinematics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import xarray as xr

from movement.utils.logging import log_error
from movement.validators.arrays import validate_dimension_coordinates


def compute_displacement(data: xr.DataArray) -> xr.DataArray:
Expand All @@ -26,7 +27,7 @@ def compute_displacement(data: xr.DataArray) -> xr.DataArray:
An xarray DataArray containing the computed displacement.
"""
_validate_time_dimension(data)
validate_dimension_coordinates(data, {"time": []})
result = data.diff(dim="time")
result = result.reindex(data.coords, fill_value=0)
return result
Expand Down Expand Up @@ -113,28 +114,8 @@ def _compute_approximate_time_derivative(
)
if order <= 0:
raise log_error(ValueError, "Order must be a positive integer.")
_validate_time_dimension(data)
validate_dimension_coordinates(data, {"time": []})
result = data
for _ in range(order):
result = result.differentiate("time")
return result


def _validate_time_dimension(data: xr.DataArray) -> None:
"""Validate the input data contains a ``time`` dimension.
Parameters
----------
data : xarray.DataArray
The input data to validate.
Raises
------
ValueError
If the input data does not contain a ``time`` dimension.
"""
if "time" not in data.dims:
raise log_error(
ValueError, "Input data must contain 'time' as a dimension."
)
55 changes: 7 additions & 48 deletions movement/utils/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import xarray as xr

from movement.utils.logging import log_error
from movement.validators.arrays import validate_dimension_coordinates


def compute_norm(data: xr.DataArray) -> xr.DataArray:
Expand Down Expand Up @@ -39,15 +40,15 @@ def compute_norm(data: xr.DataArray) -> xr.DataArray:
"""
if "space" in data.dims:
_validate_dimension_coordinates(data, {"space": ["x", "y"]})
validate_dimension_coordinates(data, {"space": ["x", "y"]})
return xr.apply_ufunc(
np.linalg.norm,
data,
input_core_dims=[["space"]],
kwargs={"axis": -1},
)
elif "space_pol" in data.dims:
_validate_dimension_coordinates(data, {"space_pol": ["rho", "phi"]})
validate_dimension_coordinates(data, {"space_pol": ["rho", "phi"]})
return data.sel(space_pol="rho", drop=True)
else:
_raise_error_for_missing_spatial_dim()
Expand Down Expand Up @@ -78,10 +79,10 @@ def convert_to_unit(data: xr.DataArray) -> xr.DataArray:
"""
if "space" in data.dims:
_validate_dimension_coordinates(data, {"space": ["x", "y"]})
validate_dimension_coordinates(data, {"space": ["x", "y"]})
return data / compute_norm(data)
elif "space_pol" in data.dims:
_validate_dimension_coordinates(data, {"space_pol": ["rho", "phi"]})
validate_dimension_coordinates(data, {"space_pol": ["rho", "phi"]})
# Set both rho and phi values to NaN at null vectors (where rho = 0)
new_data = xr.where(data.sel(space_pol="rho") == 0, np.nan, data)
# Set the rho values to 1 for non-null vectors (phi is preserved)
Expand Down Expand Up @@ -111,7 +112,7 @@ def cart2pol(data: xr.DataArray) -> xr.DataArray:
``phi`` returned are in radians, in the range ``[-pi, pi]``.
"""
_validate_dimension_coordinates(data, {"space": ["x", "y"]})
validate_dimension_coordinates(data, {"space": ["x", "y"]})
rho = compute_norm(data)
phi = xr.apply_ufunc(
np.arctan2,
Expand Down Expand Up @@ -147,7 +148,7 @@ def pol2cart(data: xr.DataArray) -> xr.DataArray:
in the dimension coordinate.
"""
_validate_dimension_coordinates(data, {"space_pol": ["rho", "phi"]})
validate_dimension_coordinates(data, {"space_pol": ["rho", "phi"]})
rho = data.sel(space_pol="rho")
phi = data.sel(space_pol="phi")
x = rho * np.cos(phi)
Expand All @@ -164,48 +165,6 @@ def pol2cart(data: xr.DataArray) -> xr.DataArray:
).transpose(*dims)


def _validate_dimension_coordinates(
data: xr.DataArray, required_dim_coords: dict
) -> None:
"""Validate the input data array.
Ensure that it contains the required dimensions and coordinates.
Parameters
----------
data : xarray.DataArray
The input data to validate.
required_dim_coords : dict
A dictionary of required dimensions and their corresponding
coordinate values.
Raises
------
ValueError
If the input data does not contain the required dimension(s)
and/or the required coordinate(s).
"""
missing_dims = [dim for dim in required_dim_coords if dim not in data.dims]
error_message = ""
if missing_dims:
error_message += (
f"Input data must contain {missing_dims} as dimensions.\n"
)
missing_coords = []
for dim, coords in required_dim_coords.items():
missing_coords = [
coord for coord in coords if coord not in data.coords.get(dim, [])
]
if missing_coords:
error_message += (
"Input data must contain "
f"{missing_coords} in the '{dim}' coordinates."
)
if error_message:
raise log_error(ValueError, error_message)


def _raise_error_for_missing_spatial_dim() -> None:
raise log_error(
ValueError,
Expand Down
49 changes: 49 additions & 0 deletions movement/validators/arrays.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""Validators for data arrays."""

import xarray as xr

from movement.utils.logging import log_error


def validate_dimension_coordinates(
data: xr.DataArray, required_dim_coords: dict
) -> None:
"""Validate dimensions and coordinates in a data array.
This function raises a ValueError if the specified dimensions and
coordinates are not present in the input data array..
Parameters
----------
data : xarray.DataArray
The input data array to validate.
required_dim_coords : dict
A dictionary of required dimensions and their corresponding
coordinate values. If you don't need to specify any
coordinate values, you can pass an empty list.
Raises
------
ValueError
If the input data does not contain the required dimension(s)
and/or the required coordinate(s).
"""
missing_dims = [dim for dim in required_dim_coords if dim not in data.dims]
error_message = ""
if missing_dims:
error_message += (
f"Input data must contain {missing_dims} as dimensions.\n"
)
missing_coords = []
for dim, coords in required_dim_coords.items():
missing_coords = [
coord for coord in coords if coord not in data.coords.get(dim, [])
]
if missing_coords:
error_message += (
"Input data must contain "
f"{missing_coords} in the '{dim}' coordinates."
)
if error_message:
raise log_error(ValueError, error_message)

0 comments on commit d33ac25

Please sign in to comment.