Skip to content

Commit

Permalink
Gather extra dependencies, homogeneize extra imports
Browse files Browse the repository at this point in the history
  • Loading branch information
Daval-G committed Dec 10, 2024
1 parent 82172be commit 34144b8
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 21 deletions.
4 changes: 1 addition & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@ cufinufft = ["cufinufft<2.3", "cupy-cuda12x"]
finufft = ["finufft"]
pynfft = ["pynfft2>=1.4.3; python_version < '3.12'", "numpy>=2.0.0; python_version < '3.12'"]
pynufft = ["pynufft"]
io = ["pymapvbvd"]
smaps = ["scikit-image"]
sampling = ["pywavelets", "scikit-learn"]
extra = ["pymapvbvd", "scikit-image", "scikit-learn", "pywavelets"]
autodiff = ["torch"]


Expand Down
11 changes: 9 additions & 2 deletions src/mrinufft/extras/smaps.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,15 @@ def low_frequency(
"""
# defer import to later to prevent circular import
from mrinufft import get_operator
from skimage.filters import threshold_otsu, gaussian
from skimage.morphology import convex_hull_image
try:
from skimage.filters import threshold_otsu, gaussian
from skimage.morphology import convex_hull_image
except ImportError as err:
raise ImportError(
"The scikit-image module is not available. Please install "
"it along with the [extra] dependencies "
"or using `pip install scikit-image`."
) from err

k_space, samples, dc = _extract_kspace_center(
kspace_data=kspace_data,
Expand Down
6 changes: 6 additions & 0 deletions src/mrinufft/io/siemens.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ def read_siemens_rawdat(
"The mapVBVD module is not available. Please install it using "
"the following command: pip install pymapVBVD"
) from err
except ImportError as err:
raise ImportError(
"The mapVBVD module is not available. Please install "
"it along with the [extra] dependencies "
"or using `pip install pymapVBVD`."
) from err
twixObj = mapVBVD(filename)
if isinstance(twixObj, list):
twixObj = twixObj[-1]
Expand Down
1 change: 0 additions & 1 deletion src/mrinufft/operators/stacked.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
try:
import cupy as cp
from cupyx.scipy import fft as cpfft

except ImportError:
CUPY_AVAILABLE = False

Expand Down
39 changes: 24 additions & 15 deletions src/mrinufft/trajectories/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,10 @@
import numpy.fft as nf
import numpy.linalg as nl
import numpy.random as nr
from sklearn.cluster import BisectingKMeans, KMeans
from tqdm.auto import tqdm

from .utils import KMAX

try:
import pywt as pw
except ImportError:
pw = None


def sample_from_density(
nb_samples, density, method="random", *, dim_compensation="auto"
Expand Down Expand Up @@ -59,7 +53,15 @@ def sample_from_density(
"Variable density sampling with continuous trajectories."
SIAM Journal on Imaging Sciences 7, no. 4 (2014): 1962-1992.
"""
rng = nr.default_rng()
try:
from sklearn.cluster import BisectingKMeans, KMeans
except ImportError as err:
raise ImportError(
"The scikit-learn module is not available. Please install "
"it along with the [extra] dependencies "
"or using `pip install scikit-learn`."
) from err


# Define dimension variables
shape = np.array(density.shape)
Expand All @@ -80,6 +82,7 @@ def sample_from_density(
density = density / np.sum(density)

# Sample using specified method
rng = nr.default_rng()
if method == "random":
choices = rng.choice(
np.arange(max_nb_samples),
Expand Down Expand Up @@ -259,11 +262,14 @@ def create_chauffert_density(shape, wavelet_basis, nb_wavelet_scales, verbose=Fa
In 2013 IEEE 10th International Symposium on Biomedical Imaging,
pp. 298-301. IEEE, 2013.
"""
if pw is None:
try:
import pywt
except ImportError as err:
raise ImportError(
"The PyWavelets package must be installed "
"as an additional dependency for this function."
)
"The PyWavelets module is not available. Please install "
"it along with the [extra] dependencies "
"or using `pip install pywavelets`."
) from err

nb_dims = len(shape)
indices = np.indices(shape).reshape((nb_dims, -1)).T
Expand Down Expand Up @@ -330,11 +336,14 @@ def create_fast_chauffert_density(shape, wavelet_basis, nb_wavelet_scales):
In 2013 IEEE 10th International Symposium on Biomedical Imaging,
pp. 298-301. IEEE, 2013.
"""
if pw is None:
try:
import pywt
except ImportError as err:
raise ImportError(
"The PyWavelets package must be installed "
"as an additional dependency for this function."
)
"The PyWavelets module is not available. Please install "
"it along with the [extra] dependencies "
"or using `pip install pywavelets`."
) from err

nb_dims = len(shape)

Expand Down

0 comments on commit 34144b8

Please sign in to comment.