Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Smaps estimation module in mri-nufft #90

Merged
merged 58 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
f997f54
Added siemens and add_raw shifts
chaithyagr Apr 10, 2024
2870cc9
Update src/mrinufft/io/siemens.py
chaithyagr Apr 11, 2024
04773f2
Update
chaithyagr Apr 11, 2024
23b1f61
Merge branch 'add_phase' of github.com:chaithyagr/mri-nufft into add_…
chaithyagr Apr 11, 2024
a1966eb
Fixed some more
chaithyagr Apr 11, 2024
ab337db
Moved codes around
chaithyagr Apr 11, 2024
db0669e
Added np.ndarray
chaithyagr Apr 11, 2024
e13948d
Fix movement
chaithyagr Apr 11, 2024
c817c91
Fix movement
chaithyagr Apr 11, 2024
6f4280a
Fix flake
chaithyagr Apr 11, 2024
e4fe153
ruff fix
chaithyagr Apr 11, 2024
33ff82a
Fix
chaithyagr Apr 11, 2024
65c6bf7
Remove bymistake add
chaithyagr Apr 11, 2024
ebd8d32
ci: runs test only for non-style commit. (#73)
paquiteau Jan 8, 2024
f79ffad
Added fixSmaps
chaithyagr Apr 10, 2024
238b358
Fixes updates
chaithyagr Apr 11, 2024
c17ad55
Fix
chaithyagr Apr 11, 2024
dd2cf1d
fix docs
chaithyagr Apr 11, 2024
506bce6
Added smaps with blurring
chaithyagr Apr 11, 2024
2a62f94
Added doc
chaithyagr Apr 11, 2024
c1dabee
Final touchups
chaithyagr Apr 11, 2024
ed4d974
Added compute_smaps
chaithyagr Apr 11, 2024
f6596c9
Added extra files
chaithyagr Apr 11, 2024
dbb7743
Added compute_smaps
chaithyagr Apr 11, 2024
f138ee7
Added mask
Apr 24, 2024
3c27e7f
Added Smaps
Apr 24, 2024
d662498
Updates
Apr 26, 2024
8ca012c
Added
Apr 26, 2024
3a980c2
Merge branch 'mind-inria:master' into master
chaithyagr Apr 26, 2024
093edfb
Fix
Apr 26, 2024
f8a6a4a
Merge branch 'master' of github.com:chaithyagr/mri-nufft
Apr 26, 2024
74c1ecd
Remove bymistake add
Apr 26, 2024
0250aa8
Fix
Apr 26, 2024
060a8bd
Fixed lint
Apr 26, 2024
aecb844
Lint
Apr 26, 2024
3130bc1
Added refbackend
Apr 26, 2024
bc014b8
Fix NDFT
Apr 26, 2024
0cc73c4
feat: use finufft as ref backend.
paquiteau Apr 29, 2024
21e090f
feat(tests): move ndft vs nufft tests to own file.
paquiteau Apr 29, 2024
6869a4a
Merge branch 'master' of github.com:mind-inria/mri-nufft
Apr 29, 2024
d77c3a0
Merge branch 'master' into smaps
paquiteau Apr 29, 2024
f8364d4
Merge branch 'master' of github.com:mind-inria/mri-nufft
Apr 30, 2024
67ff56a
Merge branch 'master' into smaps
Apr 30, 2024
7afdd8e
Added rebart
Apr 30, 2024
140921e
Update codes
Apr 30, 2024
90bf832
updated mask
Apr 30, 2024
ae50a8e
Fixs
May 21, 2024
2013cf1
PEP
May 21, 2024
e7f27ea
Merge branch 'master' into smaps
chaithyagr May 23, 2024
064f9c1
Add lint fixes
May 23, 2024
7de40f6
Added PEP fixes
May 23, 2024
238ec00
Black
May 23, 2024
a58962c
Fix black
May 23, 2024
ebd61d3
Fix
May 23, 2024
c0aa0c5
Added PSF weighting
May 23, 2024
45bc400
Move to tuple
May 23, 2024
18f5b34
lint
paquiteau May 24, 2024
0639eda
lint
paquiteau May 24, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ cufinufft = ["cufinufft", "cupy-cuda11x"]
finufft = ["finufft"]
pynfft = ["pynfft2", "cython<3.0.0"]
pynufft = ["pynufft"]
io = ["pymapvbvd"]

test = ["pytest<8.0.0", "pytest-cov", "pytest-xdist", "pytest-sugar", "pytest-cases"]
dev = ["black", "isort", "ruff"]
Expand Down
10 changes: 10 additions & 0 deletions src/mrinufft/extras/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""Sensitivity map estimation methods."""

from .smaps import low_frequency
from .utils import get_density


__all__ = [
"low_frequency",
"get_smaps",
]
155 changes: 155 additions & 0 deletions src/mrinufft/extras/smaps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
from mrinufft.density.utils import flat_traj
from mrinufft.operators.base import get_array_module
from mrinufft import get_operator
from skimage.filters import threshold_otsu, gaussian
from skimage.morphology import convex_hull_image
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to check, but I think it's a new dependency.

from .utils import register_smaps
import numpy as np


def _extract_kspace_center(
kspace_data, kspace_loc, threshold=None, density=None, window_fun="ellipse",
):
r"""Extract k-space center and corresponding sampling locations.

The extracted center of the k-space, i.e. both the kspace locations and
kspace values. If the density compensators are passed, the corresponding
compensators for the center of k-space data will also be returned. The
return dtypes for density compensation and kspace data is same as input

Parameters
----------
kspace_data: numpy.ndarray
The value of the samples
kspace_loc: numpy.ndarray
The samples location in the k-sapec domain (between [-0.5, 0.5[)
threshold: tuple or float
The threshold used to extract the k_space center (between (0, 1])
window_fun: "Hann", "Hanning", "Hamming", or a callable, default None.
The window function to apply to the selected data. It is computed with
the center locations selected. Only works with circular mask.
If window_fun is a callable, it takes as input the array (n_samples x n_dims)
of sample positions and returns an array of n_samples weights to be
applied to the selected k-space values, before the smaps estimation.

Returns
-------
data_thresholded: ndarray
The k-space values in the center region.
center_loc: ndarray
The locations in the center region.
density_comp: ndarray, optional
The density compensation weights (if requested)

Notes
-----
The Hann (or Hanning) and Hamming windows of width :math:`2\theta` are defined as:
.. math::

w(x,y) = a_0 - (1-a_0) * \cos(\pi * \sqrt{x^2+y^2}/\theta),
\sqrt{x^2+y^2} \le \theta

In the case of Hann window :math:`a_0=0.5`.
For Hamming window we consider the optimal value in the equiripple sense:
:math:`a_0=0.53836`.
.. Wikipedia:: https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows

"""
xp = get_array_module(kspace_data)
if isinstance(threshold, float):
threshold = (threshold,) * kspace_loc.shape[1]

if window_fun == "rect":
data_ordered = xp.copy(kspace_data)
index = xp.linspace(0, kspace_loc.shape[0] - 1, kspace_loc.shape[0], dtype=xp.int64)
condition = xp.logical_and.reduce(tuple(
xp.abs(kspace_loc[:, i]) <= threshold[i] for i in range(len(threshold))
))
index = xp.extract(condition, index)
center_locations = kspace_loc[index, :]
data_thresholded = data_ordered[:, index]
dc = density[index]
return data_thresholded, center_locations, dc
else:
if callable(window_fun):
window = window_fun(center_locations)
else:
if window_fun in ["hann", "hanning", "hamming"]:
radius = xp.linalg.norm(kspace_loc, axis=1)
a_0 = 0.5 if window_fun in ["hann", "hanning"] else 0.53836
window = a_0 + (1 - a_0) * xp.cos(xp.pi * radius / threshold)
elif window_fun == "ellipse":
window = xp.sum(kspace_loc**2/ xp.asarray(threshold)**2, axis=1) <= 1
else:
raise ValueError("Unsupported window function.")
data_thresholded = window * data_thresholded
# Return k-space locations & density just for consistency
return data_thresholded, kspace_loc, density


@register_smaps
@flat_traj
def low_frequency(traj, shape, kspace_data, threshold, backend, density=None,
extract_kwargs=None, blurr_factor=0, mask=True):
"""
Calculate low-frequency sensitivity maps.

Parameters
----------
traj : numpy.ndarray
The trajectory of the samples.
shape : tuple
The shape of the image.
kspace_data : numpy.ndarray
The k-space data.
threshold : float
The threshold used for extracting the k-space center.
backend : str
The backend used for the operator.
density : numpy.ndarray, optional
The density compensation weights.
extract_kwargs : dict, optional
Additional keyword arguments for the `extract_kspace_center` function.
blurr_factor : float, optional
The blurring factor for smoothing the sensitivity maps.
mask: bool, optional default `True`
Whether the Sensitivity maps must be masked

Returns
-------
Smaps : numpy.ndarray
The low-frequency sensitivity maps.
SOS : numpy.ndarray
The sum of squares of the sensitivity maps.
"""
k_space, samples, dc = _extract_kspace_center(
kspace_data=kspace_data,
kspace_loc=traj,
threshold=threshold,
density=density,
img_shape=shape,
**(extract_kwargs or {}),
)
smaps_adj_op = get_operator(backend)(
samples,
shape,
density=dc,
n_coils=k_space.shape[0]
)
Smaps = smaps_adj_op.adj_op(k_space)
SOS = np.linalg.norm(Smaps, axis=0)
if mask:
thresh = threshold_otsu(SOS)
# Create convex hull from mask
convex_hull = convex_hull_image(SOS>thresh)
Smaps = Smaps * convex_hull
# Smooth out the sensitivity maps
if blurr_factor > 0:
Smaps = gaussian(Smaps, sigma=blurr_factor * np.asarray(shape))
# Re-normalize the sensitivity maps
if mask or blurr_factor > 0:
# ReCalculate SOS with a minor eps to ensure divide by 0 is ok
SOS = np.linalg.norm(Smaps, axis=0) + 1e-10
Smaps = Smaps / SOS
return Smaps, SOS

17 changes: 17 additions & 0 deletions src/mrinufft/extras/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from mrinufft._utils import MethodRegister

register_smaps = MethodRegister("sensitivity_maps")

def get_smaps(name, *args, **kwargs):
"""Get the density compensation function from its name."""
try:
method = register_smaps.registry["sensitivity_maps"][name]
except KeyError as e:
raise ValueError(
f"Unknown density compensation method {name}. Available methods are \n"
f"{list(register_smaps.registry['sensitivity_maps'].keys())}"
) from e
Comment on lines +9 to +16
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some copy paste error to correct (smaps instead of density compensation)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't forget this ;)


if args or kwargs:
return method(*args, **kwargs)
return method
88 changes: 87 additions & 1 deletion src/mrinufft/io/nsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def read_trajectory(
grad_filename: str,
dwell_time: float = DEFAULT_RASTER_TIME,
num_adc_samples: int = None,
gamma: float = Gammas.HYDROGEN,
gamma: Gammas | float = Gammas.HYDROGEN,
raster_time: float = DEFAULT_RASTER_TIME,
read_shots: bool = False,
normalize_factor: float = KMAX,
Expand Down Expand Up @@ -390,3 +390,89 @@ def read_trajectory(
Kmax = img_size / 2 / fov
kspace_loc = kspace_loc / Kmax * normalize_factor
return kspace_loc, params


def read_siemens_rawdat(
filename: str,
removeOS: bool = False,
squeeze: bool = True,
data_type: str = "ARBGRAD_VE11C",
):
"""Read raw data from a Siemens MRI file.

Parameters
----------
filename : str
The path to the Siemens MRI file.
removeOS : bool, optional
Whether to remove the oversampling, by default False.
squeeze : bool, optional
Whether to squeeze the dimensions of the data, by default True.
data_type : str, optional
The type of data to read, by default 'ARBGRAD_VE11C'.

Returns
-------
data: ndarray
Imported data formatted as n_coils X n_samples X n_slices X n_contrasts
hdr: dict
Extra information about the data parsed from the twix file

Raises
------
ImportError
If the mapVBVD module is not available.

Notes
-----
This function requires the mapVBVD module to be installed.
You can install it using the following command:
`pip install pymapVBVD`
"""
try:
from mapvbvd import mapVBVD
except ImportError as err:
raise ImportError(
"The mapVBVD module is not available. Please install it using "
"the following command: pip install pymapVBVD"
) from err
twixObj = mapVBVD(filename)
if isinstance(twixObj, list):
twixObj = twixObj[-1]
twixObj.image.flagRemoveOS = removeOS
twixObj.image.squeeze = squeeze
raw_kspace = twixObj.image[""]
data = np.moveaxis(raw_kspace, 0, 2)
hdr = {
"n_coils": int(twixObj.image.NCha),
"n_shots": int(twixObj.image.NLin),
"n_contrasts": int(twixObj.image.NSet),
"n_adc_samples": int(twixObj.image.NCol),
"n_slices": int(twixObj.image.NSli),
}
data = data.reshape(
hdr["n_coils"],
hdr["n_shots"] * hdr["n_adc_samples"],
hdr["n_slices"],
hdr["n_contrasts"],
)
if "ARBGRAD_VE11C" in data_type:
hdr["type"] = "ARBGRAD_GRE"
hdr["shifts"] = ()
for s in [7, 6, 8]:
shift = twixObj.search_header_for_val(
"Phoenix", ("sWiPMemBlock", "adFree", str(s))
)
hdr["shifts"] += (0,) if shift == [] else (shift[0],)
hdr["oversampling_factor"] = twixObj.search_header_for_val(
"Phoenix", ("sWiPMemBlock", "alFree", "4")
)[0]
hdr["trajectory_name"] = twixObj.search_header_for_val(
"Phoenix", ("sWipMemBlock", "tFree")
)[0][1:-1]
if hdr["n_contrasts"] > 1:
hdr["turboFactor"] = twixObj.search_header_for_val(
"Phoenix", ("sFastImaging", "lTurboFactor")
)[0]
hdr["type"] = "ARBGRAD_MP2RAGE"
return data, hdr
38 changes: 38 additions & 0 deletions src/mrinufft/io/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""
Module containing utility functions for IO in MRI NUFFT.
"""
import numpy as np


def add_phase_to_kspace_with_shifts(kspace_data, kspace_loc, normalized_shifts):
"""
Add phase shifts to k-space data.

Parameters
----------
kspace_data : np.ndarray
The k-space data.
kspace_loc : np.ndarray
The k-space locations.
normalized_shifts : tuple
The normalized shifts to apply to each dimension of k-space.

Returns
-------
ndarray
The k-space data with phase shifts applied.

Raises
------
ValueError
If the dimension of normalized_shifts does not match the number of
dimensions in kspace_loc.
"""
if len(normalized_shifts) != kspace_loc.shape[1]:
raise ValueError(
"Dimension mismatch between shift and kspace locations! "
"Ensure that shifts are right"
)
phi = np.sum(kspace_loc * normalized_shifts, axis=-1)
phase = np.exp(-2 * np.pi * 1j * phi)
return kspace_data * phase
38 changes: 38 additions & 0 deletions src/mrinufft/operators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from mrinufft.operators.interfaces.utils import is_cuda_array

from mrinufft.density import get_density
from mrinufft.extras import get_smaps

CUPY_AVAILABLE = True
try:
Expand Down Expand Up @@ -225,6 +226,42 @@ def with_off_resonnance_correction(self, B, C, indices):
from ..off_resonnance import MRIFourierCorrected

return MRIFourierCorrected(self, B, C, indices)

def compute_smaps(self, method=None):
"""Compute the sensitivity maps and set it.

Parameters
----------
method: callable or dict or array
The method to use to compute the sensitivity maps.
If an array, it should be of shape (NCoils,XYZ) and will be used as is.
If a dict, it should have a key 'name', to determine which method to use.
other items will be used as kwargs.
If a callable, it should take the samples and the shape as input.
Note that this callable function should also hold the k-space data
(use funtools.partial)
"""
if isinstance(method, np.ndarray):
self.smaps = method
return None
if not method:
self.smaps = None
return None
kwargs = {}
if isinstance(method, dict):
kwargs = method.copy()
method = kwargs.pop("name")
if isinstance(method, str):
method = get_smaps(method)
if not callable(method):
raise ValueError(f"Unknown smaps method: {method}")
self.smaps, self.SOS = method(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mention that you have the SOS available as well in this case.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where do I mention though?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docstring should be fine. Maybe rename it as _SOS , it's not something that users should rely onto, as it is not always available. If we want to expose it we could use a property wrapping around it (and compute the SOS on the fly when this method is not used). Somethink like:

@property()
def SOS(self):
    """Returns the sum of square of the smaps. 
    
    If available, else return None"""
    if not self.uses_sense:
         return None 
    if hasattr(self, "_sos"):
       return self._sos
    else: 
      xp = get_array_module(self.smaps)
      return xp.sum(self.smaps**2, axis=0)

self.samples,
self.shape,
density=self.density,
backend=self.backend,
**kwargs
)

def compute_density(self, method=None):
"""Compute the density compensation weights and set it.
Expand Down Expand Up @@ -442,6 +479,7 @@ def __init__(

# Density Compensation Setup
self.compute_density(density)
self.compute_smaps(smaps)
# Multi Coil Setup
if n_coils < 1:
raise ValueError("n_coils should be ≥ 1")
Expand Down
Loading
Loading