Skip to content

Commit

Permalink
Merge pull request #68 from jhlegarreta/AddMiscTypeHints
Browse files Browse the repository at this point in the history
ENH: Add type hints across miscellaneous methods
  • Loading branch information
oesteban authored Jan 27, 2025
2 parents 7e363d8 + f005b7e commit 18e9de5
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 14 deletions.
13 changes: 7 additions & 6 deletions src/nifreeze/data/dmri.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import h5py
import nibabel as nb
import numpy as np
import numpy.typing as npt
from nibabel.spatialimages import SpatialImage
from nitransforms.linear import Affine

Expand Down Expand Up @@ -369,11 +370,11 @@ def load(


def find_shelling_scheme(
bvals,
num_bins=DEFAULT_NUM_BINS,
multishell_nonempty_bin_count_thr=DEFAULT_MULTISHELL_BIN_COUNT_THR,
bval_cap=DEFAULT_HIGHB_THRESHOLD,
):
bvals: np.ndarray,
num_bins: int = DEFAULT_NUM_BINS,
multishell_nonempty_bin_count_thr: int = DEFAULT_MULTISHELL_BIN_COUNT_THR,
bval_cap: float = DEFAULT_HIGHB_THRESHOLD,
) -> tuple[str, list[npt.NDArray[np.floating]], list[np.floating]]:
"""
Find the shelling scheme on the given b-values.
Expand All @@ -390,7 +391,7 @@ def find_shelling_scheme(
Number of bins.
multishell_nonempty_bin_count_thr : :obj:`int`, optional
Bin count to consider a multi-shell scheme.
bval_cap : :obj:`int`, optional
bval_cap : :obj:`float`, optional
Maximum b-value to be considered in a multi-shell scheme.
Returns
Expand Down
4 changes: 2 additions & 2 deletions src/nifreeze/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class ModelFactory:
"""A factory for instantiating data models."""

@staticmethod
def init(model=None, **kwargs):
def init(model: str | None = None, **kwargs):
"""
Instantiate a diffusion model.
Expand Down Expand Up @@ -136,7 +136,7 @@ def __init__(self, dataset, stat="median", **kwargs):
super().__init__(dataset, **kwargs)
self._stat = stat

def fit_predict(self, index, **kwargs):
def fit_predict(self, index: int, **kwargs):
"""
Return the expectation map.
Expand Down
21 changes: 15 additions & 6 deletions src/nifreeze/model/dmri.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from nifreeze.data.dmri import (
DEFAULT_CLIP_PERCENTILE,
DTI_MIN_ORIENTATIONS,
DWI,
)
from nifreeze.model.base import BaseModel, ExpectationModel

Expand All @@ -51,7 +52,7 @@ class BaseDWIModel(BaseModel):
"_modelargs": "Arguments acceptable by the underlying DIPY-like model.",
}

def __init__(self, dataset, **kwargs):
def __init__(self, dataset: DWI, **kwargs):
r"""Initialization.
Parameters
Expand Down Expand Up @@ -117,7 +118,7 @@ def _fit(self, index, n_jobs=None, **kwargs):
self._model = None # Preempt further actions on the model
return n_jobs

def fit_predict(self, index, **kwargs):
def fit_predict(self, index: int, **kwargs):
"""
Predict asynchronously chunk-by-chunk the diffusion signal.
Expand All @@ -140,7 +141,7 @@ def fit_predict(self, index, **kwargs):
if n_models == 1:
predicted, _ = _exec_predict(self._model, **(kwargs | {"gtab": gradient, "S0": S0}))
else:
S0 = np.array_split(S0, n_models) if S0 is not None else [None] * n_models
S0 = np.array_split(S0, n_models) if S0 is not None else np.full(n_models, None)

predicted = [None] * n_models

Expand Down Expand Up @@ -173,7 +174,15 @@ class AverageDWIModel(ExpectationModel):

__slots__ = ("_th_low", "_th_high", "_detrend")

def __init__(self, dataset, stat="median", th_low=100, th_high=100, detrend=False, **kwargs):
def __init__(
self,
dataset: DWI,
stat: str = "median",
th_low: float = 100.0,
th_high: float = 100.0,
detrend: bool = False,
**kwargs,
):
r"""
Implement object initialization.
Expand All @@ -183,10 +192,10 @@ def __init__(self, dataset, stat="median", th_low=100, th_high=100, detrend=Fals
Reference to a DWI object.
stat : :obj:`str`, optional
Whether the summary statistic to apply is ``"mean"`` or ``"median"``.
th_low : :obj:`numbers.Number`, optional
th_low : :obj:`float`, optional
A lower bound for the b-value corresponding to the diffusion weighted images
that will be averaged.
th_high : :obj:`numbers.Number`, optional
th_high : :obj:`float`, optional
An upper bound for the b-value corresponding to the diffusion weighted images
that will be averaged.
detrend : :obj:`bool`, optional
Expand Down

0 comments on commit 18e9de5

Please sign in to comment.