Skip to content

Commit

Permalink
ENH: Test absent mask warning
Browse files Browse the repository at this point in the history
Test absent mask warning: parameterize the trivial model test so that
the warning raised for the case where the model is instatiated without a
mask can be tested.

Remove the corresponding message from the `filterwarnings` list.
  • Loading branch information
jhlegarreta committed Jan 18, 2025
1 parent a78af6c commit 6cbc7e7
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 6 deletions.
2 changes: 0 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,6 @@ filterwarnings = [
"ignore:Updating b0_threshold to.*:UserWarning",
# scikit-learn
"ignore:The optimal value found for dimension.*:sklearn.exceptions.ConvergenceWarning",
# masks
"ignore:No mask provided;.*:UserWarning",
]


Expand Down
6 changes: 5 additions & 1 deletion src/nifreeze/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@

from nifreeze.exceptions import ModelNotFittedError

mask_absence_warn_msg = (
"No mask provided; consider using a mask to avoid issues in model optimization."
)


class ModelFactory:
"""A factory for instantiating diffusion models."""
Expand Down Expand Up @@ -96,7 +100,7 @@ def __init__(self, mask=None, **kwargs):

# Setup brain mask
if mask is None:
warn("No mask provided; consider using a mask to avoid issues in model optimization.")
warn(mask_absence_warn_msg)

self._mask = mask

Expand Down
19 changes: 16 additions & 3 deletions test/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
#
"""Unit tests exercising models."""

import contextlib

import numpy as np
import pytest
from dipy.sims.voxel import single_tensor
Expand All @@ -31,11 +33,13 @@
from nifreeze.data.splitting import lovo_split
from nifreeze.exceptions import ModelNotFittedError
from nifreeze.model._dipy import GaussianProcessModel
from nifreeze.model.base import mask_absence_warn_msg
from nifreeze.model.dmri import DEFAULT_MAX_S0, DEFAULT_MIN_S0
from nifreeze.testing import simulations as _sim


def test_trivial_model():
@pytest.mark.parametrize("use_mask", (False, True))
def test_trivial_model(use_mask):
"""Check the implementation of the trivial B0 model."""

rng = np.random.default_rng(1234)
Expand All @@ -44,15 +48,24 @@ def test_trivial_model():
with pytest.raises(TypeError):
model.TrivialModel()

_S0 = rng.normal(size=(2, 2, 2))
size = (2, 2, 2)
mask = None
if use_mask:
mask = np.ones(size, dtype=bool)
context = contextlib.nullcontext()
else:
context = pytest.warns(UserWarning, match=mask_absence_warn_msg)

_S0 = rng.normal(size=size)

_clipped_S0 = np.clip(
_S0.astype("float32") / _S0.max(),
a_min=DEFAULT_MIN_S0,
a_max=DEFAULT_MAX_S0,
)

tmodel = model.TrivialModel(predicted=_clipped_S0)
with context:
tmodel = model.TrivialModel(mask=mask, predicted=_clipped_S0)

data = None
assert tmodel.fit(data) is None
Expand Down

0 comments on commit 6cbc7e7

Please sign in to comment.