diff --git a/pyproject.toml b/pyproject.toml index df3993a..79c2ea1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ dependencies = [ "scikit-image>=0.15.0", "scikit_learn>=1.3.0", "scipy>=1.8.0", + "typing_extensions >=4.12", ] dynamic = ["version"] diff --git a/src/nifreeze/data/base.py b/src/nifreeze/data/base.py index 4fee5a0..4cf47ff 100644 --- a/src/nifreeze/data/base.py +++ b/src/nifreeze/data/base.py @@ -27,7 +27,7 @@ from collections import namedtuple from pathlib import Path from tempfile import mkdtemp -from typing import Any, Generic, TypeVarTuple +from typing import Any, Generic import attr import h5py @@ -35,6 +35,7 @@ import numpy as np from nibabel.spatialimages import SpatialHeader, SpatialImage from nitransforms.linear import Affine +from typing_extensions import TypeVarTuple, Unpack from nifreeze.utils.ndimage import load_api @@ -58,7 +59,7 @@ def _cmp(lh: Any, rh: Any) -> bool: @attr.s(slots=True) -class BaseDataset(Generic[*Ts]): +class BaseDataset(Generic[Unpack[Ts]]): """ Base dataset representation structure. @@ -99,13 +100,12 @@ def __len__(self) -> int: return self.dataobj.shape[-1] - def _getextra(self, idx: int | slice | tuple | np.ndarray) -> tuple[*Ts]: - # PY312: Default values for TypeVarTuples are not yet supported + def _getextra(self, idx: int | slice | tuple | np.ndarray) -> tuple[Unpack[Ts]]: return () # type: ignore[return-value] def __getitem__( self, idx: int | slice | tuple | np.ndarray - ) -> tuple[np.ndarray, np.ndarray | None, *Ts]: + ) -> tuple[np.ndarray, np.ndarray | None, Unpack[Ts]]: """ Returns volume(s) and corresponding affine(s) through fancy indexing. diff --git a/src/nifreeze/estimator.py b/src/nifreeze/estimator.py index be1ed33..4c8a171 100644 --- a/src/nifreeze/estimator.py +++ b/src/nifreeze/estimator.py @@ -26,9 +26,10 @@ from pathlib import Path from tempfile import TemporaryDirectory -from typing import Self, TypeVar +from typing import TypeVar from tqdm import tqdm +from typing_extensions import Self from nifreeze.data.base import BaseDataset from nifreeze.model.base import BaseModel, ModelFactory diff --git a/test/test_data_base.py b/test/test_data_base.py index 4afa644..d8c65e2 100644 --- a/test/test_data_base.py +++ b/test/test_data_base.py @@ -53,25 +53,25 @@ def test_len(random_dataset: BaseDataset): assert len(random_dataset) == 5 # last dimension is 5 volumes -def test_getitem_volume_index(random_dataset: BaseDataset[()]): +def test_getitem_volume_index(random_dataset: BaseDataset): """ Test that __getitem__ returns the correct (volume, affine) tuple. By default, motion_affines is None, so we expect to get None for the affine. """ - # Single volume - volume0, aff0 = random_dataset[0] + # Single volume # Note that the type ignore can be removed once we can use *Ts + volume0, aff0 = random_dataset[0] # type: ignore[misc] # PY310 assert volume0.shape == (32, 32, 32) # No transforms have been applied yet, so there's no motion_affines array assert aff0 is None # Slice of volumes - volume_slice, aff_slice = random_dataset[2:4] + volume_slice, aff_slice = random_dataset[2:4] # type: ignore[misc] # PY310 assert volume_slice.shape == (32, 32, 32, 2) assert aff_slice is None -def test_set_transform(random_dataset: BaseDataset[()]): +def test_set_transform(random_dataset: BaseDataset): """ Test that calling set_transform changes the data and motion_affines. For simplicity, we'll apply an identity transform and check that motion_affines is updated. @@ -83,7 +83,7 @@ def test_set_transform(random_dataset: BaseDataset[()]): random_dataset.set_transform(idx, affine, order=1) # Data shouldn't have changed (since transform is identity). - volume0, aff0 = random_dataset[idx] + volume0, aff0 = random_dataset[idx] # type: ignore[misc] # PY310 assert np.allclose(data_before, volume0) # motion_affines should be created and match the transform matrix.