diff --git a/pyproject.toml b/pyproject.toml index d3274501..8f7e00f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ finufft = ["finufft"] pynfft = ["pynfft2", "cython<3.0.0"] pynufft = ["pynufft"] io = ["pymapvbvd"] +smaps = ["scikit-image"] test = ["pytest<8.0.0", "pytest-cov", "pytest-xdist", "pytest-sugar", "pytest-cases"] dev = ["black", "isort", "ruff"] diff --git a/src/mrinufft/extras/__init__.py b/src/mrinufft/extras/__init__.py new file mode 100644 index 00000000..6ba1d4a4 --- /dev/null +++ b/src/mrinufft/extras/__init__.py @@ -0,0 +1,10 @@ +"""Sensitivity map estimation methods.""" + +from .smaps import low_frequency +from .utils import get_smaps + + +__all__ = [ + "low_frequency", + "get_smaps", +] diff --git a/src/mrinufft/extras/smaps.py b/src/mrinufft/extras/smaps.py new file mode 100644 index 00000000..9051b8d3 --- /dev/null +++ b/src/mrinufft/extras/smaps.py @@ -0,0 +1,177 @@ +"""SMaps module for sensitivity maps estimation.""" + +from mrinufft.density.utils import flat_traj +from mrinufft.operators.base import get_array_module +from .utils import register_smaps +import numpy as np +from typing import Tuple + + +def _extract_kspace_center( + kspace_data, + kspace_loc, + threshold=None, + density=None, + window_fun="ellipse", +): + r"""Extract k-space center and corresponding sampling locations. + + The extracted center of the k-space, i.e. both the kspace locations and + kspace values. If the density compensators are passed, the corresponding + compensators for the center of k-space data will also be returned. The + return dtypes for density compensation and kspace data is same as input + + Parameters + ---------- + kspace_data: numpy.ndarray + The value of the samples + kspace_loc: numpy.ndarray + The samples location in the k-space domain (between [-0.5, 0.5[) + threshold: tuple or float + The threshold used to extract the k_space center (between (0, 1]) + window_fun: "Hann", "Hanning", "Hamming", or a callable, default None. + The window function to apply to the selected data. It is computed with + the center locations selected. Only works with circular mask. + If window_fun is a callable, it takes as input the array (n_samples x n_dims) + of sample positions and returns an array of n_samples weights to be + applied to the selected k-space values, before the smaps estimation. + + Returns + ------- + data_thresholded: ndarray + The k-space values in the center region. + center_loc: ndarray + The locations in the center region. + density_comp: ndarray, optional + The density compensation weights (if requested) + + Notes + ----- + The Hann (or Hanning) and Hamming windows of width :math:`2\theta` are defined as: + .. math:: + + w(x,y) = a_0 - (1-a_0) * \cos(\pi * \sqrt{x^2+y^2}/\theta), + \sqrt{x^2+y^2} \le \theta + + In the case of Hann window :math:`a_0=0.5`. + For Hamming window we consider the optimal value in the equiripple sense: + :math:`a_0=0.53836`. + .. Wikipedia:: https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows + + """ + xp = get_array_module(kspace_data) + if isinstance(threshold, float): + threshold = (threshold,) * kspace_loc.shape[1] + + if window_fun == "rect": + data_ordered = xp.copy(kspace_data) + index = xp.linspace( + 0, kspace_loc.shape[0] - 1, kspace_loc.shape[0], dtype=xp.int64 + ) + condition = xp.logical_and.reduce( + tuple( + xp.abs(kspace_loc[:, i]) <= threshold[i] for i in range(len(threshold)) + ) + ) + index = xp.extract(condition, index) + center_locations = kspace_loc[index, :] + data_thresholded = data_ordered[:, index] + dc = density[index] + return data_thresholded, center_locations, dc + else: + if callable(window_fun): + window = window_fun(center_locations) + else: + if window_fun in ["hann", "hanning", "hamming"]: + radius = xp.linalg.norm(kspace_loc, axis=1) + a_0 = 0.5 if window_fun in ["hann", "hanning"] else 0.53836 + window = a_0 + (1 - a_0) * xp.cos(xp.pi * radius / threshold) + elif window_fun == "ellipse": + window = xp.sum(kspace_loc**2 / xp.asarray(threshold) ** 2, axis=1) <= 1 + else: + raise ValueError("Unsupported window function.") + data_thresholded = window * kspace_data + # Return k-space locations & density just for consistency + return data_thresholded, kspace_loc, density + + +@register_smaps +@flat_traj +def low_frequency( + traj, + shape, + kspace_data, + backend, + threshold: float | Tuple[float, ...] = 0.1, + density=None, + window_fun: str = "ellipse", + blurr_factor: float = 0, + mask: bool = False, +): + """ + Calculate low-frequency sensitivity maps. + + Parameters + ---------- + traj : numpy.ndarray + The trajectory of the samples. + shape : tuple + The shape of the image. + kspace_data : numpy.ndarray + The k-space data. + threshold : float, or tuple of float, optional + The threshold used for extracting the k-space center. + By default it is 0.1 + backend : str + The backend used for the operator. + density : numpy.ndarray, optional + The density compensation weights. + window_fun: "Hann", "Hanning", "Hamming", or a callable, default None. + The window function to apply to the selected data. It is computed with + the center locations selected. Only works with circular mask. + If window_fun is a callable, it takes as input the array (n_samples x n_dims) + of sample positions and returns an array of n_samples weights to be + applied to the selected k-space values, before the smaps estimation. + blurr_factor : float, optional + The blurring factor for smoothing the sensitivity maps. + mask: bool, optional default `False` + Whether the Sensitivity maps must be masked + + Returns + ------- + Smaps : numpy.ndarray + The low-frequency sensitivity maps. + SOS : numpy.ndarray + The sum of squares of the sensitivity maps. + """ + # defer import to later to prevent circular import + from mrinufft import get_operator + from skimage.filters import threshold_otsu, gaussian + from skimage.morphology import convex_hull_image + + k_space, samples, dc = _extract_kspace_center( + kspace_data=kspace_data, + kspace_loc=traj, + threshold=threshold, + density=density, + window_fun=window_fun, + ) + smaps_adj_op = get_operator(backend)( + samples, shape, density=dc, n_coils=k_space.shape[0] + ) + Smaps = smaps_adj_op.adj_op(k_space) + SOS = np.linalg.norm(Smaps, axis=0) + if mask: + thresh = threshold_otsu(SOS) + # Create convex hull from mask + convex_hull = convex_hull_image(SOS > thresh) + Smaps = Smaps * convex_hull + # Smooth out the sensitivity maps + if blurr_factor > 0: + Smaps = gaussian(Smaps, sigma=blurr_factor * np.asarray(shape)) + # Re-normalize the sensitivity maps + if mask or blurr_factor > 0: + # ReCalculate SOS with a minor eps to ensure divide by 0 is ok + SOS = np.linalg.norm(Smaps, axis=0) + 1e-10 + Smaps = Smaps / SOS + return Smaps, SOS diff --git a/src/mrinufft/extras/utils.py b/src/mrinufft/extras/utils.py new file mode 100644 index 00000000..5c9a7b9d --- /dev/null +++ b/src/mrinufft/extras/utils.py @@ -0,0 +1,20 @@ +"""Utils for extras module.""" + +from mrinufft._utils import MethodRegister + +register_smaps = MethodRegister("sensitivity_maps") + + +def get_smaps(name, *args, **kwargs): + """Get the density compensation function from its name.""" + try: + method = register_smaps.registry["sensitivity_maps"][name] + except KeyError as e: + raise ValueError( + f"Unknown density compensation method {name}. Available methods are \n" + f"{list(register_smaps.registry['sensitivity_maps'].keys())}" + ) from e + + if args or kwargs: + return method(*args, **kwargs) + return method diff --git a/src/mrinufft/io/__init__.py b/src/mrinufft/io/__init__.py index 039d0600..f77d3b32 100644 --- a/src/mrinufft/io/__init__.py +++ b/src/mrinufft/io/__init__.py @@ -1,7 +1,8 @@ """Input/Output module for trajectories and data.""" from .cfl import traj2cfl, cfl2traj -from .nsp import read_trajectory, write_trajectory +from .nsp import read_trajectory, write_trajectory, read_arbgrad_rawdat +from .siemens import read_siemens_rawdat __all__ = [ @@ -9,4 +10,6 @@ "cfl2traj", "read_trajectory", "write_trajectory", + "read_arbgrad_rawdat", + "read_siemens_rawdat", ] diff --git a/src/mrinufft/io/nsp.py b/src/mrinufft/io/nsp.py index 5b4da407..3fafb306 100644 --- a/src/mrinufft/io/nsp.py +++ b/src/mrinufft/io/nsp.py @@ -6,6 +6,7 @@ import numpy as np from datetime import datetime from array import array +from .siemens import read_siemens_rawdat from mrinufft.trajectories.utils import ( KMAX, @@ -392,7 +393,7 @@ def read_trajectory( return kspace_loc, params -def read_siemens_rawdat( +def read_arbgrad_rawdat( filename: str, removeOS: bool = False, squeeze: bool = True, @@ -429,32 +430,10 @@ def read_siemens_rawdat( You can install it using the following command: `pip install pymapVBVD` """ - try: - from mapvbvd import mapVBVD - except ImportError as err: - raise ImportError( - "The mapVBVD module is not available. Please install it using " - "the following command: pip install pymapVBVD" - ) from err - twixObj = mapVBVD(filename) - if isinstance(twixObj, list): - twixObj = twixObj[-1] - twixObj.image.flagRemoveOS = removeOS - twixObj.image.squeeze = squeeze - raw_kspace = twixObj.image[""] - data = np.moveaxis(raw_kspace, 0, 2) - hdr = { - "n_coils": int(twixObj.image.NCha), - "n_shots": int(twixObj.image.NLin), - "n_contrasts": int(twixObj.image.NSet), - "n_adc_samples": int(twixObj.image.NCol), - "n_slices": int(twixObj.image.NSli), - } - data = data.reshape( - hdr["n_coils"], - hdr["n_shots"] * hdr["n_adc_samples"], - hdr["n_slices"], - hdr["n_contrasts"], + data, hdr, twixObj = read_siemens_rawdat( + filename=filename, + removeOS=removeOS, + squeeze=squeeze, ) if "ARBGRAD_VE11C" in data_type: hdr["type"] = "ARBGRAD_GRE" diff --git a/src/mrinufft/io/siemens.py b/src/mrinufft/io/siemens.py new file mode 100644 index 00000000..9cc782aa --- /dev/null +++ b/src/mrinufft/io/siemens.py @@ -0,0 +1,74 @@ +"""Siemens specific rawdat reader, wrapper over pymapVBVD.""" + +import numpy as np + + +def read_siemens_rawdat( + filename: str, + removeOS: bool = False, + squeeze: bool = True, + return_twix: bool = True, +): # pragma: no cover + """Read raw data from a Siemens MRI file. + + Parameters + ---------- + filename : str + The path to the Siemens MRI file. + removeOS : bool, optional + Whether to remove the oversampling, by default False. + squeeze : bool, optional + Whether to squeeze the dimensions of the data, by default True. + data_type : str, optional + The type of data to read, by default 'ARBGRAD_VE11C'. + return_twix : bool, optional + Whether to return the twix object, by default True. + + Returns + ------- + data: ndarray + Imported data formatted as n_coils X n_samples X n_slices X n_contrasts + hdr: dict + Extra information about the data parsed from the twix file + + Raises + ------ + ImportError + If the mapVBVD module is not available. + + Notes + ----- + This function requires the mapVBVD module to be installed. + You can install it using the following command: + `pip install pymapVBVD` + """ + try: + from mapvbvd import mapVBVD + except ImportError as err: + raise ImportError( + "The mapVBVD module is not available. Please install it using " + "the following command: pip install pymapVBVD" + ) from err + twixObj = mapVBVD(filename) + if isinstance(twixObj, list): + twixObj = twixObj[-1] + twixObj.image.flagRemoveOS = removeOS + twixObj.image.squeeze = squeeze + raw_kspace = twixObj.image[""] + data = np.moveaxis(raw_kspace, 0, 2) + hdr = { + "n_coils": int(twixObj.image.NCha), + "n_shots": int(twixObj.image.NLin), + "n_contrasts": int(twixObj.image.NSet), + "n_adc_samples": int(twixObj.image.NCol), + "n_slices": int(twixObj.image.NSli), + } + data = data.reshape( + hdr["n_coils"], + hdr["n_shots"] * hdr["n_adc_samples"], + hdr["n_slices"], + hdr["n_contrasts"], + ) + if return_twix: + return data, hdr, twixObj + return data, hdr diff --git a/src/mrinufft/operators/base.py b/src/mrinufft/operators/base.py index 0bec7fcf..034d6bcf 100644 --- a/src/mrinufft/operators/base.py +++ b/src/mrinufft/operators/base.py @@ -14,6 +14,7 @@ from mrinufft.operators.interfaces.utils import is_cuda_array, is_host_array from mrinufft.density import get_density +from mrinufft.extras import get_smaps CUPY_AVAILABLE = True try: @@ -253,6 +254,42 @@ def with_off_resonnance_correction(self, B, C, indices): return MRIFourierCorrected(self, B, C, indices) + def compute_smaps(self, method=None): + """Compute the sensitivity maps and set it. + + Parameters + ---------- + method: callable or dict or array + The method to use to compute the sensitivity maps. + If an array, it should be of shape (NCoils,XYZ) and will be used as is. + If a dict, it should have a key 'name', to determine which method to use. + other items will be used as kwargs. + If a callable, it should take the samples and the shape as input. + Note that this callable function should also hold the k-space data + (use funtools.partial) + """ + if isinstance(method, np.ndarray): + self.smaps = method + return None + if not method: + self.smaps = None + return None + kwargs = {} + if isinstance(method, dict): + kwargs = method.copy() + method = kwargs.pop("name") + if isinstance(method, str): + method = get_smaps(method) + if not callable(method): + raise ValueError(f"Unknown smaps method: {method}") + self.smaps, self.SOS = method( + self.samples, + self.shape, + density=self.density, + backend=self.backend, + **kwargs, + ) + def make_autograd(self, variable="data"): """Make a new Operator with autodiff support. @@ -499,6 +536,7 @@ def __init__( # Density Compensation Setup self.compute_density(density) + self.compute_smaps(smaps) # Multi Coil Setup if n_coils < 1: raise ValueError("n_coils should be ≥ 1") diff --git a/src/mrinufft/operators/interfaces/cufinufft.py b/src/mrinufft/operators/interfaces/cufinufft.py index ee581607..d0017924 100644 --- a/src/mrinufft/operators/interfaces/cufinufft.py +++ b/src/mrinufft/operators/interfaces/cufinufft.py @@ -210,7 +210,7 @@ def __init__( self.density = cp.array(self.density) # Smaps support - self.smaps = smaps + self.compute_smaps(smaps) self.smaps_cached = False if smaps is not None: if not (is_host_array(smaps) or is_cuda_array(smaps)): diff --git a/src/mrinufft/operators/interfaces/gpunufft.py b/src/mrinufft/operators/interfaces/gpunufft.py index 08934dc2..5e0d4ac0 100644 --- a/src/mrinufft/operators/interfaces/gpunufft.py +++ b/src/mrinufft/operators/interfaces/gpunufft.py @@ -381,15 +381,15 @@ def __init__( self.dtype = self.samples.dtype self.n_coils = n_coils self.n_batchs = n_batchs - self.smaps = smaps self.squeeze_dims = squeeze_dims self.compute_density(density) + self.compute_smaps(smaps) self.impl = RawGpuNUFFT( samples=self.samples, shape=self.shape, n_coils=self.n_coils, density_comp=self.density, - smaps=smaps, + smaps=self.smaps, kernel_width=kwargs.get("kernel_width", -int(np.log10(eps))), **kwargs, ) @@ -478,7 +478,15 @@ def uses_sense(self): return self.impl.uses_sense @classmethod - def pipe(cls, kspace_loc, volume_shape, num_iterations=10, osf=2, **kwargs): + def pipe( + cls, + kspace_loc, + volume_shape, + num_iterations=10, + osf=2, + normalize=True, + **kwargs, + ): """Compute the density compensation weights for a given set of kspace locations. Parameters @@ -491,6 +499,9 @@ def pipe(cls, kspace_loc, volume_shape, num_iterations=10, osf=2, **kwargs): the number of iterations for density estimation osf: float or int The oversampling factor the volume shape + normalize: bool + Whether to normalize the density compensation. + We normalize such that the energy of PSF = 1 """ if GPUNUFFT_AVAILABLE is False: raise ValueError( @@ -506,6 +517,12 @@ def pipe(cls, kspace_loc, volume_shape, num_iterations=10, osf=2, **kwargs): density_comp = grid_op.impl.operator.estimate_density_comp( max_iter=num_iterations ) + if normalize: + spike = np.zeros(volume_shape) + mid_loc = tuple(v // 2 for v in volume_shape) + spike[mid_loc] = 1 + psf = grid_op.adj_op(grid_op.op(spike)) + density_comp /= np.linalg.norm(psf) return density_comp.squeeze() def get_lipschitz_cst(self, max_iter=10, tolerance=1e-5, **kwargs):