Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added siemens data reader and phase shifter #89

Merged
merged 15 commits into from
Apr 29, 2024
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ cufinufft = ["cufinufft", "cupy-cuda11x"]
finufft = ["finufft"]
pynfft = ["pynfft2", "cython<3.0.0"]
pynufft = ["pynufft"]
io = ["pymapvbvd"]

test = ["pytest<8.0.0", "pytest-cov", "pytest-xdist", "pytest-sugar", "pytest-cases"]
dev = ["black", "isort", "ruff"]
Expand Down
88 changes: 87 additions & 1 deletion src/mrinufft/io/nsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def read_trajectory(
grad_filename: str,
dwell_time: float = DEFAULT_RASTER_TIME,
num_adc_samples: int = None,
gamma: float = Gammas.HYDROGEN,
gamma: Gammas | float = Gammas.HYDROGEN,
raster_time: float = DEFAULT_RASTER_TIME,
read_shots: bool = False,
normalize_factor: float = KMAX,
Expand Down Expand Up @@ -390,3 +390,89 @@ def read_trajectory(
Kmax = img_size / 2 / fov
kspace_loc = kspace_loc / Kmax * normalize_factor
return kspace_loc, params


def read_siemens_rawdat(
filename: str,
removeOS: bool = False,
squeeze: bool = True,
data_type: str = "ARBGRAD_VE11C",
): # pragma: no cover
"""Read raw data from a Siemens MRI file.

Parameters
----------
filename : str
The path to the Siemens MRI file.
removeOS : bool, optional
Whether to remove the oversampling, by default False.
squeeze : bool, optional
Whether to squeeze the dimensions of the data, by default True.
data_type : str, optional
The type of data to read, by default 'ARBGRAD_VE11C'.

Returns
-------
data: ndarray
Imported data formatted as n_coils X n_samples X n_slices X n_contrasts
hdr: dict
Extra information about the data parsed from the twix file

Raises
------
ImportError
If the mapVBVD module is not available.

Notes
-----
This function requires the mapVBVD module to be installed.
You can install it using the following command:
`pip install pymapVBVD`
"""
try:
from mapvbvd import mapVBVD
except ImportError as err:
raise ImportError(
"The mapVBVD module is not available. Please install it using "
"the following command: pip install pymapVBVD"
) from err
twixObj = mapVBVD(filename)
if isinstance(twixObj, list):
twixObj = twixObj[-1]
twixObj.image.flagRemoveOS = removeOS
twixObj.image.squeeze = squeeze
raw_kspace = twixObj.image[""]
data = np.moveaxis(raw_kspace, 0, 2)
hdr = {
"n_coils": int(twixObj.image.NCha),
"n_shots": int(twixObj.image.NLin),
"n_contrasts": int(twixObj.image.NSet),
"n_adc_samples": int(twixObj.image.NCol),
"n_slices": int(twixObj.image.NSli),
}
data = data.reshape(
hdr["n_coils"],
hdr["n_shots"] * hdr["n_adc_samples"],
hdr["n_slices"],
hdr["n_contrasts"],
)
if "ARBGRAD_VE11C" in data_type:
hdr["type"] = "ARBGRAD_GRE"
hdr["shifts"] = ()
for s in [7, 6, 8]:
shift = twixObj.search_header_for_val(
"Phoenix", ("sWiPMemBlock", "adFree", str(s))
)
hdr["shifts"] += (0,) if shift == [] else (shift[0],)
hdr["oversampling_factor"] = twixObj.search_header_for_val(
"Phoenix", ("sWiPMemBlock", "alFree", "4")
)[0]
hdr["trajectory_name"] = twixObj.search_header_for_val(
"Phoenix", ("sWipMemBlock", "tFree")
)[0][1:-1]
if hdr["n_contrasts"] > 1:
hdr["turboFactor"] = twixObj.search_header_for_val(
"Phoenix", ("sFastImaging", "lTurboFactor")
)[0]
hdr["type"] = "ARBGRAD_MP2RAGE"
return data, hdr
37 changes: 37 additions & 0 deletions src/mrinufft/io/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""Module containing utility functions for IO in MRI NUFFT."""

import numpy as np


def add_phase_to_kspace_with_shifts(kspace_data, kspace_loc, normalized_shifts):
"""
Add phase shifts to k-space data.

Parameters
----------
kspace_data : np.ndarray
The k-space data.
kspace_loc : np.ndarray
The k-space locations.
normalized_shifts : tuple
The normalized shifts to apply to each dimension of k-space.

Returns
-------
ndarray
The k-space data with phase shifts applied.

Raises
------
ValueError
If the dimension of normalized_shifts does not match the number of
dimensions in kspace_loc.
"""
if len(normalized_shifts) != kspace_loc.shape[1]:
raise ValueError(
"Dimension mismatch between shift and kspace locations! "
"Ensure that shifts are right"
)
phi = np.sum(kspace_loc * normalized_shifts, axis=-1)
phase = np.exp(-2 * np.pi * 1j * phi)
return kspace_data * phase
20 changes: 20 additions & 0 deletions tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

import numpy as np
from mrinufft.io import read_trajectory, write_trajectory
from mrinufft.io.utils import add_phase_to_kspace_with_shifts
from mrinufft.trajectories.trajectory2D import initialize_2D_radial
from mrinufft.trajectories.trajectory3D import initialize_3D_cones
from pytest_cases import parametrize_with_cases
from case_trajectories import CasesTrajectories


class CasesIO:
Expand Down Expand Up @@ -67,3 +69,21 @@ def test_write_n_read(
np.testing.assert_almost_equal(params["FOV"], FOV, decimal=6)
np.testing.assert_equal(params["img_size"], img_size)
np.testing.assert_almost_equal(read_traj, trajectory, decimal=5)


@parametrize_with_cases(
"kspace_loc, shape",
cases=[CasesTrajectories.case_random2D, CasesTrajectories.case_random3D],
)
def test_add_shift(kspace_loc, shape):
"""Test the add_phase_to_kspace_with_shifts function."""
n_samples = np.prod(kspace_loc.shape[:-1])
kspace_data = np.random.randn(n_samples) + 1j * np.random.randn(n_samples)
shifts = np.random.rand(kspace_loc.shape[-1])

shifted_data = add_phase_to_kspace_with_shifts(kspace_data, kspace_loc, shifts)

assert np.allclose(np.abs(shifted_data), np.abs(kspace_data))

phase = np.exp(-2 * np.pi * 1j * np.sum(kspace_loc * shifts, axis=-1))
np.testing.assert_almost_equal(shifted_data / phase, kspace_data, decimal=5)
Loading