Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add transforms module with scale function #384

Merged
Merged
Changes from 1 commit
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
parametrize and refactor test_scale_space_dimension
stellaprins committed Jan 27, 2025
commit 362196fcdd66554d8648770876bebdc05c34d3aa
51 changes: 16 additions & 35 deletions tests/test_unit/test_transforms.py
Original file line number Diff line number Diff line change
@@ -96,49 +96,30 @@ def test_scale(
assert scaled_data.attrs == expected_output.attrs


def test_scale_space_dimension_two_dims():
@pytest.mark.parametrize(
"dims, data_shape",
[
(["time", "space"], (3, 2)),
(["space", "time"], (2, 3)),
(["time", "individuals", "keypoints", "space"], (3, 6, 4, 2)),
],
ids=["time-space", "space-time", "time-individuals-keypoints-space"],
)
def test_scale_space_dimension(dims: list[str], data_shape):
"""Test scaling with transposed data along the correct dimension.

The scaling factor should be broadcasted along the space axis irrespective
of the order of the dimensions in the input data.
"""
factor = [0.5, 2]

data_space_second = data_array_with_dims_and_coords(
nparray_0_to_23(), dims=["time", "space"]
)
data_space_first = data_array_with_dims_and_coords(
nparray_0_to_23().transpose(), dims=["space", "time"]
)

scaled_data_space_second = scale(data_space_second, factor=factor)
scaled_data_space_first = scale(data_space_first, factor=factor)

xr.testing.assert_equal(
scaled_data_space_second, scaled_data_space_first.transpose()
)


def test_scale_space_dimension_four_dims():
"""Test scaling with data having four dimensions.

The scaling factor should be broadcasted along the space axis irrespective
of the order of the dimensions in the input data.
"""
factor = [0.5, 2]
data_shape = (3, 6, 4, 2)
numerical_data = np.arange(np.prod(data_shape)).reshape(data_shape)
data_space_fourth = xr.DataArray(
numerical_data, dims=["time", "individuals", "keypoints", "space"]
)
scaled_data_space_fourth = scale(data_space_fourth, factor=factor)
data = xr.DataArray(numerical_data, dims=dims)
scaled_data = scale(data, factor=factor)
broadcast_list = [1 if dim != "space" else len(factor) for dim in dims]
expected_output_data = data * np.array(factor).reshape(broadcast_list)

assert scaled_data_space_fourth.shape == data_space_fourth.shape

expected_output_data = data_space_fourth * np.array(factor).reshape(
1, 1, 1, 2
)
xr.testing.assert_equal(scaled_data_space_fourth, expected_output_data)
assert scaled_data.shape == data.shape
xr.testing.assert_equal(scaled_data, expected_output_data)


@pytest.mark.parametrize(