Skip to content

Commit

Permalink
Combine tests for both cart pol transforms for invalid data arrays an…
Browse files Browse the repository at this point in the history
…d for data arrays with nans
  • Loading branch information
sfmig committed Aug 8, 2024
1 parent db757b8 commit 67ceaec
Showing 1 changed file with 102 additions and 44 deletions.
146 changes: 102 additions & 44 deletions tests/test_unit/test_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,30 +104,42 @@ def trajectory_random_pol(trajectory_random_cart):


# ---- with nan values ----
@pytest.fixture(
params=[
"trajectory_x_eq_0_cart",
"trajectory_x_eq_y_cart",
"trajectory_random_cart",
]
)
def trajectories_cart_with_nan(request):
trajectory_data_array = request.getfixturevalue(request.param)
trajectory_data_array.loc[{"time": slice(2, 3)}] = np.nan
return trajectory_data_array


@pytest.fixture(
params=[
"trajectory_x_eq_0_pol",
"trajectory_x_eq_y_pol",
"trajectory_random_pol",
]
)
def trajectories_pol_with_nan(request):
trajectory_data_array = request.getfixturevalue(request.param)
trajectory_data_array.loc[{"time": slice(2, 3)}] = np.nan
return trajectory_data_array
@pytest.fixture
def trajectory_random_cart_with_nan(trajectory_random_cart):
trajectory_random_cart.loc[{"time": slice(2, 3)}] = np.nan
return trajectory_random_cart


@pytest.fixture
def trajectory_random_pol_with_nan(trajectory_random_pol):
trajectory_random_pol.loc[{"time": slice(2, 3)}] = np.nan
return trajectory_random_pol


# @pytest.fixture(
# params=[
# "trajectory_x_eq_0_cart",
# "trajectory_x_eq_y_cart",
# "trajectory_random_cart",
# ]
# )
# def trajectories_cart_with_nan(request):
# trajectory_data_array = request.getfixturevalue(request.param)
# trajectory_data_array.loc[{"time": slice(2, 3)}] = np.nan
# return trajectory_data_array


# @pytest.fixture(
# params=[
# "trajectory_x_eq_0_pol",
# "trajectory_x_eq_y_pol",
# "trajectory_random_pol",
# ]
# )
# def trajectories_pol_with_nan(request):
# trajectory_data_array = request.getfixturevalue(request.param)
# trajectory_data_array.loc[{"time": slice(2, 3)}] = np.nan
# return trajectory_data_array


# ---- invalid data arrays ----
Expand All @@ -148,6 +160,23 @@ def trajectory_cart_with_missing_space_coord(trajectory_random_cart):
return trajectory_random_cart


@pytest.fixture
def trajectory_pol_with_missing_space_dim(trajectory_random_pol):
"""Return an xarray.Dataset with Cartesian and polar coordinates,
where the required ``space`` dimension is missing.
"""
return trajectory_random_pol.rename({"space_pol": "spice"})


@pytest.fixture
def trajectory_pol_with_missing_space_coord(trajectory_random_pol):
"""Return an xarray.DataArray where the required ``space["x"]`` and
``space["y"]`` coordinates are missing.
"""
trajectory_random_pol["space_pol"] = ["a", "b"]
return trajectory_random_pol


@pytest.mark.parametrize(
"position_data_array_cart, expected_rho, expected_phi",
[
Expand Down Expand Up @@ -195,36 +224,65 @@ def test_cart2pol(
)


def test_cart2pol_with_nan(trajectories_cart_with_nan):
@pytest.mark.parametrize(
"position_array_with_nan, cart_polar_transform_fn",
[
("trajectory_random_cart_with_nan", vector.cart2pol),
("trajectory_random_pol_with_nan", vector.pol2cart),
],
)
def test_cart_pol_transforms_with_nans(
position_array_with_nan, cart_polar_transform_fn, request
):
"""Test Cartesian to polar coordinates with NaN values."""
position_array_pol = vector.cart2pol(trajectories_cart_with_nan)

n_expected_nans_per_coord = 2
input_array_with_nan = request.getfixturevalue(position_array_with_nan)
output_array_with_nan = cart_polar_transform_fn(input_array_with_nan)

# Check that NaN values are preserved
assert (
sum(np.isnan(position_array_pol.sel(space_pol="rho")))
== n_expected_nans_per_coord
)
assert (
sum(np.isnan(position_array_pol.sel(space_pol="phi")))
== n_expected_nans_per_coord
np.isnan(output_array_with_nan).sum()
== np.isnan(input_array_with_nan).sum()
)
# Check equal to 2 timepoints?


@pytest.mark.parametrize(
"invalid_position_array",
"invalid_position_array, vector_fn, expected_error_msg",
[
"trajectory_cart_with_missing_space_dim",
"trajectory_cart_with_missing_space_coord",
(
"trajectory_cart_with_missing_space_dim",
vector.cart2pol,
"Input data must contain ['x', 'y'] in the 'space' coordinates.",
),
(
"trajectory_cart_with_missing_space_coord",
vector.cart2pol,
"Input data must contain ['x', 'y'] in the 'space' coordinates.",
),
(
"trajectory_pol_with_missing_space_dim",
vector.pol2cart,
(
"Input data must contain ['rho', 'phi'] in the 'space_pol' "
"coordinates."
),
),
(
"trajectory_pol_with_missing_space_coord",
vector.pol2cart,
(
"Input data must contain ['rho', 'phi'] in the 'space_pol' "
"coordinates."
),
),
],
)
def test_cart2pol_invalid_array(invalid_position_array, request):
"""Test Cartesian to polar coordinates with invalid input."""
def test_cart_pol_transforms_invalid_array(
invalid_position_array, vector_fn, expected_error_msg, request
):
"""Test vector utils with invalid input."""
invalid_position_array = request.getfixturevalue(invalid_position_array)
with pytest.raises(ValueError) as excinfo:
vector.cart2pol(request.getfixturevalue(invalid_position_array))
vector_fn(invalid_position_array)

assert (
"Input data must contain ['x', 'y'] in the 'space' coordinates."
in str(excinfo.value)
)
assert expected_error_msg in str(excinfo.value)

0 comments on commit 67ceaec

Please sign in to comment.