From 7d5423dcb4f76e4f36fcf8220f4c0cbdebbb71b5 Mon Sep 17 00:00:00 2001 From: Stella <30465823+stellaprins@users.noreply.github.com> Date: Thu, 30 Jan 2025 09:54:06 +0000 Subject: [PATCH] add 3D space coords validation and test --- movement/transforms.py | 9 ++-- tests/test_unit/test_transforms.py | 69 ++++++++++++++++++++++++------ 2 files changed, 61 insertions(+), 17 deletions(-) diff --git a/movement/transforms.py b/movement/transforms.py index 61d11afe..62435368 100644 --- a/movement/transforms.py +++ b/movement/transforms.py @@ -41,11 +41,12 @@ def scale( xarray.DataArray.attrs["space_unit"] is overwritten each time or is dropped if ``None`` is passed by default or explicitly. - When the factor is a scalar (a single number), the scaling factor is - applied to all dimensions, while if the factor is a list or array, the - factor is broadcasted along the first matching dimension. - """ + if data.space.ndim == 2: + validate_dims_coords(data, {"space": ["x", "y"]}) + else: + validate_dims_coords(data, {"space": ["x", "y", "z"]}) + validate_dims_coords(data, {"space": ["x", "y"]}) if not np.isscalar(factor): diff --git a/tests/test_unit/test_transforms.py b/tests/test_unit/test_transforms.py index 89cb0cf3..d2182352 100644 --- a/tests/test_unit/test_transforms.py +++ b/tests/test_unit/test_transforms.py @@ -6,7 +6,8 @@ from movement.transforms import scale -DEFAULT_SPATIAL_COORDS = {"space": ["x", "y"]} +SPATIAL_COORDS_2D = {"space": ["x", "y"]} +SPATIAL_COORDS_3D = {"space": ["x", "y", "z"]} def nparray_0_to_23() -> np.ndarray: @@ -15,19 +16,29 @@ def nparray_0_to_23() -> np.ndarray: @pytest.fixture -def sample_data() -> xr.DataArray: +def sample_data_2d() -> xr.DataArray: """Turn the nparray_0_to_23 into a DataArray.""" return data_array_with_dims_and_coords(nparray_0_to_23()) +@pytest.fixture +def sample_data_3d() -> xr.DataArray: + """Turn the nparray_0_to_23 into a DataArray with 3D space.""" + return data_array_with_dims_and_coords( + nparray_0_to_23().reshape(8, 3), + coords=SPATIAL_COORDS_3D, + ) + + def data_array_with_dims_and_coords( data: np.ndarray, dims: list | tuple = ("time", "space"), - coords: dict[str, list[str]] = DEFAULT_SPATIAL_COORDS, + coords: dict[str, list[str]] = SPATIAL_COORDS_2D, **attributes: Any, ) -> xr.DataArray: """Create a DataArray with given data, dimensions, coordinates, and - attributes (e.g. space_unit or factor). + attributes (e.g. space_unit or factor). The default space coordinates + are x and y (2D). """ return xr.DataArray( data, @@ -86,12 +97,12 @@ def data_array_with_dims_and_coords( ], ) def test_scale( - sample_data: xr.DataArray, + sample_data_2d: xr.DataArray, optional_arguments: dict[str, Any], expected_output: xr.DataArray, ): """Test scaling with different factors and space_units.""" - scaled_data = scale(sample_data, **optional_arguments) + scaled_data = scale(sample_data_2d, **optional_arguments) xr.testing.assert_equal(scaled_data, expected_output) assert scaled_data.attrs == expected_output.attrs @@ -119,9 +130,7 @@ def test_scale_space_dimension(dims: list[str], data_shape): """ factor = [0.5, 2] numerical_data = np.arange(np.prod(data_shape)).reshape(data_shape) - data = xr.DataArray( - numerical_data, dims=dims, coords=DEFAULT_SPATIAL_COORDS - ) + data = xr.DataArray(numerical_data, dims=dims, coords=SPATIAL_COORDS_2D) scaled_data = scale(data, factor=factor) broadcast_list = [1 if dim != "space" else len(factor) for dim in dims] expected_output_data = data * np.array(factor).reshape(broadcast_list) @@ -158,7 +167,7 @@ def test_scale_space_dimension(dims: list[str], data_shape): ], ) def test_scale_twice( - sample_data: xr.DataArray, + sample_data_2d: xr.DataArray, optional_arguments_1: dict[str, Any], optional_arguments_2: dict[str, Any], expected_output: xr.DataArray, @@ -168,7 +177,7 @@ def test_scale_twice( provided, or remove it if None is passed explicitly or by default. """ output_data_array = scale( - scale(sample_data, **optional_arguments_1), + scale(sample_data_2d, **optional_arguments_1), **optional_arguments_2, ) xr.testing.assert_equal(output_data_array, expected_output) @@ -193,11 +202,45 @@ def test_scale_twice( ], ) def test_scale_value_error( - sample_data: xr.DataArray, + sample_data_2d: xr.DataArray, invalid_factor: np.ndarray, expected_error_message: str, ): """Test invalid factors raise correct error type and message.""" with pytest.raises(ValueError) as error: - scale(sample_data, factor=invalid_factor) + scale(sample_data_2d, factor=invalid_factor) assert str(error.value) == expected_error_message + + +@pytest.mark.parametrize( + "factor", + [2, [1, 2, 0.5]], + ids=["uniform scaling", "multi-axis scaling"], +) +def test_scale_3d_space(factor, sample_data_3d: xr.DataArray): + """Test scaling a DataArray with 3D space.""" + scaled_data = scale(sample_data_3d, factor=factor) + expected_output = data_array_with_dims_and_coords( + nparray_0_to_23().reshape(8, 3) * np.array(factor).reshape(1, -1), + coords=SPATIAL_COORDS_3D, + ) + xr.testing.assert_equal(scaled_data, expected_output) + + +@pytest.mark.parametrize( + "factor", + [2, [1, 2, 0.5]], + ids=["uniform scaling", "multi-axis scaling"], +) +def test_scale_invalid_3d_space(factor): + """Test scaling a DataArray with 3D space.""" + invalid_coords = {"space": ["x", "flubble", "y"]} # "z" is missing + invalid_sample_data_3d = data_array_with_dims_and_coords( + nparray_0_to_23().reshape(8, 3), + coords=invalid_coords, + ) + with pytest.raises(ValueError) as error: + scale(invalid_sample_data_3d, factor=factor) + assert str(error.value) == ( + "Input data must contain ['z'] in the 'space' coordinates." + )