Skip to content

Commit

Permalink
add 3D space coords validation and test
Browse files Browse the repository at this point in the history
  • Loading branch information
stellaprins committed Jan 30, 2025
1 parent fea9230 commit 7d5423d
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 17 deletions.
9 changes: 5 additions & 4 deletions movement/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,12 @@ def scale(
xarray.DataArray.attrs["space_unit"] is overwritten each time or is dropped
if ``None`` is passed by default or explicitly.
When the factor is a scalar (a single number), the scaling factor is
applied to all dimensions, while if the factor is a list or array, the
factor is broadcasted along the first matching dimension.
"""
if data.space.ndim == 2:
validate_dims_coords(data, {"space": ["x", "y"]})
else:
validate_dims_coords(data, {"space": ["x", "y", "z"]})

validate_dims_coords(data, {"space": ["x", "y"]})

if not np.isscalar(factor):
Expand Down
69 changes: 56 additions & 13 deletions tests/test_unit/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

from movement.transforms import scale

DEFAULT_SPATIAL_COORDS = {"space": ["x", "y"]}
SPATIAL_COORDS_2D = {"space": ["x", "y"]}
SPATIAL_COORDS_3D = {"space": ["x", "y", "z"]}


def nparray_0_to_23() -> np.ndarray:
Expand All @@ -15,19 +16,29 @@ def nparray_0_to_23() -> np.ndarray:


@pytest.fixture
def sample_data() -> xr.DataArray:
def sample_data_2d() -> xr.DataArray:
"""Turn the nparray_0_to_23 into a DataArray."""
return data_array_with_dims_and_coords(nparray_0_to_23())


@pytest.fixture
def sample_data_3d() -> xr.DataArray:
"""Turn the nparray_0_to_23 into a DataArray with 3D space."""
return data_array_with_dims_and_coords(
nparray_0_to_23().reshape(8, 3),
coords=SPATIAL_COORDS_3D,
)


def data_array_with_dims_and_coords(
data: np.ndarray,
dims: list | tuple = ("time", "space"),
coords: dict[str, list[str]] = DEFAULT_SPATIAL_COORDS,
coords: dict[str, list[str]] = SPATIAL_COORDS_2D,
**attributes: Any,
) -> xr.DataArray:
"""Create a DataArray with given data, dimensions, coordinates, and
attributes (e.g. space_unit or factor).
attributes (e.g. space_unit or factor). The default space coordinates
are x and y (2D).
"""
return xr.DataArray(
data,
Expand Down Expand Up @@ -86,12 +97,12 @@ def data_array_with_dims_and_coords(
],
)
def test_scale(
sample_data: xr.DataArray,
sample_data_2d: xr.DataArray,
optional_arguments: dict[str, Any],
expected_output: xr.DataArray,
):
"""Test scaling with different factors and space_units."""
scaled_data = scale(sample_data, **optional_arguments)
scaled_data = scale(sample_data_2d, **optional_arguments)
xr.testing.assert_equal(scaled_data, expected_output)
assert scaled_data.attrs == expected_output.attrs

Expand Down Expand Up @@ -119,9 +130,7 @@ def test_scale_space_dimension(dims: list[str], data_shape):
"""
factor = [0.5, 2]
numerical_data = np.arange(np.prod(data_shape)).reshape(data_shape)
data = xr.DataArray(
numerical_data, dims=dims, coords=DEFAULT_SPATIAL_COORDS
)
data = xr.DataArray(numerical_data, dims=dims, coords=SPATIAL_COORDS_2D)
scaled_data = scale(data, factor=factor)
broadcast_list = [1 if dim != "space" else len(factor) for dim in dims]
expected_output_data = data * np.array(factor).reshape(broadcast_list)
Expand Down Expand Up @@ -158,7 +167,7 @@ def test_scale_space_dimension(dims: list[str], data_shape):
],
)
def test_scale_twice(
sample_data: xr.DataArray,
sample_data_2d: xr.DataArray,
optional_arguments_1: dict[str, Any],
optional_arguments_2: dict[str, Any],
expected_output: xr.DataArray,
Expand All @@ -168,7 +177,7 @@ def test_scale_twice(
provided, or remove it if None is passed explicitly or by default.
"""
output_data_array = scale(
scale(sample_data, **optional_arguments_1),
scale(sample_data_2d, **optional_arguments_1),
**optional_arguments_2,
)
xr.testing.assert_equal(output_data_array, expected_output)
Expand All @@ -193,11 +202,45 @@ def test_scale_twice(
],
)
def test_scale_value_error(
sample_data: xr.DataArray,
sample_data_2d: xr.DataArray,
invalid_factor: np.ndarray,
expected_error_message: str,
):
"""Test invalid factors raise correct error type and message."""
with pytest.raises(ValueError) as error:
scale(sample_data, factor=invalid_factor)
scale(sample_data_2d, factor=invalid_factor)
assert str(error.value) == expected_error_message


@pytest.mark.parametrize(
"factor",
[2, [1, 2, 0.5]],
ids=["uniform scaling", "multi-axis scaling"],
)
def test_scale_3d_space(factor, sample_data_3d: xr.DataArray):
"""Test scaling a DataArray with 3D space."""
scaled_data = scale(sample_data_3d, factor=factor)
expected_output = data_array_with_dims_and_coords(
nparray_0_to_23().reshape(8, 3) * np.array(factor).reshape(1, -1),
coords=SPATIAL_COORDS_3D,
)
xr.testing.assert_equal(scaled_data, expected_output)


@pytest.mark.parametrize(
"factor",
[2, [1, 2, 0.5]],
ids=["uniform scaling", "multi-axis scaling"],
)
def test_scale_invalid_3d_space(factor):
"""Test scaling a DataArray with 3D space."""
invalid_coords = {"space": ["x", "flubble", "y"]} # "z" is missing
invalid_sample_data_3d = data_array_with_dims_and_coords(
nparray_0_to_23().reshape(8, 3),
coords=invalid_coords,
)
with pytest.raises(ValueError) as error:
scale(invalid_sample_data_3d, factor=factor)
assert str(error.value) == (
"Input data must contain ['z'] in the 'space' coordinates."
)

0 comments on commit 7d5423d

Please sign in to comment.