Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add transforms module with scale function #384

Merged
Merged
Changes from 1 commit
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add docstrings to transforms module and tests
stellaprins committed Jan 22, 2025
commit 1f6c4529f126c6a185ae7f2094a5eb33dcf53278
59 changes: 43 additions & 16 deletions movement/transforms.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,68 @@
"""Transforms module."""
"""Transform and add unit attributes to xarray.DataArray datasets."""

import numpy as np
import xarray as xr


def scale(
data_array: xr.DataArray,
data: xr.DataArray,
factor: float | np.ndarray[float] = 1.0,
unit: str | None = None,
) -> xr.DataArray:
"""Scale data by a given factor with an optional unit."""
"""Scale data by a given factor with an optional unit.

Parameters
----------
data : xarray.DataArray
The input data to be scaled.
factor : float or np.ndarray of floats
The scaling factor to apply to the data. If factor is a scalar, all
dimensions of the data array are scaled by the same factor. If factor
is a list or an 1D array, the length of the array must match the length
of one of the data array's dimensions. The factor is broadcast
along the first matching dimension.
unit : str or None
The unit of the scaled data stored as a property in
xarray.DataArray.attrs['unit']. In case of the default (``None``) the
``unit`` attribute is dropped.

Returns
-------
xarray.DataArray
The scaled data array.

Notes
-----
When scale is used multiple times on the same xarray.DataArray,
xarray.DataArray.attrs["unit"] is overwritten each time or is dropped if
``None`` is passed by default or explicitly.

When the factor is a scalar (a single number), the scaling factor is
applied to all dimensions, while if the factor is a list or array, the
factor is broadcasted along the first matching dimension.

"""
if not np.isscalar(factor):
factor = np.array(factor).squeeze()
if factor.ndim != 1:
raise ValueError(
f"Factor must be a scalar or a 1D array, got {factor.ndim}D"
)
elif factor.shape[0] not in data_array.shape:
elif factor.shape[0] not in data.shape:
raise ValueError(
f"Factor shape {factor.shape} does not match "
f"the length of any data axes: {data_array.shape}"
f"the length of any data axes: {data.shape}"
)
else:
# To figure out which dimension to broadcast along.
# Find dimensions with as many values as we have factors.
matching_dims = np.array(data_array.shape) == factor.shape[0]
# Find first dimension that matches.
matching_dims = np.array(data.shape) == factor.shape[0]
first_matching_dim = np.argmax(matching_dims).item()
# Reshape factor to broadcast along the matching dimension.
factor_dims = [1] * data_array.ndim
factor_dims = [1] * data.ndim
factor_dims[first_matching_dim] = factor.shape[0]
# Reshape factor for broadcasting.
factor = factor.reshape(factor_dims)
scaled_data_array = data_array * factor
scaled_data = data * factor

if unit is not None:
scaled_data_array.attrs["unit"] = unit
scaled_data.attrs["unit"] = unit
elif unit is None:
scaled_data_array.attrs.pop("unit", None)
return scaled_data_array
scaled_data.attrs.pop("unit", None)
return scaled_data
41 changes: 26 additions & 15 deletions tests/test_unit/test_transforms.py
Original file line number Diff line number Diff line change
@@ -8,11 +8,12 @@


def nparray_0_to_23() -> np.ndarray:
"""Create a 2D nparray from 0 to 23."""
return np.arange(0, 24).reshape(12, 2)


@pytest.fixture
def sample_data_array() -> xr.DataArray:
def sample_data() -> xr.DataArray:
"""Turn the nparray_0_to_23 into a DataArray."""
return data_array_with_dims_and_coords(nparray_0_to_23())

@@ -23,7 +24,9 @@ def data_array_with_dims_and_coords(
coords: dict[str, list[str]] = {"space": ["x", "y"]},
**attributes: Any,
) -> xr.DataArray:
""""""
"""Create a DataArray with given data, dimensions, coordinates, and
attributes (e.g. unit or factor).
"""
return xr.DataArray(
data,
dims=dims,
@@ -81,22 +84,22 @@ def data_array_with_dims_and_coords(
],
)
def test_scale(
sample_data_array: xr.DataArray,
sample_data: xr.DataArray,
optional_arguments: dict[str, Any],
expected_output: xr.DataArray,
):
expected_output = xr.DataArray(
expected_output,
dims=["time", "space"],
coords={"space": ["x", "y"]},
)
"""Test scaling with different factors and units."""
scaled_data = scale(sample_data, **optional_arguments)
xr.testing.assert_equal(scaled_data, expected_output)
assert scaled_data.attrs == expected_output.attrs

output_data_array = scale(sample_data_array, **optional_arguments)
xr.testing.assert_equal(output_data_array, expected_output)
assert output_data_array.attrs == expected_output.attrs

def test_scale_inverted_data():
"""Test scaling with transposed data along the correct dimension.

def test_scale_inverted_data() -> None:
The factor is reshaped to (1, 1, 4, 1) so that it can be broadcasted along
the third dimension ("y") which matches the length of the scaling factor.
"""
factor = [0.5, 2]
transposed_data = data_array_with_dims_and_coords(
nparray_0_to_23().transpose(), dims=["space", "time"]
@@ -117,6 +120,11 @@ def test_scale_inverted_data() -> None:
output_array, input_data * np.array(factor).reshape(1, 1, 4, 1)
)


def test_scale_first_matching_axis():
"""Test scaling when multiple axes match the scaling factor's length.
The scaling factor should be broadcasted along the first matching axis.
"""
factor = [0.5, 1]
data_shape = (2, 2)
numerical_data = np.arange(np.prod(data_shape)).reshape(data_shape)
@@ -125,7 +133,6 @@ def test_scale_inverted_data() -> None:
assert output_array.shape == input_data.shape
assert np.isclose(input_data.values[0] * 0.5, output_array.values[0]).all()
assert np.isclose(input_data.values[1], output_array.values[1]).all()
pass


@pytest.mark.parametrize(
@@ -154,13 +161,17 @@ def test_scale_inverted_data() -> None:
],
)
def test_scale_twice(
sample_data_array: xr.DataArray,
sample_data: xr.DataArray,
optional_arguments_1: dict[str, Any],
optional_arguments_2: dict[str, Any],
expected_output: xr.DataArray,
):
"""Test scaling when applied twice.
The second scaling operation should update the unit attribute if provided,
or remove it if None is passed explicitly or by default.
"""
output_data_array = scale(
scale(sample_data_array, **optional_arguments_1),
scale(sample_data, **optional_arguments_1),
**optional_arguments_2,
)