diff --git a/pyproject.toml b/pyproject.toml index 5e3cb71a..1d594e0c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,9 +17,7 @@ cufinufft = ["cufinufft<2.3", "cupy-cuda12x"] finufft = ["finufft"] pynfft = ["pynfft2>=1.4.3; python_version < '3.12'", "numpy>=2.0.0; python_version < '3.12'"] pynufft = ["pynufft"] -io = ["pymapvbvd"] -smaps = ["scikit-image"] -sampling = ["pywavelets", "scikit-learn"] +extra = ["pymapvbvd", "scikit-image", "scikit-learn", "pywavelets"] autodiff = ["torch"] diff --git a/src/mrinufft/extras/smaps.py b/src/mrinufft/extras/smaps.py index c3f7018f..5060c3c3 100644 --- a/src/mrinufft/extras/smaps.py +++ b/src/mrinufft/extras/smaps.py @@ -149,8 +149,15 @@ def low_frequency( """ # defer import to later to prevent circular import from mrinufft import get_operator - from skimage.filters import threshold_otsu, gaussian - from skimage.morphology import convex_hull_image + try: + from skimage.filters import threshold_otsu, gaussian + from skimage.morphology import convex_hull_image + except ImportError as err: + raise ImportError( + "The scikit-image module is not available. Please install " + "it along with the [extra] dependencies " + "or using `pip install scikit-image`." + ) from err k_space, samples, dc = _extract_kspace_center( kspace_data=kspace_data, diff --git a/src/mrinufft/io/siemens.py b/src/mrinufft/io/siemens.py index dd8c96d7..d172e355 100644 --- a/src/mrinufft/io/siemens.py +++ b/src/mrinufft/io/siemens.py @@ -60,6 +60,12 @@ def read_siemens_rawdat( "The mapVBVD module is not available. Please install it using " "the following command: pip install pymapVBVD" ) from err + except ImportError as err: + raise ImportError( + "The mapVBVD module is not available. Please install " + "it along with the [extra] dependencies " + "or using `pip install pymapVBVD`." + ) from err twixObj = mapVBVD(filename) if isinstance(twixObj, list): twixObj = twixObj[-1] diff --git a/src/mrinufft/operators/stacked.py b/src/mrinufft/operators/stacked.py index c3136bd1..cdb2bb62 100644 --- a/src/mrinufft/operators/stacked.py +++ b/src/mrinufft/operators/stacked.py @@ -23,7 +23,6 @@ try: import cupy as cp from cupyx.scipy import fft as cpfft - except ImportError: CUPY_AVAILABLE = False diff --git a/src/mrinufft/trajectories/sampling.py b/src/mrinufft/trajectories/sampling.py index cbaa5877..b5fb465d 100644 --- a/src/mrinufft/trajectories/sampling.py +++ b/src/mrinufft/trajectories/sampling.py @@ -4,16 +4,10 @@ import numpy.fft as nf import numpy.linalg as nl import numpy.random as nr -from sklearn.cluster import BisectingKMeans, KMeans from tqdm.auto import tqdm from .utils import KMAX -try: - import pywt as pw -except ImportError: - pw = None - def sample_from_density( nb_samples, density, method="random", *, dim_compensation="auto" @@ -59,7 +53,15 @@ def sample_from_density( "Variable density sampling with continuous trajectories." SIAM Journal on Imaging Sciences 7, no. 4 (2014): 1962-1992. """ - rng = nr.default_rng() + try: + from sklearn.cluster import BisectingKMeans, KMeans + except ImportError as err: + raise ImportError( + "The scikit-learn module is not available. Please install " + "it along with the [extra] dependencies " + "or using `pip install scikit-learn`." + ) from err + # Define dimension variables shape = np.array(density.shape) @@ -80,6 +82,7 @@ def sample_from_density( density = density / np.sum(density) # Sample using specified method + rng = nr.default_rng() if method == "random": choices = rng.choice( np.arange(max_nb_samples), @@ -259,11 +262,14 @@ def create_chauffert_density(shape, wavelet_basis, nb_wavelet_scales, verbose=Fa In 2013 IEEE 10th International Symposium on Biomedical Imaging, pp. 298-301. IEEE, 2013. """ - if pw is None: + try: + import pywt + except ImportError as err: raise ImportError( - "The PyWavelets package must be installed " - "as an additional dependency for this function." - ) + "The PyWavelets module is not available. Please install " + "it along with the [extra] dependencies " + "or using `pip install pywavelets`." + ) from err nb_dims = len(shape) indices = np.indices(shape).reshape((nb_dims, -1)).T @@ -330,11 +336,14 @@ def create_fast_chauffert_density(shape, wavelet_basis, nb_wavelet_scales): In 2013 IEEE 10th International Symposium on Biomedical Imaging, pp. 298-301. IEEE, 2013. """ - if pw is None: + try: + import pywt + except ImportError as err: raise ImportError( - "The PyWavelets package must be installed " - "as an additional dependency for this function." - ) + "The PyWavelets module is not available. Please install " + "it along with the [extra] dependencies " + "or using `pip install pywavelets`." + ) from err nb_dims = len(shape)