Skip to content

Commit

Permalink
Tweaked compute_forward_vector() to use new validator
Browse files Browse the repository at this point in the history
  • Loading branch information
b-peri committed Sep 24, 2024
1 parent 054d50c commit df6bed9
Showing 1 changed file with 9 additions and 44 deletions.
53 changes: 9 additions & 44 deletions movement/analysis/kinematics.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import xarray as xr

from movement.utils.logging import log_error
from movement.validators.arrays import validate_dims_coords
from movement.utils.vector import compute_norm
from movement.validators.arrays import validate_dims_coords


def compute_displacement(data: xr.DataArray) -> xr.DataArray:
Expand Down Expand Up @@ -230,7 +230,14 @@ def compute_forward_vector(
"""
# Validate input data
_validate_type_data_array(data)
_validate_time_keypoints_space_dimensions(data)
validate_dims_coords(
data,
{
"time": [],
"keypoints": [left_keypoint, right_keypoint],
"space": [],
},
)
if len(data.space) != 2:
raise log_error(
ValueError,
Expand Down Expand Up @@ -315,26 +322,6 @@ def compute_head_direction_vector(
)


def _validate_time_dimension(data: xr.DataArray) -> None:
"""Validate the input data contains a ``time`` dimension.
Parameters
----------
data : xarray.DataArray
The input data to validate.
Raises
------
ValueError
If the input data does not contain a ``time`` dimension.
"""
if "time" not in data.dims:
raise log_error(
ValueError, "Input data must contain 'time' as a dimension."
)


def _validate_type_data_array(data: xr.DataArray) -> None:
"""Validate the input data is an xarray DataArray.
Expand All @@ -354,25 +341,3 @@ def _validate_type_data_array(data: xr.DataArray) -> None:
TypeError,
f"Input data must be an xarray.DataArray, but got {type(data)}.",
)


def _validate_time_keypoints_space_dimensions(data: xr.DataArray) -> None:
"""Validate if input data contains ``time``, ``keypoints`` and ``space``.
Parameters
----------
data : xarray.DataArray
The input data to validate.
Raises
------
ValueError
If the input data is not an xarray DataArray.
"""
if not all(coord in data.dims for coord in ["time", "keypoints", "space"]):
raise log_error(
AttributeError,
"Input data must contain 'time', 'space', and 'keypoints' as "
"dimensions.",
)

0 comments on commit df6bed9

Please sign in to comment.