From f997f54c53e853b794d4692b36e2f9a195bbbda7 Mon Sep 17 00:00:00 2001 From: Chaithya G R Date: Wed, 10 Apr 2024 16:21:42 +0200 Subject: [PATCH 01/50] Added siemens and add_raw shifts --- src/mrinufft/io/siemens.py | 87 ++++++++++++++++++++++++++++++++++++++ src/mrinufft/io/utils.py | 38 +++++++++++++++++ 2 files changed, 125 insertions(+) create mode 100644 src/mrinufft/io/siemens.py create mode 100644 src/mrinufft/io/utils.py diff --git a/src/mrinufft/io/siemens.py b/src/mrinufft/io/siemens.py new file mode 100644 index 00000000..7f8b068c --- /dev/null +++ b/src/mrinufft/io/siemens.py @@ -0,0 +1,87 @@ +import numpy as np + +try: + from mapvbvd import mapVBVD + MAPVBVD_FOUND = True +except ImportError: + MAPVBVD_FOUND = False + + +def read_rawdat(filename: str, removeOS: bool = False, squeeze: bool = True, + data_type: str = "SPARKLING_VE11C"): + """Read raw data from a Siemens MRI file. + + Parameters + ---------- + filename : str + The path to the Siemens MRI file. + removeOS : bool, optional + Whether to remove the oversampling, by default False. + squeeze : bool, optional + Whether to squeeze the dimensions of the data, by default True. + data_type : str, optional + The type of data to read, by default 'SPARKLING_VE11C'. + + Returns + ------- + tuple + A tuple containing the raw data and the header information. + + Raises + ------ + ImportError + If the mapVBVD module is not available. + + Notes + ----- + This function requires the mapVBVD module to be installed. + You can install it using the following command: + `pip install pymapVBVD` + """ + if not MAPVBVD_FOUND: + raise ImportError( + "The mapVBVD module is not available. Please install it using " + "the following command: pip install pymapVBVD" + ) + twixObj = mapVBVD(filename) + if isinstance(twixObj, list): + twixObj = twixObj[-1] + twixObj.image.flagRemoveOS = removeOS + twixObj.image.squeeze = squeeze + raw_kspace = twixObj.image[''] + data = np.moveaxis(raw_kspace, 0, 2) + hdr = { + "num_coils": int(twixObj.image.NCha), + "num_shots": int(twixObj.image.NLin), + "num_contrasts": int(twixObj.image.NSet), + "num_adc_samples": int(twixObj.image.NCol), + "num_slices": int(twixObj.image.NSli), + } + data = data.reshape( + hdr["num_coils"], + hdr["num_shots"]*hdr["num_adc_samples"], + hdr["num_slices"], + hdr["num_contrasts"] + ) + if "SPARKLING_VE11C" in data_type: + hdr["shifts"] = tuple([ + 0 if twixObj.search_header_for_val( + "Phoenix", ("sWiPMemBlock", "adFree", str(s)) + ) == [] + else twixObj.search_header_for_val( + "Phoenix", ("sWiPMemBlock", "adFree", str(s)) + )[0] + for s in [7, 6, 8] + ]) + hdr["oversampling_factor"] = twixObj.search_header_for_val( + "Phoenix", ("sWiPMemBlock", "alFree", "4") + )[0] + hdr["trajectory_name"] = twixObj.search_header_for_val( + "Phoenix", ("sWipMemBlock", "tFree") + )[0][1:-1] + if(hdr["num_contrasts"] > 1): + hdr["turboFactor"] = twixObj.search_header_for_val( + "Phoenix", ("sFastImaging", "lTurboFactor") + )[0] + hdr["type"] = "MP2RAGE" + return data, hdr \ No newline at end of file diff --git a/src/mrinufft/io/utils.py b/src/mrinufft/io/utils.py new file mode 100644 index 00000000..8a5e7f3f --- /dev/null +++ b/src/mrinufft/io/utils.py @@ -0,0 +1,38 @@ +from mrinufft.operators.base import with_numpy_cupy, get_array_module + + +@with_numpy_cupy +def add_phase_to_kspace_with_shifts(kspace_data, kspace_loc, normalized_shifts): + """ + Add phase shifts to k-space data. + + Parameters + ---------- + kspace_data : ndarray + The k-space data. + kspace_loc : ndarray + The k-space locations. + normalized_shifts : tuple + The normalized shifts to apply to each dimension of k-space. + + Returns + ------- + ndarray + The k-space data with phase shifts applied. + + Raises + ------ + ValueError + If the dimension of normalized_shifts does not match the number of + dimensions in kspace_loc. + """ + if len(normalized_shifts) != kspace_loc.shape[1]: + raise ValueError( + "Dimension mismatch between shift and kspace locations! " + "Ensure that shifts are right" + ) + xp = get_array_module(kspace_data) + phi = xp.sum(kspace_loc*normalized_shifts, axis=-1) + phase = xp.exp(-2 * xp.pi * 1j * phi) + return kspace_data * phase + From 2870cc9a2eed51d16e9262d0c939e9fd29ff02bc Mon Sep 17 00:00:00 2001 From: Chaithya G R Date: Thu, 11 Apr 2024 09:27:42 +0200 Subject: [PATCH 02/50] Update src/mrinufft/io/siemens.py Co-authored-by: Pierre-Antoine Comby --- src/mrinufft/io/siemens.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/mrinufft/io/siemens.py b/src/mrinufft/io/siemens.py index 7f8b068c..cf32ecb1 100644 --- a/src/mrinufft/io/siemens.py +++ b/src/mrinufft/io/siemens.py @@ -24,8 +24,10 @@ def read_rawdat(filename: str, removeOS: bool = False, squeeze: bool = True, Returns ------- - tuple - A tuple containing the raw data and the header information. + data: ndarray + Imported data formatted as XXX + hdr: dict + Extra information about the data parsed from the twix file Raises ------ From 04773f26fb21603814c4e4c6273502a340674cec Mon Sep 17 00:00:00 2001 From: Chaithya G R Date: Thu, 11 Apr 2024 09:36:29 +0200 Subject: [PATCH 03/50] Update --- src/mrinufft/io/siemens.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/src/mrinufft/io/siemens.py b/src/mrinufft/io/siemens.py index 7f8b068c..1e86a77b 100644 --- a/src/mrinufft/io/siemens.py +++ b/src/mrinufft/io/siemens.py @@ -63,16 +63,14 @@ def read_rawdat(filename: str, removeOS: bool = False, squeeze: bool = True, hdr["num_slices"], hdr["num_contrasts"] ) - if "SPARKLING_VE11C" in data_type: - hdr["shifts"] = tuple([ - 0 if twixObj.search_header_for_val( + if "ARBGRAD_VE11C" in data_type: + hdr["type"] = "ARBGRAD_GRE" + hdr["shifts"] = () + for s in [7, 6, 8]: + shift = twixObj.search_header_for_val( "Phoenix", ("sWiPMemBlock", "adFree", str(s)) - ) == [] - else twixObj.search_header_for_val( - "Phoenix", ("sWiPMemBlock", "adFree", str(s)) - )[0] - for s in [7, 6, 8] - ]) + ) + hdr["shifts"] += (0,) if shift == [] else (shift[0],) hdr["oversampling_factor"] = twixObj.search_header_for_val( "Phoenix", ("sWiPMemBlock", "alFree", "4") )[0] @@ -83,5 +81,5 @@ def read_rawdat(filename: str, removeOS: bool = False, squeeze: bool = True, hdr["turboFactor"] = twixObj.search_header_for_val( "Phoenix", ("sFastImaging", "lTurboFactor") )[0] - hdr["type"] = "MP2RAGE" + hdr["type"] = "ARBGRAD_MP2RAGE" return data, hdr \ No newline at end of file From a1966eb446e0adacaaa6e7ab11606358de22fed2 Mon Sep 17 00:00:00 2001 From: Chaithya G R Date: Thu, 11 Apr 2024 09:38:09 +0200 Subject: [PATCH 04/50] Fixed some more --- src/mrinufft/io/siemens.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/mrinufft/io/siemens.py b/src/mrinufft/io/siemens.py index 02fffa0d..e8454d12 100644 --- a/src/mrinufft/io/siemens.py +++ b/src/mrinufft/io/siemens.py @@ -25,7 +25,7 @@ def read_rawdat(filename: str, removeOS: bool = False, squeeze: bool = True, Returns ------- data: ndarray - Imported data formatted as XXX + Imported data formatted as n_coils X n_samples X n_slices X n_contrasts hdr: dict Extra information about the data parsed from the twix file @@ -53,17 +53,17 @@ def read_rawdat(filename: str, removeOS: bool = False, squeeze: bool = True, raw_kspace = twixObj.image[''] data = np.moveaxis(raw_kspace, 0, 2) hdr = { - "num_coils": int(twixObj.image.NCha), - "num_shots": int(twixObj.image.NLin), - "num_contrasts": int(twixObj.image.NSet), - "num_adc_samples": int(twixObj.image.NCol), - "num_slices": int(twixObj.image.NSli), + "n_coils": int(twixObj.image.NCha), + "n_shots": int(twixObj.image.NLin), + "n_contrasts": int(twixObj.image.NSet), + "n_adc_samples": int(twixObj.image.NCol), + "n_slices": int(twixObj.image.NSli), } data = data.reshape( - hdr["num_coils"], - hdr["num_shots"]*hdr["num_adc_samples"], - hdr["num_slices"], - hdr["num_contrasts"] + hdr["n_coils"], + hdr["n_shots"]*hdr["n_adc_samples"], + hdr["n_slices"], + hdr["n_contrasts"] ) if "ARBGRAD_VE11C" in data_type: hdr["type"] = "ARBGRAD_GRE" From ab337dbd66f6b3352486e15211f2c8684ba4fe49 Mon Sep 17 00:00:00 2001 From: Chaithya G R Date: Thu, 11 Apr 2024 09:41:40 +0200 Subject: [PATCH 05/50] Moved codes around --- pyproject.toml | 1 + src/mrinufft/io/nsp.py | 82 +++++++++++++++++++++++++++++++++++ src/mrinufft/io/siemens.py | 87 -------------------------------------- 3 files changed, 83 insertions(+), 87 deletions(-) delete mode 100644 src/mrinufft/io/siemens.py diff --git a/pyproject.toml b/pyproject.toml index dfe63e58..d3274501 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ cufinufft = ["cufinufft", "cupy-cuda11x"] finufft = ["finufft"] pynfft = ["pynfft2", "cython<3.0.0"] pynufft = ["pynufft"] +io = ["pymapvbvd"] test = ["pytest<8.0.0", "pytest-cov", "pytest-xdist", "pytest-sugar", "pytest-cases"] dev = ["black", "isort", "ruff"] diff --git a/src/mrinufft/io/nsp.py b/src/mrinufft/io/nsp.py index 180aee0d..173e636e 100644 --- a/src/mrinufft/io/nsp.py +++ b/src/mrinufft/io/nsp.py @@ -390,3 +390,85 @@ def read_trajectory( Kmax = img_size / 2 / fov kspace_loc = kspace_loc / Kmax * normalize_factor return kspace_loc, params + + +def read_siemens_rawdat(filename: str, removeOS: bool = False, squeeze: bool = True, + data_type: str = "SPARKLING_VE11C"): + """Read raw data from a Siemens MRI file. + + Parameters + ---------- + filename : str + The path to the Siemens MRI file. + removeOS : bool, optional + Whether to remove the oversampling, by default False. + squeeze : bool, optional + Whether to squeeze the dimensions of the data, by default True. + data_type : str, optional + The type of data to read, by default 'SPARKLING_VE11C'. + + Returns + ------- + data: ndarray + Imported data formatted as n_coils X n_samples X n_slices X n_contrasts + hdr: dict + Extra information about the data parsed from the twix file + + Raises + ------ + ImportError + If the mapVBVD module is not available. + + Notes + ----- + This function requires the mapVBVD module to be installed. + You can install it using the following command: + `pip install pymapVBVD` + """ + try: + from mapvbvd import mapVBVD + except ImportError as err: + raise ImportError( + "The mapVBVD module is not available. Please install it using " + "the following command: pip install pymapVBVD" + ) from err + twixObj = mapVBVD(filename) + if isinstance(twixObj, list): + twixObj = twixObj[-1] + twixObj.image.flagRemoveOS = removeOS + twixObj.image.squeeze = squeeze + raw_kspace = twixObj.image[''] + data = np.moveaxis(raw_kspace, 0, 2) + hdr = { + "n_coils": int(twixObj.image.NCha), + "n_shots": int(twixObj.image.NLin), + "n_contrasts": int(twixObj.image.NSet), + "n_adc_samples": int(twixObj.image.NCol), + "n_slices": int(twixObj.image.NSli), + } + data = data.reshape( + hdr["n_coils"], + hdr["n_shots"]*hdr["n_adc_samples"], + hdr["n_slices"], + hdr["n_contrasts"] + ) + if "ARBGRAD_VE11C" in data_type: + hdr["type"] = "ARBGRAD_GRE" + hdr["shifts"] = () + for s in [7, 6, 8]: + shift = twixObj.search_header_for_val( + "Phoenix", ("sWiPMemBlock", "adFree", str(s)) + ) + hdr["shifts"] += (0,) if shift == [] else (shift[0],) + hdr["oversampling_factor"] = twixObj.search_header_for_val( + "Phoenix", ("sWiPMemBlock", "alFree", "4") + )[0] + hdr["trajectory_name"] = twixObj.search_header_for_val( + "Phoenix", ("sWipMemBlock", "tFree") + )[0][1:-1] + if(hdr["num_contrasts"] > 1): + hdr["turboFactor"] = twixObj.search_header_for_val( + "Phoenix", ("sFastImaging", "lTurboFactor") + )[0] + hdr["type"] = "ARBGRAD_MP2RAGE" + return data, hdr \ No newline at end of file diff --git a/src/mrinufft/io/siemens.py b/src/mrinufft/io/siemens.py deleted file mode 100644 index e8454d12..00000000 --- a/src/mrinufft/io/siemens.py +++ /dev/null @@ -1,87 +0,0 @@ -import numpy as np - -try: - from mapvbvd import mapVBVD - MAPVBVD_FOUND = True -except ImportError: - MAPVBVD_FOUND = False - - -def read_rawdat(filename: str, removeOS: bool = False, squeeze: bool = True, - data_type: str = "SPARKLING_VE11C"): - """Read raw data from a Siemens MRI file. - - Parameters - ---------- - filename : str - The path to the Siemens MRI file. - removeOS : bool, optional - Whether to remove the oversampling, by default False. - squeeze : bool, optional - Whether to squeeze the dimensions of the data, by default True. - data_type : str, optional - The type of data to read, by default 'SPARKLING_VE11C'. - - Returns - ------- - data: ndarray - Imported data formatted as n_coils X n_samples X n_slices X n_contrasts - hdr: dict - Extra information about the data parsed from the twix file - - Raises - ------ - ImportError - If the mapVBVD module is not available. - - Notes - ----- - This function requires the mapVBVD module to be installed. - You can install it using the following command: - `pip install pymapVBVD` - """ - if not MAPVBVD_FOUND: - raise ImportError( - "The mapVBVD module is not available. Please install it using " - "the following command: pip install pymapVBVD" - ) - twixObj = mapVBVD(filename) - if isinstance(twixObj, list): - twixObj = twixObj[-1] - twixObj.image.flagRemoveOS = removeOS - twixObj.image.squeeze = squeeze - raw_kspace = twixObj.image[''] - data = np.moveaxis(raw_kspace, 0, 2) - hdr = { - "n_coils": int(twixObj.image.NCha), - "n_shots": int(twixObj.image.NLin), - "n_contrasts": int(twixObj.image.NSet), - "n_adc_samples": int(twixObj.image.NCol), - "n_slices": int(twixObj.image.NSli), - } - data = data.reshape( - hdr["n_coils"], - hdr["n_shots"]*hdr["n_adc_samples"], - hdr["n_slices"], - hdr["n_contrasts"] - ) - if "ARBGRAD_VE11C" in data_type: - hdr["type"] = "ARBGRAD_GRE" - hdr["shifts"] = () - for s in [7, 6, 8]: - shift = twixObj.search_header_for_val( - "Phoenix", ("sWiPMemBlock", "adFree", str(s)) - ) - hdr["shifts"] += (0,) if shift == [] else (shift[0],) - hdr["oversampling_factor"] = twixObj.search_header_for_val( - "Phoenix", ("sWiPMemBlock", "alFree", "4") - )[0] - hdr["trajectory_name"] = twixObj.search_header_for_val( - "Phoenix", ("sWipMemBlock", "tFree") - )[0][1:-1] - if(hdr["num_contrasts"] > 1): - hdr["turboFactor"] = twixObj.search_header_for_val( - "Phoenix", ("sFastImaging", "lTurboFactor") - )[0] - hdr["type"] = "ARBGRAD_MP2RAGE" - return data, hdr \ No newline at end of file From db0669e3ba0943b55b8276f2fc09fa9360d60ba9 Mon Sep 17 00:00:00 2001 From: Chaithya G R Date: Thu, 11 Apr 2024 09:53:23 +0200 Subject: [PATCH 06/50] Added np.ndarray --- src/mrinufft/io/utils.py | 12 +++++------- src/mrinufft/operators/base.py | 2 -- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/src/mrinufft/io/utils.py b/src/mrinufft/io/utils.py index 8a5e7f3f..5595d5cf 100644 --- a/src/mrinufft/io/utils.py +++ b/src/mrinufft/io/utils.py @@ -1,16 +1,15 @@ -from mrinufft.operators.base import with_numpy_cupy, get_array_module +import numpy as np -@with_numpy_cupy def add_phase_to_kspace_with_shifts(kspace_data, kspace_loc, normalized_shifts): """ Add phase shifts to k-space data. Parameters ---------- - kspace_data : ndarray + kspace_data : np.ndarray The k-space data. - kspace_loc : ndarray + kspace_loc : np.ndarray The k-space locations. normalized_shifts : tuple The normalized shifts to apply to each dimension of k-space. @@ -31,8 +30,7 @@ def add_phase_to_kspace_with_shifts(kspace_data, kspace_loc, normalized_shifts): "Dimension mismatch between shift and kspace locations! " "Ensure that shifts are right" ) - xp = get_array_module(kspace_data) - phi = xp.sum(kspace_loc*normalized_shifts, axis=-1) - phase = xp.exp(-2 * xp.pi * 1j * phi) + phi = np.sum(kspace_loc*normalized_shifts, axis=-1) + phase = np.exp(-2 * np.pi * 1j * phi) return kspace_data * phase diff --git a/src/mrinufft/operators/base.py b/src/mrinufft/operators/base.py index 8aee5120..401f8718 100644 --- a/src/mrinufft/operators/base.py +++ b/src/mrinufft/operators/base.py @@ -130,11 +130,9 @@ def wrapper(self, data, output=None, *args, **kwargs): if xp.__name__ == "torch" and is_cuda_array(data): # Move them to cupy data_ = cp.from_dlpack(data) - output_ = cp.from_dlpack(output) if output is not None else None elif xp.__name__ == "torch": # Move to numpy data_ = data.to("cpu").numpy() - output_ = output.to("cpu").numpy() if output is not None else None else: data_ = data output_ = output From e13948da773ceafc1e438b17e17baca5932c22d2 Mon Sep 17 00:00:00 2001 From: Chaithya G R Date: Thu, 11 Apr 2024 10:05:04 +0200 Subject: [PATCH 07/50] Fix movement --- src/mrinufft/io/nsp.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/mrinufft/io/nsp.py b/src/mrinufft/io/nsp.py index 173e636e..63bef63f 100644 --- a/src/mrinufft/io/nsp.py +++ b/src/mrinufft/io/nsp.py @@ -2,7 +2,7 @@ import warnings import os -from typing import Tuple, Optional +from typing import Tuple, Optional, Union import numpy as np from datetime import datetime from array import array @@ -253,7 +253,7 @@ def read_trajectory( grad_filename: str, dwell_time: float = DEFAULT_RASTER_TIME, num_adc_samples: int = None, - gamma: float = Gammas.HYDROGEN, + gamma: Union[Gammas,float] = Gammas.HYDROGEN, raster_time: float = DEFAULT_RASTER_TIME, read_shots: bool = False, normalize_factor: float = KMAX, @@ -393,7 +393,7 @@ def read_trajectory( def read_siemens_rawdat(filename: str, removeOS: bool = False, squeeze: bool = True, - data_type: str = "SPARKLING_VE11C"): + data_type: str = "ARBGRAD_VE11C"): """Read raw data from a Siemens MRI file. Parameters @@ -405,7 +405,7 @@ def read_siemens_rawdat(filename: str, removeOS: bool = False, squeeze: bool = T squeeze : bool, optional Whether to squeeze the dimensions of the data, by default True. data_type : str, optional - The type of data to read, by default 'SPARKLING_VE11C'. + The type of data to read, by default 'ARBGRAD_VE11C'. Returns ------- From c817c91225e2e5a032f65c39363f089a3d6bbe85 Mon Sep 17 00:00:00 2001 From: Chaithya G R Date: Thu, 11 Apr 2024 10:05:30 +0200 Subject: [PATCH 08/50] Fix movement --- src/mrinufft/io/nsp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mrinufft/io/nsp.py b/src/mrinufft/io/nsp.py index 63bef63f..c5debc2b 100644 --- a/src/mrinufft/io/nsp.py +++ b/src/mrinufft/io/nsp.py @@ -466,7 +466,7 @@ def read_siemens_rawdat(filename: str, removeOS: bool = False, squeeze: bool = T hdr["trajectory_name"] = twixObj.search_header_for_val( "Phoenix", ("sWipMemBlock", "tFree") )[0][1:-1] - if(hdr["num_contrasts"] > 1): + if(hdr["n_contrasts"] > 1): hdr["turboFactor"] = twixObj.search_header_for_val( "Phoenix", ("sFastImaging", "lTurboFactor") )[0] From 6f4280aad0f7b5518914ffb465b2b510fdbe5f36 Mon Sep 17 00:00:00 2001 From: Chaithya G R Date: Thu, 11 Apr 2024 11:10:11 +0200 Subject: [PATCH 09/50] Fix flake --- src/mrinufft/io/nsp.py | 28 ++++++++++++++++------------ src/mrinufft/io/utils.py | 5 ++--- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/src/mrinufft/io/nsp.py b/src/mrinufft/io/nsp.py index c5debc2b..cc49a89f 100644 --- a/src/mrinufft/io/nsp.py +++ b/src/mrinufft/io/nsp.py @@ -253,7 +253,7 @@ def read_trajectory( grad_filename: str, dwell_time: float = DEFAULT_RASTER_TIME, num_adc_samples: int = None, - gamma: Union[Gammas,float] = Gammas.HYDROGEN, + gamma: Union[Gammas, float] = Gammas.HYDROGEN, raster_time: float = DEFAULT_RASTER_TIME, read_shots: bool = False, normalize_factor: float = KMAX, @@ -392,10 +392,14 @@ def read_trajectory( return kspace_loc, params -def read_siemens_rawdat(filename: str, removeOS: bool = False, squeeze: bool = True, - data_type: str = "ARBGRAD_VE11C"): +def read_siemens_rawdat( + filename: str, + removeOS: bool = False, + squeeze: bool = True, + data_type: str = "ARBGRAD_VE11C", +): """Read raw data from a Siemens MRI file. - + Parameters ---------- filename : str @@ -421,7 +425,7 @@ def read_siemens_rawdat(filename: str, removeOS: bool = False, squeeze: bool = T Notes ----- - This function requires the mapVBVD module to be installed. + This function requires the mapVBVD module to be installed. You can install it using the following command: `pip install pymapVBVD` """ @@ -437,7 +441,7 @@ def read_siemens_rawdat(filename: str, removeOS: bool = False, squeeze: bool = T twixObj = twixObj[-1] twixObj.image.flagRemoveOS = removeOS twixObj.image.squeeze = squeeze - raw_kspace = twixObj.image[''] + raw_kspace = twixObj.image[""] data = np.moveaxis(raw_kspace, 0, 2) hdr = { "n_coils": int(twixObj.image.NCha), @@ -447,10 +451,10 @@ def read_siemens_rawdat(filename: str, removeOS: bool = False, squeeze: bool = T "n_slices": int(twixObj.image.NSli), } data = data.reshape( - hdr["n_coils"], - hdr["n_shots"]*hdr["n_adc_samples"], - hdr["n_slices"], - hdr["n_contrasts"] + hdr["n_coils"], + hdr["n_shots"] * hdr["n_adc_samples"], + hdr["n_slices"], + hdr["n_contrasts"], ) if "ARBGRAD_VE11C" in data_type: hdr["type"] = "ARBGRAD_GRE" @@ -466,9 +470,9 @@ def read_siemens_rawdat(filename: str, removeOS: bool = False, squeeze: bool = T hdr["trajectory_name"] = twixObj.search_header_for_val( "Phoenix", ("sWipMemBlock", "tFree") )[0][1:-1] - if(hdr["n_contrasts"] > 1): + if hdr["n_contrasts"] > 1: hdr["turboFactor"] = twixObj.search_header_for_val( "Phoenix", ("sFastImaging", "lTurboFactor") )[0] hdr["type"] = "ARBGRAD_MP2RAGE" - return data, hdr \ No newline at end of file + return data, hdr diff --git a/src/mrinufft/io/utils.py b/src/mrinufft/io/utils.py index 5595d5cf..b4e19682 100644 --- a/src/mrinufft/io/utils.py +++ b/src/mrinufft/io/utils.py @@ -22,7 +22,7 @@ def add_phase_to_kspace_with_shifts(kspace_data, kspace_loc, normalized_shifts): Raises ------ ValueError - If the dimension of normalized_shifts does not match the number of + If the dimension of normalized_shifts does not match the number of dimensions in kspace_loc. """ if len(normalized_shifts) != kspace_loc.shape[1]: @@ -30,7 +30,6 @@ def add_phase_to_kspace_with_shifts(kspace_data, kspace_loc, normalized_shifts): "Dimension mismatch between shift and kspace locations! " "Ensure that shifts are right" ) - phi = np.sum(kspace_loc*normalized_shifts, axis=-1) + phi = np.sum(kspace_loc * normalized_shifts, axis=-1) phase = np.exp(-2 * np.pi * 1j * phi) return kspace_data * phase - From e4fe1532d242eb003d3e83ae7db90b832c77b003 Mon Sep 17 00:00:00 2001 From: Chaithya G R Date: Thu, 11 Apr 2024 11:11:06 +0200 Subject: [PATCH 10/50] ruff fix --- src/mrinufft/io/utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/mrinufft/io/utils.py b/src/mrinufft/io/utils.py index b4e19682..1b2fe619 100644 --- a/src/mrinufft/io/utils.py +++ b/src/mrinufft/io/utils.py @@ -1,3 +1,6 @@ +""" +Module containing utility functions for IO in MRI NUFFT. +""" import numpy as np From 33ff82a29eb37c744165ff2cff72749424c1695c Mon Sep 17 00:00:00 2001 From: Chaithya G R Date: Thu, 11 Apr 2024 13:10:14 +0200 Subject: [PATCH 11/50] Fix --- src/mrinufft/io/nsp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/mrinufft/io/nsp.py b/src/mrinufft/io/nsp.py index cc49a89f..bbdc69fd 100644 --- a/src/mrinufft/io/nsp.py +++ b/src/mrinufft/io/nsp.py @@ -2,7 +2,7 @@ import warnings import os -from typing import Tuple, Optional, Union +from typing import Tuple, Optional import numpy as np from datetime import datetime from array import array @@ -253,7 +253,7 @@ def read_trajectory( grad_filename: str, dwell_time: float = DEFAULT_RASTER_TIME, num_adc_samples: int = None, - gamma: Union[Gammas, float] = Gammas.HYDROGEN, + gamma: Gammas | = Gammas.HYDROGEN, raster_time: float = DEFAULT_RASTER_TIME, read_shots: bool = False, normalize_factor: float = KMAX, From 65c6bf79b73d87c7e5bc140b18c0839d1e4576d0 Mon Sep 17 00:00:00 2001 From: Chaithya G R Date: Thu, 11 Apr 2024 13:12:12 +0200 Subject: [PATCH 12/50] Remove bymistake add --- src/mrinufft/io/nsp.py | 2 +- src/mrinufft/operators/base.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/mrinufft/io/nsp.py b/src/mrinufft/io/nsp.py index bbdc69fd..39770369 100644 --- a/src/mrinufft/io/nsp.py +++ b/src/mrinufft/io/nsp.py @@ -253,7 +253,7 @@ def read_trajectory( grad_filename: str, dwell_time: float = DEFAULT_RASTER_TIME, num_adc_samples: int = None, - gamma: Gammas | = Gammas.HYDROGEN, + gamma: Gammas | float = Gammas.HYDROGEN, raster_time: float = DEFAULT_RASTER_TIME, read_shots: bool = False, normalize_factor: float = KMAX, diff --git a/src/mrinufft/operators/base.py b/src/mrinufft/operators/base.py index 401f8718..8aee5120 100644 --- a/src/mrinufft/operators/base.py +++ b/src/mrinufft/operators/base.py @@ -130,9 +130,11 @@ def wrapper(self, data, output=None, *args, **kwargs): if xp.__name__ == "torch" and is_cuda_array(data): # Move them to cupy data_ = cp.from_dlpack(data) + output_ = cp.from_dlpack(output) if output is not None else None elif xp.__name__ == "torch": # Move to numpy data_ = data.to("cpu").numpy() + output_ = output.to("cpu").numpy() if output is not None else None else: data_ = data output_ = output From ebd8d321451f1df5f050be2316ac42c4c37af77f Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Comby Date: Mon, 8 Jan 2024 11:25:25 +0100 Subject: [PATCH 13/50] ci: runs test only for non-style commit. (#73) From f79ffad3e2430507618cb8dbad51cce37ff4dfa3 Mon Sep 17 00:00:00 2001 From: Chaithya G R Date: Wed, 10 Apr 2024 10:27:22 +0200 Subject: [PATCH 14/50] Added fixSmaps --- src/mrinufft/extras/smaps.py | 135 +++++++++++++++++++++++++++++++++++ 1 file changed, 135 insertions(+) create mode 100644 src/mrinufft/extras/smaps.py diff --git a/src/mrinufft/extras/smaps.py b/src/mrinufft/extras/smaps.py new file mode 100644 index 00000000..e07d553b --- /dev/null +++ b/src/mrinufft/extras/smaps.py @@ -0,0 +1,135 @@ +from mrinufft._utils import MethodRegister +from mrinufft.density.utils import flat_traj +from mrinufft.operators.base import with_numpy_cupy, get_array_module + + +register_smaps = MethodRegister("sensitivity_maps") + + +@flat_traj +def _get_centeral_index(kspace_loc, threshold): + r""" + Extract the index of the k-space center. + + Parameters + ---------- + kspace_loc: numpy.ndarray + The samples location in the k-sapec domain (between [-0.5, 0.5[) + threshold: tuple or float + The threshold used to extract the k_space center (between (0, 1]) + + Returns + ------- + The index of the k-space center. + """ + xp = get_array_module(kspace_loc) + radius = xp.linalg.norm(kspace_loc, axis=-1) + + if isinstance(threshold, float): + threshold = (threshold,) * kspace_loc.shape[-1] + condition = xp.logical_and.reduce(tuple( + xp.abs(kspace_loc[:, i]) <= threshold[i] for i in range(len(threshold)) + )) + index = xp.linspace(0, kspace_loc.shape[0] - 1, kspace_loc.shape[0], dtype=xp.int64) + index = xp.extract(condition, index) + return index + +def extract_k_space_center_and_locations( + kspace_data, kspace_loc, threshold=None, window_fun=None, + ): + r""" + Extract k-space center and corresponding sampling locations. + + Parameters + ---------- + kspace_data: numpy.ndarray + The value of the samples + kspace_loc: numpy.ndarray + The samples location in the k-sapec domain (between [-0.5, 0.5[) + threshold: tuple or float + The threshold used to extract the k_space center (between (0, 1]) + window_fun: "Hann", "Hanning", "Hamming", or a callable, default None. + The window function to apply to the selected data. It is computed with + the center locations selected. Only works with circular mask. + If window_fun is a callable, it takes as input the array (n_samples x n_dims) + of sample positions and returns an array of n_samples weights to be + applied to the selected k-space values, before the smaps estimation. + + + Returns + ------- + The extracted center of the k-space, i.e. both the kspace locations and + kspace values. If the density compensators are passed, the corresponding + compensators for the center of k-space data will also be returned. The + return stypes for density compensation and kspace data is same as input + + Notes + ----- + The Hann (or Hanning) and Hamming windows of width :math:`2\theta` are defined as: + .. math:: + + w(x,y) = a_0 - (1-a_0) * \cos(\pi * \sqrt{x^2+y^2}/\theta), + \sqrt{x^2+y^2} \le \theta + + In the case of Hann window :math:`a_0=0.5`. + For Hamming window we consider the optimal value in the equiripple sense: + :math:`a_0=0.53836`. + .. Wikipedia:: https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows + + """ + xp = get_array_module(kspace_data) + radius = xp.linalg.norm(center_locations, axis=1) + data_ordered = xp.copy(kspace_data) + if isinstance(threshold, float): + threshold = (threshold,) * kspace_loc.shape[1] + condition = xp.logical_and.reduce(tuple( + xp.abs(kspace_loc[:, i]) <= threshold[i] for i in range(len(threshold)) + )) + index = xp.linspace(0, kspace_loc.shape[0] - 1, kspace_loc.shape[0], dtype=xp.int64) + index = xp.extract(condition, index) + center_locations = kspace_loc[index, :] + data_thresholded = data_ordered[:, index] + if window_fun is not None: + if callable(window_fun): + window = window_fun(center_locations) + else: + if window_fun == "Hann" or window_fun == "Hanning": + a_0 = 0.5 + elif window_fun == "Hamming": + a_0 = 0.53836 + else: + raise ValueError("Unsupported window function.") + + window = a_0 + (1 - a_0) * xp.cos(xp.pi * radius / threshold) + data_thresholded = window * data_thresholded + + if density_comp is not None: + density_comp = density_comp[index] + return data_thresholded, center_locations, density_comp + else: + return data_thresholded, center_locations + + +@register_smaps +@with_numpy_cupy +@flat_traj +def low_frequency(traj, kspace_data, shape, backend, theshold, *args, **kwargs): + xp = get_array_module(kspace_data) + k_space, samples, dc = extract_k_space_center_and_locations( + kspace_data=kspace_data, + kspace_loc=traj, + threshold=threshold, + img_shape=traj_params['img_size'], + ) + smaps_adj_op = get_operator('gpunufft')( + samples, + shape, + density=dc, + n_coils=k_space.shape[0] + ) + Smaps_ = smaps_adj_op.adj_op(k_space) + SOS = xp.linalg.norm(Smaps_ , axis=0) + thresh = threshold_otsu(SOS) + convex_hull = convex_hull_image(SOS>thresh) + + \ No newline at end of file From 238b358c20de180a4b0e061ade6c1c290353a0fd Mon Sep 17 00:00:00 2001 From: Chaithya G R Date: Thu, 11 Apr 2024 11:35:02 +0200 Subject: [PATCH 15/50] Fixes updates --- src/mrinufft/extras/smaps.py | 86 ++++++++++++------------------------ 1 file changed, 28 insertions(+), 58 deletions(-) diff --git a/src/mrinufft/extras/smaps.py b/src/mrinufft/extras/smaps.py index e07d553b..3d4eb150 100644 --- a/src/mrinufft/extras/smaps.py +++ b/src/mrinufft/extras/smaps.py @@ -6,36 +6,8 @@ register_smaps = MethodRegister("sensitivity_maps") -@flat_traj -def _get_centeral_index(kspace_loc, threshold): - r""" - Extract the index of the k-space center. - - Parameters - ---------- - kspace_loc: numpy.ndarray - The samples location in the k-sapec domain (between [-0.5, 0.5[) - threshold: tuple or float - The threshold used to extract the k_space center (between (0, 1]) - - Returns - ------- - The index of the k-space center. - """ - xp = get_array_module(kspace_loc) - radius = xp.linalg.norm(kspace_loc, axis=-1) - - if isinstance(threshold, float): - threshold = (threshold,) * kspace_loc.shape[-1] - condition = xp.logical_and.reduce(tuple( - xp.abs(kspace_loc[:, i]) <= threshold[i] for i in range(len(threshold)) - )) - index = xp.linspace(0, kspace_loc.shape[0] - 1, kspace_loc.shape[0], dtype=xp.int64) - index = xp.extract(condition, index) - return index - -def extract_k_space_center_and_locations( - kspace_data, kspace_loc, threshold=None, window_fun=None, +def extract_kspace_center( + kspace_data, kspace_loc, threshold=None, window_fun="ellipse", ): r""" Extract k-space center and corresponding sampling locations. @@ -48,7 +20,8 @@ def extract_k_space_center_and_locations( The samples location in the k-sapec domain (between [-0.5, 0.5[) threshold: tuple or float The threshold used to extract the k_space center (between (0, 1]) - window_fun: "Hann", "Hanning", "Hamming", or a callable, default None. + window_fun: "hann" / "hanning", "hamming", "ellipse", "rect", or a callable, + default "ellipse". The window function to apply to the selected data. It is computed with the center locations selected. Only works with circular mask. If window_fun is a callable, it takes as input the array (n_samples x n_dims) @@ -78,50 +51,47 @@ def extract_k_space_center_and_locations( """ xp = get_array_module(kspace_data) - radius = xp.linalg.norm(center_locations, axis=1) - data_ordered = xp.copy(kspace_data) if isinstance(threshold, float): threshold = (threshold,) * kspace_loc.shape[1] - condition = xp.logical_and.reduce(tuple( - xp.abs(kspace_loc[:, i]) <= threshold[i] for i in range(len(threshold)) - )) - index = xp.linspace(0, kspace_loc.shape[0] - 1, kspace_loc.shape[0], dtype=xp.int64) - index = xp.extract(condition, index) - center_locations = kspace_loc[index, :] - data_thresholded = data_ordered[:, index] - if window_fun is not None: + + if window_fun == "rect": + data_ordered = xp.copy(kspace_data) + index = xp.linspace(0, kspace_loc.shape[0] - 1, kspace_loc.shape[0], dtype=xp.int64) + condition = xp.logical_and.reduce(tuple( + xp.abs(kspace_loc[:, i]) <= threshold[i] for i in range(len(threshold)) + )) + index = xp.extract(condition, index) + center_locations = kspace_loc[index, :] + data_thresholded = data_ordered[:, index] + else: if callable(window_fun): window = window_fun(center_locations) else: - if window_fun == "Hann" or window_fun == "Hanning": - a_0 = 0.5 - elif window_fun == "Hamming": - a_0 = 0.53836 + if window_fun in ["hann", "hanning", "hamming"]: + radius = xp.linalg.norm(kspace_loc, axis=1) + a_0 = 0.5 if window_fun in ["hann", "hanning"] else 0.53836 + window = a_0 + (1 - a_0) * xp.cos(xp.pi * radius / threshold) + elif window_fun == "ellipse": + window = xp.sum(kspace_loc**2/ xp.asarray(threshold)**2, axis=1) <= 1 else: raise ValueError("Unsupported window function.") - - window = a_0 + (1 - a_0) * xp.cos(xp.pi * radius / threshold) data_thresholded = window * data_thresholded - - if density_comp is not None: - density_comp = density_comp[index] - return data_thresholded, center_locations, density_comp - else: - return data_thresholded, center_locations + # Return k-space locations just for consistency + return data_thresholded, kspace_loc @register_smaps -@with_numpy_cupy @flat_traj -def low_frequency(traj, kspace_data, shape, backend, theshold, *args, **kwargs): +def low_frequency(traj, kspace_data, shape, backend, threshold, *args, **kwargs): xp = get_array_module(kspace_data) - k_space, samples, dc = extract_k_space_center_and_locations( + k_space, traj = extract_kspace_center( kspace_data=kspace_data, kspace_loc=traj, threshold=threshold, - img_shape=traj_params['img_size'], + img_shape=shape, + **kwargs, ) - smaps_adj_op = get_operator('gpunufft')( + smaps_adj_op = get_operator(backend)( samples, shape, density=dc, From c17ad556c1448e67721fa840116c8c62bfacbdf4 Mon Sep 17 00:00:00 2001 From: Chaithya G R Date: Thu, 11 Apr 2024 12:55:10 +0200 Subject: [PATCH 16/50] Fix --- src/mrinufft/extras/smaps.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/mrinufft/extras/smaps.py b/src/mrinufft/extras/smaps.py index 3d4eb150..666f2306 100644 --- a/src/mrinufft/extras/smaps.py +++ b/src/mrinufft/extras/smaps.py @@ -1,6 +1,7 @@ from mrinufft._utils import MethodRegister from mrinufft.density.utils import flat_traj from mrinufft.operators.base import with_numpy_cupy, get_array_module +from mrinufft import get_operator register_smaps = MethodRegister("sensitivity_maps") @@ -82,12 +83,13 @@ def extract_kspace_center( @register_smaps @flat_traj -def low_frequency(traj, kspace_data, shape, backend, threshold, *args, **kwargs): +def low_frequency(traj, kspace_data, shape, backend, threshold, density=None, *args, **kwargs): xp = get_array_module(kspace_data) - k_space, traj = extract_kspace_center( + k_space, samples, dc = extract_kspace_center( kspace_data=kspace_data, kspace_loc=traj, threshold=threshold, + density=density, img_shape=shape, **kwargs, ) From dd2cf1d9b099975b4d5e03a816a43abdf77bf406 Mon Sep 17 00:00:00 2001 From: Chaithya G R Date: Thu, 11 Apr 2024 13:07:40 +0200 Subject: [PATCH 17/50] fix docs --- src/mrinufft/extras/smaps.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/src/mrinufft/extras/smaps.py b/src/mrinufft/extras/smaps.py index 666f2306..c644480f 100644 --- a/src/mrinufft/extras/smaps.py +++ b/src/mrinufft/extras/smaps.py @@ -10,8 +10,12 @@ def extract_kspace_center( kspace_data, kspace_loc, threshold=None, window_fun="ellipse", ): - r""" - Extract k-space center and corresponding sampling locations. + r"""Extract k-space center and corresponding sampling locations. + + The extracted center of the k-space, i.e. both the kspace locations and + kspace values. If the density compensators are passed, the corresponding + compensators for the center of k-space data will also be returned. The + return dtypes for density compensation and kspace data is same as input Parameters ---------- @@ -21,22 +25,22 @@ def extract_kspace_center( The samples location in the k-sapec domain (between [-0.5, 0.5[) threshold: tuple or float The threshold used to extract the k_space center (between (0, 1]) - window_fun: "hann" / "hanning", "hamming", "ellipse", "rect", or a callable, - default "ellipse". + window_fun: "Hann", "Hanning", "Hamming", or a callable, default None. The window function to apply to the selected data. It is computed with the center locations selected. Only works with circular mask. If window_fun is a callable, it takes as input the array (n_samples x n_dims) of sample positions and returns an array of n_samples weights to be applied to the selected k-space values, before the smaps estimation. - - + Returns ------- - The extracted center of the k-space, i.e. both the kspace locations and - kspace values. If the density compensators are passed, the corresponding - compensators for the center of k-space data will also be returned. The - return stypes for density compensation and kspace data is same as input - + data_thresholded: ndarray + The k-space values in the center region. + center_loc: ndarray + The locations in the center region. + density_comp: ndarray, optional + The density compensation weights (if requested) + Notes ----- The Hann (or Hanning) and Hamming windows of width :math:`2\theta` are defined as: From 506bce66bfa7efa4f8f6a4301760d8b470265131 Mon Sep 17 00:00:00 2001 From: Chaithya G R Date: Thu, 11 Apr 2024 13:27:48 +0200 Subject: [PATCH 18/50] Added smaps with blurring --- src/mrinufft/extras/smaps.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/src/mrinufft/extras/smaps.py b/src/mrinufft/extras/smaps.py index c644480f..d13983ba 100644 --- a/src/mrinufft/extras/smaps.py +++ b/src/mrinufft/extras/smaps.py @@ -1,8 +1,10 @@ from mrinufft._utils import MethodRegister from mrinufft.density.utils import flat_traj -from mrinufft.operators.base import with_numpy_cupy, get_array_module +from mrinufft.operators.base import get_array_module from mrinufft import get_operator - +from skimage.filters import threshold_otsu, gaussian +from skimage.morphology import convex_hull_image +import numpy as np register_smaps = MethodRegister("sensitivity_maps") @@ -87,15 +89,15 @@ def extract_kspace_center( @register_smaps @flat_traj -def low_frequency(traj, kspace_data, shape, backend, threshold, density=None, *args, **kwargs): - xp = get_array_module(kspace_data) +def low_frequency(traj, kspace_data, shape, threshold, backend, density=None, + extract_args=None, blurr_factor=0): k_space, samples, dc = extract_kspace_center( kspace_data=kspace_data, kspace_loc=traj, threshold=threshold, density=density, img_shape=shape, - **kwargs, + **(extract_args or {}), ) smaps_adj_op = get_operator(backend)( samples, @@ -104,8 +106,14 @@ def low_frequency(traj, kspace_data, shape, backend, threshold, density=None, *a n_coils=k_space.shape[0] ) Smaps_ = smaps_adj_op.adj_op(k_space) - SOS = xp.linalg.norm(Smaps_ , axis=0) + SOS = np.linalg.norm(Smaps_, axis=0) thresh = threshold_otsu(SOS) convex_hull = convex_hull_image(SOS>thresh) - + Smaps = Smaps_ * convex_hull / SOS + # Smooth out the sensitivity maps + if blurr_factor > 0: + Smaps = gaussian(Smaps, sigma=blurr_factor * np.asarray(shape)) + SOS = np.linalg.norm(Smaps, axis=0) + Smaps = Smaps / SOS + return Smaps, SOS \ No newline at end of file From 2a62f94eed898ae68f1cea02397508593aad3b35 Mon Sep 17 00:00:00 2001 From: Chaithya G R Date: Thu, 11 Apr 2024 13:28:44 +0200 Subject: [PATCH 19/50] Added doc --- src/mrinufft/extras/smaps.py | 35 ++++++++++++++++++++++++++++++++--- 1 file changed, 32 insertions(+), 3 deletions(-) diff --git a/src/mrinufft/extras/smaps.py b/src/mrinufft/extras/smaps.py index d13983ba..98b9e377 100644 --- a/src/mrinufft/extras/smaps.py +++ b/src/mrinufft/extras/smaps.py @@ -89,15 +89,44 @@ def extract_kspace_center( @register_smaps @flat_traj -def low_frequency(traj, kspace_data, shape, threshold, backend, density=None, - extract_args=None, blurr_factor=0): +def low_frequency(traj, kspace_data, shape, threshold, backend, density=None, + extract_kwargs=None, blurr_factor=0): + """ + Calculate low-frequency sensitivity maps. + + Parameters + ---------- + traj : numpy.ndarray + The trajectory of the samples. + kspace_data : numpy.ndarray + The k-space data. + shape : tuple + The shape of the image. + threshold : float + The threshold used for extracting the k-space center. + backend : str + The backend used for the operator. + density : numpy.ndarray, optional + The density compensation weights. + extract_kwargs : dict, optional + Additional keyword arguments for the `extract_kspace_center` function. + blurr_factor : float, optional + The blurring factor for smoothing the sensitivity maps. + + Returns + ------- + Smaps : numpy.ndarray + The low-frequency sensitivity maps. + SOS : numpy.ndarray + The sum of squares of the sensitivity maps. + """ k_space, samples, dc = extract_kspace_center( kspace_data=kspace_data, kspace_loc=traj, threshold=threshold, density=density, img_shape=shape, - **(extract_args or {}), + **(extract_kwargs or {}), ) smaps_adj_op = get_operator(backend)( samples, From c1dabee3af9b0ca8ba5ac3cbd83e24ed39b54107 Mon Sep 17 00:00:00 2001 From: Chaithya G R Date: Thu, 11 Apr 2024 13:30:40 +0200 Subject: [PATCH 20/50] Final touchups --- src/mrinufft/extras/smaps.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/mrinufft/extras/smaps.py b/src/mrinufft/extras/smaps.py index 98b9e377..4b410686 100644 --- a/src/mrinufft/extras/smaps.py +++ b/src/mrinufft/extras/smaps.py @@ -10,7 +10,7 @@ def extract_kspace_center( - kspace_data, kspace_loc, threshold=None, window_fun="ellipse", + kspace_data, kspace_loc, threshold=None, density=None, window_fun="ellipse", ): r"""Extract k-space center and corresponding sampling locations. @@ -70,6 +70,8 @@ def extract_kspace_center( index = xp.extract(condition, index) center_locations = kspace_loc[index, :] data_thresholded = data_ordered[:, index] + dc = density[index] + return data_thresholded, center_locations, dc else: if callable(window_fun): window = window_fun(center_locations) @@ -83,8 +85,8 @@ def extract_kspace_center( else: raise ValueError("Unsupported window function.") data_thresholded = window * data_thresholded - # Return k-space locations just for consistency - return data_thresholded, kspace_loc + # Return k-space locations & density just for consistency + return data_thresholded, kspace_loc, density @register_smaps From ed4d97442a16201e32ecf7814f2d5d3f4750235e Mon Sep 17 00:00:00 2001 From: Chaithya G R Date: Thu, 11 Apr 2024 13:42:49 +0200 Subject: [PATCH 21/50] Added compute_smaps --- src/mrinufft/extras/smaps.py | 8 +++----- src/mrinufft/operators/base.py | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 5 deletions(-) diff --git a/src/mrinufft/extras/smaps.py b/src/mrinufft/extras/smaps.py index 4b410686..f3f52a4f 100644 --- a/src/mrinufft/extras/smaps.py +++ b/src/mrinufft/extras/smaps.py @@ -1,15 +1,13 @@ -from mrinufft._utils import MethodRegister from mrinufft.density.utils import flat_traj from mrinufft.operators.base import get_array_module from mrinufft import get_operator from skimage.filters import threshold_otsu, gaussian from skimage.morphology import convex_hull_image +from .utils import register_smaps import numpy as np -register_smaps = MethodRegister("sensitivity_maps") - -def extract_kspace_center( +def _extract_kspace_center( kspace_data, kspace_loc, threshold=None, density=None, window_fun="ellipse", ): r"""Extract k-space center and corresponding sampling locations. @@ -122,7 +120,7 @@ def low_frequency(traj, kspace_data, shape, threshold, backend, density=None, SOS : numpy.ndarray The sum of squares of the sensitivity maps. """ - k_space, samples, dc = extract_kspace_center( + k_space, samples, dc = _extract_kspace_center( kspace_data=kspace_data, kspace_loc=traj, threshold=threshold, diff --git a/src/mrinufft/operators/base.py b/src/mrinufft/operators/base.py index 8aee5120..8afd8215 100644 --- a/src/mrinufft/operators/base.py +++ b/src/mrinufft/operators/base.py @@ -14,6 +14,7 @@ from mrinufft.operators.interfaces.utils import is_cuda_array from mrinufft.density import get_density +from mrinufft.extras import get_smaps CUPY_AVAILABLE = True try: @@ -225,6 +226,39 @@ def with_off_resonnance_correction(self, B, C, indices): from ..off_resonnance import MRIFourierCorrected return MRIFourierCorrected(self, B, C, indices) + + def compute_smaps(self, kspace_data, method=None): + """Compute the sensitivity maps and set it. + + Parameters + ---------- + kspace_data: np.ndarray + The k-space data to be used to estimate sensitivity maps + method: str or callable or dict + The method to use to compute the sensitivity maps. + If a string, the method should be registered in the smaps registry. + If a callable, it should take the samples and the shape as input. + If a dict, it should have a key 'name', to determine which method to use. + other items will be used as kwargs. + """ + if not method: + self.smaps = None + return None + kwargs = {} + if isinstance(method, dict): + kwargs = method.copy() + method = kwargs.pop("name") + if isinstance(method, str): + method = get_smaps(method) + if not callable(method): + raise ValueError(f"Unknown smaps method: {method}") + self.smaps, self.SOS = method( + self.samples, + self.shape, + density=self.density, + backend=self.backend, + **kwargs + ) def compute_density(self, method=None): """Compute the density compensation weights and set it. From f6596c93a190b9b8e4c4189396cc8e088224da16 Mon Sep 17 00:00:00 2001 From: Chaithya G R Date: Thu, 11 Apr 2024 13:43:29 +0200 Subject: [PATCH 22/50] Added extra files --- src/mrinufft/extras/__init__.py | 10 ++++++++++ src/mrinufft/extras/utils.py | 17 +++++++++++++++++ 2 files changed, 27 insertions(+) create mode 100644 src/mrinufft/extras/__init__.py create mode 100644 src/mrinufft/extras/utils.py diff --git a/src/mrinufft/extras/__init__.py b/src/mrinufft/extras/__init__.py new file mode 100644 index 00000000..fdae0e45 --- /dev/null +++ b/src/mrinufft/extras/__init__.py @@ -0,0 +1,10 @@ +"""Sensitivity map estimation methods.""" + +from .smaps import low_frequency +from .utils import get_density + + +__all__ = [ + "low_frequency", + "get_smaps", +] diff --git a/src/mrinufft/extras/utils.py b/src/mrinufft/extras/utils.py new file mode 100644 index 00000000..faf56a0c --- /dev/null +++ b/src/mrinufft/extras/utils.py @@ -0,0 +1,17 @@ +from mrinufft._utils import MethodRegister + +register_smaps = MethodRegister("sensitivity_maps") + +def get_smaps(name, *args, **kwargs): + """Get the density compensation function from its name.""" + try: + method = register_smaps.registry["sensitivity_maps"][name] + except KeyError as e: + raise ValueError( + f"Unknown density compensation method {name}. Available methods are \n" + f"{list(register_smaps.registry['sensitivity_maps'].keys())}" + ) from e + + if args or kwargs: + return method(*args, **kwargs) + return method \ No newline at end of file From dbb7743f3bc5c951d7d2dd4516987635d9101ff6 Mon Sep 17 00:00:00 2001 From: Chaithya G R Date: Thu, 11 Apr 2024 14:26:29 +0200 Subject: [PATCH 23/50] Added compute_smaps --- src/mrinufft/extras/smaps.py | 1 + src/mrinufft/operators/base.py | 16 ++++++++++------ src/mrinufft/operators/interfaces/cufinufft.py | 2 +- src/mrinufft/operators/interfaces/gpunufft.py | 3 ++- 4 files changed, 14 insertions(+), 8 deletions(-) diff --git a/src/mrinufft/extras/smaps.py b/src/mrinufft/extras/smaps.py index f3f52a4f..7829efe0 100644 --- a/src/mrinufft/extras/smaps.py +++ b/src/mrinufft/extras/smaps.py @@ -142,6 +142,7 @@ def low_frequency(traj, kspace_data, shape, threshold, backend, density=None, # Smooth out the sensitivity maps if blurr_factor > 0: Smaps = gaussian(Smaps, sigma=blurr_factor * np.asarray(shape)) + # Re-normalize the sensitivity maps SOS = np.linalg.norm(Smaps, axis=0) Smaps = Smaps / SOS return Smaps, SOS diff --git a/src/mrinufft/operators/base.py b/src/mrinufft/operators/base.py index 8afd8215..e69ad196 100644 --- a/src/mrinufft/operators/base.py +++ b/src/mrinufft/operators/base.py @@ -227,20 +227,23 @@ def with_off_resonnance_correction(self, B, C, indices): return MRIFourierCorrected(self, B, C, indices) - def compute_smaps(self, kspace_data, method=None): + def compute_smaps(self, method=None): """Compute the sensitivity maps and set it. Parameters ---------- - kspace_data: np.ndarray - The k-space data to be used to estimate sensitivity maps - method: str or callable or dict + method: callable or dict or array The method to use to compute the sensitivity maps. - If a string, the method should be registered in the smaps registry. - If a callable, it should take the samples and the shape as input. + If an array, it should be of shape (NCoils,XYZ) and will be used as is. If a dict, it should have a key 'name', to determine which method to use. other items will be used as kwargs. + If a callable, it should take the samples and the shape as input. + Note that this callable function should also hold the k-space data + (use funtools.partial) """ + if isinstance(method, np.ndarray): + self.smaps = method + return None if not method: self.smaps = None return None @@ -476,6 +479,7 @@ def __init__( # Density Compensation Setup self.compute_density(density) + self.compute_smaps(smaps) # Multi Coil Setup if n_coils < 1: raise ValueError("n_coils should be ≥ 1") diff --git a/src/mrinufft/operators/interfaces/cufinufft.py b/src/mrinufft/operators/interfaces/cufinufft.py index 3d512127..0f7a2a7f 100644 --- a/src/mrinufft/operators/interfaces/cufinufft.py +++ b/src/mrinufft/operators/interfaces/cufinufft.py @@ -210,7 +210,7 @@ def __init__( self.density = cp.array(self.density) # Smaps support - self.smaps = smaps + self.compute_smaps(smaps) self.smaps_cached = False if smaps is not None: if not (is_host_array(smaps) or is_cuda_array(smaps)): diff --git a/src/mrinufft/operators/interfaces/gpunufft.py b/src/mrinufft/operators/interfaces/gpunufft.py index 7cc5d926..f6f4786b 100644 --- a/src/mrinufft/operators/interfaces/gpunufft.py +++ b/src/mrinufft/operators/interfaces/gpunufft.py @@ -377,12 +377,13 @@ def __init__( self.smaps = smaps self.squeeze_dims = squeeze_dims self.compute_density(density) + self.compute_smaps(smaps) self.impl = RawGpuNUFFT( samples=self.samples, shape=self.shape, n_coils=self.n_coils, density_comp=self.density, - smaps=smaps, + smaps=self.smaps, kernel_width=kwargs.get("kernel_width", -int(np.log10(eps))), **kwargs, ) From f138ee7824ac2d9f1476936c6afbdd7ba3c88fd5 Mon Sep 17 00:00:00 2001 From: chaithyagr Date: Wed, 24 Apr 2024 10:21:07 +0200 Subject: [PATCH 24/50] Added mask --- src/mrinufft/extras/smaps.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/mrinufft/extras/smaps.py b/src/mrinufft/extras/smaps.py index 7829efe0..5e348e8c 100644 --- a/src/mrinufft/extras/smaps.py +++ b/src/mrinufft/extras/smaps.py @@ -90,7 +90,7 @@ def _extract_kspace_center( @register_smaps @flat_traj def low_frequency(traj, kspace_data, shape, threshold, backend, density=None, - extract_kwargs=None, blurr_factor=0): + extract_kwargs=None, blurr_factor=0, mask=True): """ Calculate low-frequency sensitivity maps. @@ -112,6 +112,8 @@ def low_frequency(traj, kspace_data, shape, threshold, backend, density=None, Additional keyword arguments for the `extract_kspace_center` function. blurr_factor : float, optional The blurring factor for smoothing the sensitivity maps. + mask: bool, optional default `True` + Whether the Sensitivity maps must be masked Returns ------- @@ -136,9 +138,11 @@ def low_frequency(traj, kspace_data, shape, threshold, backend, density=None, ) Smaps_ = smaps_adj_op.adj_op(k_space) SOS = np.linalg.norm(Smaps_, axis=0) - thresh = threshold_otsu(SOS) - convex_hull = convex_hull_image(SOS>thresh) - Smaps = Smaps_ * convex_hull / SOS + if mask: + thresh = threshold_otsu(SOS) + # Create convex hull from mask + convex_hull = convex_hull_image(SOS>thresh) + Smaps = Smaps_ * convex_hull / SOS # Smooth out the sensitivity maps if blurr_factor > 0: Smaps = gaussian(Smaps, sigma=blurr_factor * np.asarray(shape)) From 3c27e7fefc81b841d1286a0fcd1fc3ed4b53dd41 Mon Sep 17 00:00:00 2001 From: chaithyagr Date: Wed, 24 Apr 2024 10:32:36 +0200 Subject: [PATCH 25/50] Added Smaps --- src/mrinufft/extras/smaps.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/mrinufft/extras/smaps.py b/src/mrinufft/extras/smaps.py index 5e348e8c..ff951c42 100644 --- a/src/mrinufft/extras/smaps.py +++ b/src/mrinufft/extras/smaps.py @@ -89,7 +89,7 @@ def _extract_kspace_center( @register_smaps @flat_traj -def low_frequency(traj, kspace_data, shape, threshold, backend, density=None, +def low_frequency(traj, shape, kspace_data, threshold, backend, density=None, extract_kwargs=None, blurr_factor=0, mask=True): """ Calculate low-frequency sensitivity maps. @@ -98,10 +98,10 @@ def low_frequency(traj, kspace_data, shape, threshold, backend, density=None, ---------- traj : numpy.ndarray The trajectory of the samples. - kspace_data : numpy.ndarray - The k-space data. shape : tuple The shape of the image. + kspace_data : numpy.ndarray + The k-space data. threshold : float The threshold used for extracting the k-space center. backend : str @@ -136,18 +136,20 @@ def low_frequency(traj, kspace_data, shape, threshold, backend, density=None, density=dc, n_coils=k_space.shape[0] ) - Smaps_ = smaps_adj_op.adj_op(k_space) - SOS = np.linalg.norm(Smaps_, axis=0) + Smaps = smaps_adj_op.adj_op(k_space) + SOS = np.linalg.norm(Smaps, axis=0) if mask: thresh = threshold_otsu(SOS) # Create convex hull from mask convex_hull = convex_hull_image(SOS>thresh) - Smaps = Smaps_ * convex_hull / SOS + Smaps = Smaps * convex_hull # Smooth out the sensitivity maps if blurr_factor > 0: Smaps = gaussian(Smaps, sigma=blurr_factor * np.asarray(shape)) # Re-normalize the sensitivity maps - SOS = np.linalg.norm(Smaps, axis=0) - Smaps = Smaps / SOS + if mask or blurr_factor > 0: + # ReCalculate SOS with a minor eps to ensure divide by 0 is ok + SOS = np.linalg.norm(Smaps, axis=0) + 1e-10 + Smaps = Smaps / SOS return Smaps, SOS \ No newline at end of file From d662498da041edd3b2af4c4d1b36e77ca715c07f Mon Sep 17 00:00:00 2001 From: chaithyagr Date: Fri, 26 Apr 2024 11:21:45 +0200 Subject: [PATCH 26/50] Updates --- src/mrinufft/extras/__init__.py | 2 +- src/mrinufft/extras/smaps.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/mrinufft/extras/__init__.py b/src/mrinufft/extras/__init__.py index fdae0e45..6ba1d4a4 100644 --- a/src/mrinufft/extras/__init__.py +++ b/src/mrinufft/extras/__init__.py @@ -1,7 +1,7 @@ """Sensitivity map estimation methods.""" from .smaps import low_frequency -from .utils import get_density +from .utils import get_smaps __all__ = [ diff --git a/src/mrinufft/extras/smaps.py b/src/mrinufft/extras/smaps.py index ff951c42..de8c1236 100644 --- a/src/mrinufft/extras/smaps.py +++ b/src/mrinufft/extras/smaps.py @@ -1,6 +1,5 @@ from mrinufft.density.utils import flat_traj from mrinufft.operators.base import get_array_module -from mrinufft import get_operator from skimage.filters import threshold_otsu, gaussian from skimage.morphology import convex_hull_image from .utils import register_smaps @@ -22,7 +21,7 @@ def _extract_kspace_center( kspace_data: numpy.ndarray The value of the samples kspace_loc: numpy.ndarray - The samples location in the k-sapec domain (between [-0.5, 0.5[) + The samples location in the k-space domain (between [-0.5, 0.5[) threshold: tuple or float The threshold used to extract the k_space center (between (0, 1]) window_fun: "Hann", "Hanning", "Hamming", or a callable, default None. @@ -122,6 +121,8 @@ def low_frequency(traj, shape, kspace_data, threshold, backend, density=None, SOS : numpy.ndarray The sum of squares of the sensitivity maps. """ + # defer import to later to prevent circular import + from mrinufft import get_operator k_space, samples, dc = _extract_kspace_center( kspace_data=kspace_data, kspace_loc=traj, From 8ca012cd9d96ebb4488eb291a3401d3584189171 Mon Sep 17 00:00:00 2001 From: chaithyagr Date: Fri, 26 Apr 2024 14:45:12 +0200 Subject: [PATCH 27/50] Added --- .../operators/interfaces/nudft_numpy.py | 40 +++-- tests/test_autodiff.py | 145 ++++++++++++++++++ tests/test_ndft.py | 59 +++---- 3 files changed, 205 insertions(+), 39 deletions(-) create mode 100644 tests/test_autodiff.py diff --git a/src/mrinufft/operators/interfaces/nudft_numpy.py b/src/mrinufft/operators/interfaces/nudft_numpy.py index 31a9fa18..9d26d02f 100644 --- a/src/mrinufft/operators/interfaces/nudft_numpy.py +++ b/src/mrinufft/operators/interfaces/nudft_numpy.py @@ -8,6 +8,7 @@ from ..base import FourierOperatorCPU +<<<<<<< Updated upstream def get_fourier_matrix(ktraj, shape): """Get the NDFT Fourier Matrix.""" n = np.prod(shape) @@ -17,47 +18,64 @@ def get_fourier_matrix(ktraj, shape): grid_r = np.reshape(np.meshgrid(*r, indexing="ij"), (ndim, np.prod(shape))) traj_grid = ktraj @ grid_r matrix = np.exp(-2j * np.pi * traj_grid) +======= +def get_fourier_matrix(ktraj, shape, dtype=np.complex64, normalize=False): + """Get the NDFT Fourier Matrix.""" + n = np.prod(shape) + ndim = len(shape) + matrix = np.zeros((len(ktraj), n), dtype=dtype) + r = [np.linspace(-s/2, s/2-1, s) for s in shape] + grid_r = np.reshape(np.meshgrid(*r, indexing="ij"), (ndim, np.prod(shape))) + traj_grid = ktraj @ grid_r + matrix = np.exp(-2j * np.pi * traj_grid, dtype=dtype) + if normalize: + matrix /= (np.sqrt(np.prod(shape)) * np.power(np.sqrt(2), len(shape))) +>>>>>>> Stashed changes return matrix -def implicit_type2_ndft(ktraj, image, shape): +def implicit_type2_ndft(ktraj, image, shape, normalize=False): """Compute the NDFT using the implicit type 2 (image -> kspace) algorithm.""" - r = [np.arange(s) for s in shape] + r = [np.linspace(-s/2, s/2-1, s) for s in shape] grid_r = np.reshape( np.meshgrid(*r, indexing="ij"), (len(shape), np.prod(image.shape)) ) res = np.zeros(len(ktraj), dtype=image.dtype) for j in range(np.prod(image.shape)): res += image[j] * np.exp(-2j * np.pi * ktraj @ grid_r[:, j]) + if normalize: + matrix /= (np.sqrt(np.prod(shape)) * np.power(np.sqrt(2), len(shape))) return res -def implicit_type1_ndft(ktraj, coeffs, shape): +def implicit_type1_ndft(ktraj, coeffs, shape, normalize=False): """Compute the NDFT using the implicit type 1 (kspace -> image) algorithm.""" - r = [np.arange(s) for s in shape] + r = [np.linspace(-s/2, s/2-1, s) for s in shape] grid_r = np.reshape(np.meshgrid(*r, indexing="ij"), (len(shape), np.prod(shape))) res = np.zeros(np.prod(shape), dtype=coeffs.dtype) for i in range(len(ktraj)): res += coeffs[i] * np.exp(2j * np.pi * ktraj[i] @ grid_r) + if normalize: + matrix /= (np.sqrt(np.prod(shape)) * np.power(np.sqrt(2), len(shape))) return res -def get_implicit_matrix(ktraj, shape): +def get_implicit_matrix(ktraj, shape, normalize=False): """Get the NDFT Fourier Matrix as implicit operator. This is more memory efficient than the explicit matrix. """ return sp.sparse.linalg.LinearOperator( (len(ktraj), np.prod(shape)), - matvec=lambda x: implicit_type2_ndft(ktraj, x, shape), - rmatvec=lambda x: implicit_type1_ndft(ktraj, x, shape), + matvec=lambda x: implicit_type2_ndft(ktraj, x, shape, normalize), + rmatvec=lambda x: implicit_type1_ndft(ktraj, x, shape, normalize), ) class RawNDFT: """Implementation of the NUDFT using numpy.""" - def __init__(self, samples, shape, explicit_matrix=True): + def __init__(self, samples, shape, explicit_matrix=True, normalize=False): self.samples = samples self.shape = shape self.n_samples = len(samples) @@ -65,13 +83,13 @@ def __init__(self, samples, shape, explicit_matrix=True): if explicit_matrix: try: self._fourier_matrix = sp.sparse.linalg.aslinearoperator( - get_fourier_matrix(self.samples, self.shape) + get_fourier_matrix(self.samples, self.shape, normalize=normalize) ) except MemoryError: warnings.warn("Not enough memory, using an implicit definition anyway") - self._fourier_matrix = get_implicit_matrix(self.samples, self.shape) + self._fourier_matrix = get_implicit_matrix(self.samples, self.shape, normalize) else: - self._fourier_matrix = get_implicit_matrix(self.samples, self.shape) + self._fourier_matrix = get_implicit_matrix(self.samples, self.shape, normalize) def op(self, coeffs, image): """Compute the forward NUDFT.""" diff --git a/tests/test_autodiff.py b/tests/test_autodiff.py new file mode 100644 index 00000000..01de2f91 --- /dev/null +++ b/tests/test_autodiff.py @@ -0,0 +1,145 @@ +"""Test the autodiff functionnality.""" + +import numpy as np +from mrinufft.operators.interfaces.nudft_numpy import get_fourier_matrix +import pytest +from pytest_cases import parametrize_with_cases, parametrize, fixture +from case_trajectories import CasesTrajectories +from mrinufft.operators import get_operator + + +from helpers import ( + kspace_from_op, + image_from_op, + to_interface, +) + + +TORCH_AVAILABLE = True +try: + import torch + import torch.testing as tt +except ImportError: + TORCH_AVAILABLE = False + + +@fixture(scope="module") +@parametrize(backend=["cufinufft", "finufft"]) +@parametrize(autograd=["data"]) +@parametrize_with_cases( + "kspace_loc, shape", + cases=[ + CasesTrajectories.case_grid2D, + CasesTrajectories.case_nyquist_radial2D, + ], # 2D cases only for reduced memory footprint. +) +def operator(kspace_loc, shape, backend, autograd): + """Create NUFFT operator with autodiff capabilities.""" + kspace_loc = kspace_loc.astype(np.float32) + + nufft = get_operator(backend_name=backend, autograd=autograd)( + samples=kspace_loc, + shape=shape, + smaps=None, + ) + + return nufft + + +@fixture(scope="module") +def ndft_matrix(operator): + """Get the NDFT matrix from the operator.""" + return get_fourier_matrix(operator.samples, operator.shape, normalize=True) + + +@pytest.mark.parametrize("interface", ["torch-gpu", "torch-cpu"]) +@pytest.mark.skipif(not TORCH_AVAILABLE, reason="Pytorch is not installed") +def test_adjoint_and_grad(operator, ndft_matrix, interface): + """Test the adjoint and gradient of the operator.""" + if operator.backend == "finufft" and "gpu" in interface: + pytest.skip("GPU not supported for finufft backend") + ndft_matrix_torch = to_interface(ndft_matrix, interface=interface) + ksp_data = to_interface(kspace_from_op(operator), interface=interface) + img_data = to_interface(image_from_op(operator), interface=interface) + ksp_data.requires_grad = True + + with torch.autograd.set_detect_anomaly(True): + adj_data = operator.adj_op(ksp_data).reshape(img_data.shape) + adj_data_ndft = (ndft_matrix_torch.conj().T @ ksp_data.flatten()).reshape( + adj_data.shape + ) + loss_nufft = torch.mean(torch.abs(adj_data) ** 2) + loss_ndft = torch.mean(torch.abs(adj_data_ndft) ** 2) + + # Check if nufft and ndft are close in the backprop + grad_ndft_kdata = torch.autograd.grad(loss_ndft, ksp_data, retain_graph=True)[0] + grad_nufft_kdata = torch.autograd.grad(loss_nufft, ksp_data, retain_graph=True)[0] + tt.assert_close(grad_ndft_kdata, grad_nufft_kdata, rtol=1, atol=1) + + +@pytest.mark.parametrize("interface", ["torch-gpu", "torch-cpu"]) +@pytest.mark.skipif(not TORCH_AVAILABLE, reason="Pytorch is not installed") +def test_adjoint_and_gradauto(operator, ndft_matrix, interface): + """Test the adjoint and gradient of the operator using autograd gradcheck.""" + if operator.backend == "finufft" and "gpu" in interface: + pytest.skip("GPU not supported for finufft backend") + + ksp_data = to_interface(kspace_from_op(operator), interface=interface) + ksp_data = torch.ones(ksp_data.shape, requires_grad=True, dtype=ksp_data.dtype) + print(ksp_data.shape) + # todo: tighten the tolerance + assert torch.autograd.gradcheck( + operator.adjoint, + ksp_data, + eps=1e-10, + rtol=1, + atol=1, + nondet_tol=1, + raise_exception=True, + ) + + +@pytest.mark.parametrize("interface", ["torch-gpu", "torch-cpu"]) +@pytest.mark.skipif(not TORCH_AVAILABLE, reason="Pytorch is not installed") +def test_forward_and_grad(operator, ndft_matrix, interface): + """Test the adjoint and gradient of the operator.""" + if operator.backend == "finufft" and "gpu" in interface: + pytest.skip("GPU not supported for finufft backend") + + ndft_matrix_torch = to_interface(ndft_matrix, interface=interface) + ksp_data_ref = to_interface(kspace_from_op(operator), interface=interface) + img_data = to_interface(image_from_op(operator), interface=interface) + img_data.requires_grad = True + + with torch.autograd.set_detect_anomaly(True): + ksp_data = operator.op(img_data).reshape(ksp_data_ref.shape) + ksp_data_ndft = (ndft_matrix_torch @ img_data.flatten()).reshape(ksp_data.shape) + loss_nufft = torch.mean(torch.abs(ksp_data - ksp_data_ref) ** 2) + loss_ndft = torch.mean(torch.abs(ksp_data_ndft - ksp_data_ref) ** 2) + + # Check if nufft and ndft are close in the backprop + grad_ndft_kdata = torch.autograd.grad(loss_ndft, img_data, retain_graph=True)[0] + grad_nufft_kdata = torch.autograd.grad(loss_nufft, img_data, retain_graph=True)[0] + assert torch.allclose(grad_ndft_kdata, grad_nufft_kdata, atol=6e-3) + + +@pytest.mark.parametrize("interface", ["torch-gpu", "torch-cpu"]) +@pytest.mark.skipif(not TORCH_AVAILABLE, reason="Pytorch is not installed") +def test_forward_and_gradauto(operator, ndft_matrix, interface): + """Test the forward and gradient of the operator using autograd gradcheck.""" + if operator.backend == "finufft" and "gpu" in interface: + pytest.skip("GPU not supported for finufft backend") + + img_data = to_interface(image_from_op(operator), interface=interface) + img_data = torch.ones(img_data.shape, requires_grad=True, dtype=img_data.dtype) + print(img_data.shape) + # todo: tighten the tolerance + assert torch.autograd.gradcheck( + operator.adjoint, + img_data, + eps=1e-10, + rtol=1, + atol=1, + nondet_tol=1, + raise_exception=True, + ) diff --git a/tests/test_ndft.py b/tests/test_ndft.py index 7ae595a8..5bcb3e8c 100644 --- a/tests/test_ndft.py +++ b/tests/test_ndft.py @@ -13,32 +13,9 @@ from case_trajectories import CasesTrajectories, case_grid1D from helpers import assert_almost_allclose +from mrinufft import get_operator -@parametrize_with_cases( - "kspace_grid, shape", - cases=[ - case_grid1D, - CasesTrajectories.case_grid2D, - ], # 3D is ignored (too much possibility for the reordering) -) -def test_ndft_grid_matrix(kspace_grid, shape): - """Test that the ndft matrix is a good matrix for doing fft.""" - ndft_matrix = get_fourier_matrix(kspace_grid, shape) - # Create a random image - fft_matrix = [None] * len(shape) - for i in range(len(shape)): - fft_matrix[i] = sp.fft.fft(np.eye(shape[i])) - fft_mat = fft_matrix[0] - if len(shape) == 2: - fft_mat = fft_matrix[0].flatten()[:, None] @ fft_matrix[1].flatten()[None, :] - fft_mat = ( - fft_mat.reshape(shape * 2) - .transpose(2, 0, 1, 3) - .reshape((np.prod(shape),) * 2) - ) - assert np.allclose(ndft_matrix, fft_mat) - @parametrize_with_cases( "kspace, shape", @@ -56,7 +33,7 @@ def test_ndft_implicit2(kspace, shape): linop_coef = implicit_type2_ndft(kspace, random_image.flatten(), shape) matrix_coef = matrix @ random_image.flatten() - assert np.allclose(linop_coef, matrix_coef) + assert_almost_allclose(linop_coef, matrix_coef, atol=1e-4, rtol=1e-4, mismatch=5) @parametrize_with_cases( @@ -76,7 +53,32 @@ def test_ndft_implicit1(kspace, shape): linop_coef = implicit_type1_ndft(kspace, random_kspace.flatten(), shape) matrix_coef = matrix.conj().T @ random_kspace.flatten() - assert np.allclose(linop_coef, matrix_coef) + assert_almost_allclose(linop_coef, matrix_coef, atol=1e-4, rtol=1e-4, mismatch=5) + +@parametrize_with_cases( + "kspace, shape", + cases=[ + CasesTrajectories.case_random2D, + CasesTrajectories.case_grid2D, + CasesTrajectories.case_grid3D, + ], +) +def test_ndft_nufft(kspace, shape): + "Test that NDFT matches NUFFT" + ndft_op = RawNDFT(kspace, shape, normalize=True) + random_kspace = 1j * np.random.randn(len(kspace)) + random_kspace += np.random.randn(len(kspace)) + random_image = np.random.randn(*shape) + 1j * np.random.randn(*shape) + operator = get_operator("pynfft")(kspace, shape) # FIXME: @PAC, we need to get ref here + nufft_k = operator.op(random_image) + nufft_i = operator.adj_op(random_kspace) + + ndft_k = np.empty(ndft_op.n_samples, dtype=random_image.dtype) + ndft_i = np.empty(shape, dtype=random_kspace.dtype) + ndft_op.op(ndft_k, random_image) + ndft_op.adj_op(random_kspace, ndft_i) + assert_almost_allclose(nufft_k, ndft_k, atol=1e-4, rtol=1e-4, mismatch=5) + assert_almost_allclose(nufft_i, ndft_i, atol=1e-4, rtol=1e-4, mismatch=5) @parametrize_with_cases( @@ -98,6 +100,7 @@ def test_ndft_fft(kspace_grid, shape): kspace = kspace.reshape(img.shape) if len(shape) >= 2: kspace = kspace.swapaxes(0, 1) - kspace_fft = sp.fft.fftn(img) + kspace_fft = sp.fft.fftn(sp.fft.fftshift(img)) + + assert_almost_allclose(kspace, kspace_fft, atol=1e-4, rtol=1e-4, mismatch=5) - assert_almost_allclose(kspace, kspace_fft, atol=1e-5, rtol=1e-5, mismatch=5) From 093edfbe3def12e22c4aec3657f70377c238525b Mon Sep 17 00:00:00 2001 From: chaithyagr Date: Fri, 26 Apr 2024 14:46:50 +0200 Subject: [PATCH 28/50] Fix --- src/mrinufft/operators/interfaces/nudft_numpy.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/src/mrinufft/operators/interfaces/nudft_numpy.py b/src/mrinufft/operators/interfaces/nudft_numpy.py index 9d26d02f..68690cc1 100644 --- a/src/mrinufft/operators/interfaces/nudft_numpy.py +++ b/src/mrinufft/operators/interfaces/nudft_numpy.py @@ -8,17 +8,6 @@ from ..base import FourierOperatorCPU -<<<<<<< Updated upstream -def get_fourier_matrix(ktraj, shape): - """Get the NDFT Fourier Matrix.""" - n = np.prod(shape) - ndim = len(shape) - matrix = np.zeros((len(ktraj), n), dtype=complex) - r = [np.arange(shape[i]) for i in range(ndim)] - grid_r = np.reshape(np.meshgrid(*r, indexing="ij"), (ndim, np.prod(shape))) - traj_grid = ktraj @ grid_r - matrix = np.exp(-2j * np.pi * traj_grid) -======= def get_fourier_matrix(ktraj, shape, dtype=np.complex64, normalize=False): """Get the NDFT Fourier Matrix.""" n = np.prod(shape) @@ -30,7 +19,6 @@ def get_fourier_matrix(ktraj, shape, dtype=np.complex64, normalize=False): matrix = np.exp(-2j * np.pi * traj_grid, dtype=dtype) if normalize: matrix /= (np.sqrt(np.prod(shape)) * np.power(np.sqrt(2), len(shape))) ->>>>>>> Stashed changes return matrix From 74c1ecd6d0a07dcc141156aa506be9d1967d07e5 Mon Sep 17 00:00:00 2001 From: chaithyagr Date: Fri, 26 Apr 2024 14:48:12 +0200 Subject: [PATCH 29/50] Remove bymistake add --- tests/test_autodiff.py | 145 ----------------------------------------- 1 file changed, 145 deletions(-) delete mode 100644 tests/test_autodiff.py diff --git a/tests/test_autodiff.py b/tests/test_autodiff.py deleted file mode 100644 index 01de2f91..00000000 --- a/tests/test_autodiff.py +++ /dev/null @@ -1,145 +0,0 @@ -"""Test the autodiff functionnality.""" - -import numpy as np -from mrinufft.operators.interfaces.nudft_numpy import get_fourier_matrix -import pytest -from pytest_cases import parametrize_with_cases, parametrize, fixture -from case_trajectories import CasesTrajectories -from mrinufft.operators import get_operator - - -from helpers import ( - kspace_from_op, - image_from_op, - to_interface, -) - - -TORCH_AVAILABLE = True -try: - import torch - import torch.testing as tt -except ImportError: - TORCH_AVAILABLE = False - - -@fixture(scope="module") -@parametrize(backend=["cufinufft", "finufft"]) -@parametrize(autograd=["data"]) -@parametrize_with_cases( - "kspace_loc, shape", - cases=[ - CasesTrajectories.case_grid2D, - CasesTrajectories.case_nyquist_radial2D, - ], # 2D cases only for reduced memory footprint. -) -def operator(kspace_loc, shape, backend, autograd): - """Create NUFFT operator with autodiff capabilities.""" - kspace_loc = kspace_loc.astype(np.float32) - - nufft = get_operator(backend_name=backend, autograd=autograd)( - samples=kspace_loc, - shape=shape, - smaps=None, - ) - - return nufft - - -@fixture(scope="module") -def ndft_matrix(operator): - """Get the NDFT matrix from the operator.""" - return get_fourier_matrix(operator.samples, operator.shape, normalize=True) - - -@pytest.mark.parametrize("interface", ["torch-gpu", "torch-cpu"]) -@pytest.mark.skipif(not TORCH_AVAILABLE, reason="Pytorch is not installed") -def test_adjoint_and_grad(operator, ndft_matrix, interface): - """Test the adjoint and gradient of the operator.""" - if operator.backend == "finufft" and "gpu" in interface: - pytest.skip("GPU not supported for finufft backend") - ndft_matrix_torch = to_interface(ndft_matrix, interface=interface) - ksp_data = to_interface(kspace_from_op(operator), interface=interface) - img_data = to_interface(image_from_op(operator), interface=interface) - ksp_data.requires_grad = True - - with torch.autograd.set_detect_anomaly(True): - adj_data = operator.adj_op(ksp_data).reshape(img_data.shape) - adj_data_ndft = (ndft_matrix_torch.conj().T @ ksp_data.flatten()).reshape( - adj_data.shape - ) - loss_nufft = torch.mean(torch.abs(adj_data) ** 2) - loss_ndft = torch.mean(torch.abs(adj_data_ndft) ** 2) - - # Check if nufft and ndft are close in the backprop - grad_ndft_kdata = torch.autograd.grad(loss_ndft, ksp_data, retain_graph=True)[0] - grad_nufft_kdata = torch.autograd.grad(loss_nufft, ksp_data, retain_graph=True)[0] - tt.assert_close(grad_ndft_kdata, grad_nufft_kdata, rtol=1, atol=1) - - -@pytest.mark.parametrize("interface", ["torch-gpu", "torch-cpu"]) -@pytest.mark.skipif(not TORCH_AVAILABLE, reason="Pytorch is not installed") -def test_adjoint_and_gradauto(operator, ndft_matrix, interface): - """Test the adjoint and gradient of the operator using autograd gradcheck.""" - if operator.backend == "finufft" and "gpu" in interface: - pytest.skip("GPU not supported for finufft backend") - - ksp_data = to_interface(kspace_from_op(operator), interface=interface) - ksp_data = torch.ones(ksp_data.shape, requires_grad=True, dtype=ksp_data.dtype) - print(ksp_data.shape) - # todo: tighten the tolerance - assert torch.autograd.gradcheck( - operator.adjoint, - ksp_data, - eps=1e-10, - rtol=1, - atol=1, - nondet_tol=1, - raise_exception=True, - ) - - -@pytest.mark.parametrize("interface", ["torch-gpu", "torch-cpu"]) -@pytest.mark.skipif(not TORCH_AVAILABLE, reason="Pytorch is not installed") -def test_forward_and_grad(operator, ndft_matrix, interface): - """Test the adjoint and gradient of the operator.""" - if operator.backend == "finufft" and "gpu" in interface: - pytest.skip("GPU not supported for finufft backend") - - ndft_matrix_torch = to_interface(ndft_matrix, interface=interface) - ksp_data_ref = to_interface(kspace_from_op(operator), interface=interface) - img_data = to_interface(image_from_op(operator), interface=interface) - img_data.requires_grad = True - - with torch.autograd.set_detect_anomaly(True): - ksp_data = operator.op(img_data).reshape(ksp_data_ref.shape) - ksp_data_ndft = (ndft_matrix_torch @ img_data.flatten()).reshape(ksp_data.shape) - loss_nufft = torch.mean(torch.abs(ksp_data - ksp_data_ref) ** 2) - loss_ndft = torch.mean(torch.abs(ksp_data_ndft - ksp_data_ref) ** 2) - - # Check if nufft and ndft are close in the backprop - grad_ndft_kdata = torch.autograd.grad(loss_ndft, img_data, retain_graph=True)[0] - grad_nufft_kdata = torch.autograd.grad(loss_nufft, img_data, retain_graph=True)[0] - assert torch.allclose(grad_ndft_kdata, grad_nufft_kdata, atol=6e-3) - - -@pytest.mark.parametrize("interface", ["torch-gpu", "torch-cpu"]) -@pytest.mark.skipif(not TORCH_AVAILABLE, reason="Pytorch is not installed") -def test_forward_and_gradauto(operator, ndft_matrix, interface): - """Test the forward and gradient of the operator using autograd gradcheck.""" - if operator.backend == "finufft" and "gpu" in interface: - pytest.skip("GPU not supported for finufft backend") - - img_data = to_interface(image_from_op(operator), interface=interface) - img_data = torch.ones(img_data.shape, requires_grad=True, dtype=img_data.dtype) - print(img_data.shape) - # todo: tighten the tolerance - assert torch.autograd.gradcheck( - operator.adjoint, - img_data, - eps=1e-10, - rtol=1, - atol=1, - nondet_tol=1, - raise_exception=True, - ) From 0250aa8d3753bad0191cdc5f42cd1c112f589f44 Mon Sep 17 00:00:00 2001 From: chaithyagr Date: Fri, 26 Apr 2024 15:38:27 +0200 Subject: [PATCH 30/50] Fix --- tests/test_ndft.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_ndft.py b/tests/test_ndft.py index 5bcb3e8c..3d972ff5 100644 --- a/tests/test_ndft.py +++ b/tests/test_ndft.py @@ -64,7 +64,7 @@ def test_ndft_implicit1(kspace, shape): ], ) def test_ndft_nufft(kspace, shape): - "Test that NDFT matches NUFFT" + """Test that NDFT matches NUFFT""" ndft_op = RawNDFT(kspace, shape, normalize=True) random_kspace = 1j * np.random.randn(len(kspace)) random_kspace += np.random.randn(len(kspace)) From 060a8bdd125d2140b82dfaa6a492ec78682a80bf Mon Sep 17 00:00:00 2001 From: chaithyagr Date: Fri, 26 Apr 2024 15:39:16 +0200 Subject: [PATCH 31/50] Fixed lint --- .../operators/interfaces/nudft_numpy.py | 20 +++++++++++-------- tests/test_ndft.py | 9 +++++---- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/src/mrinufft/operators/interfaces/nudft_numpy.py b/src/mrinufft/operators/interfaces/nudft_numpy.py index 68690cc1..3e8e81aa 100644 --- a/src/mrinufft/operators/interfaces/nudft_numpy.py +++ b/src/mrinufft/operators/interfaces/nudft_numpy.py @@ -13,18 +13,18 @@ def get_fourier_matrix(ktraj, shape, dtype=np.complex64, normalize=False): n = np.prod(shape) ndim = len(shape) matrix = np.zeros((len(ktraj), n), dtype=dtype) - r = [np.linspace(-s/2, s/2-1, s) for s in shape] + r = [np.linspace(-s / 2, s / 2 - 1, s) for s in shape] grid_r = np.reshape(np.meshgrid(*r, indexing="ij"), (ndim, np.prod(shape))) traj_grid = ktraj @ grid_r matrix = np.exp(-2j * np.pi * traj_grid, dtype=dtype) if normalize: - matrix /= (np.sqrt(np.prod(shape)) * np.power(np.sqrt(2), len(shape))) + matrix /= np.sqrt(np.prod(shape)) * np.power(np.sqrt(2), len(shape)) return matrix def implicit_type2_ndft(ktraj, image, shape, normalize=False): """Compute the NDFT using the implicit type 2 (image -> kspace) algorithm.""" - r = [np.linspace(-s/2, s/2-1, s) for s in shape] + r = [np.linspace(-s / 2, s / 2 - 1, s) for s in shape] grid_r = np.reshape( np.meshgrid(*r, indexing="ij"), (len(shape), np.prod(image.shape)) ) @@ -32,19 +32,19 @@ def implicit_type2_ndft(ktraj, image, shape, normalize=False): for j in range(np.prod(image.shape)): res += image[j] * np.exp(-2j * np.pi * ktraj @ grid_r[:, j]) if normalize: - matrix /= (np.sqrt(np.prod(shape)) * np.power(np.sqrt(2), len(shape))) + matrix /= np.sqrt(np.prod(shape)) * np.power(np.sqrt(2), len(shape)) return res def implicit_type1_ndft(ktraj, coeffs, shape, normalize=False): """Compute the NDFT using the implicit type 1 (kspace -> image) algorithm.""" - r = [np.linspace(-s/2, s/2-1, s) for s in shape] + r = [np.linspace(-s / 2, s / 2 - 1, s) for s in shape] grid_r = np.reshape(np.meshgrid(*r, indexing="ij"), (len(shape), np.prod(shape))) res = np.zeros(np.prod(shape), dtype=coeffs.dtype) for i in range(len(ktraj)): res += coeffs[i] * np.exp(2j * np.pi * ktraj[i] @ grid_r) if normalize: - matrix /= (np.sqrt(np.prod(shape)) * np.power(np.sqrt(2), len(shape))) + matrix /= np.sqrt(np.prod(shape)) * np.power(np.sqrt(2), len(shape)) return res @@ -75,9 +75,13 @@ def __init__(self, samples, shape, explicit_matrix=True, normalize=False): ) except MemoryError: warnings.warn("Not enough memory, using an implicit definition anyway") - self._fourier_matrix = get_implicit_matrix(self.samples, self.shape, normalize) + self._fourier_matrix = get_implicit_matrix( + self.samples, self.shape, normalize + ) else: - self._fourier_matrix = get_implicit_matrix(self.samples, self.shape, normalize) + self._fourier_matrix = get_implicit_matrix( + self.samples, self.shape, normalize + ) def op(self, coeffs, image): """Compute the forward NUDFT.""" diff --git a/tests/test_ndft.py b/tests/test_ndft.py index 3d972ff5..7f90d14e 100644 --- a/tests/test_ndft.py +++ b/tests/test_ndft.py @@ -16,7 +16,6 @@ from mrinufft import get_operator - @parametrize_with_cases( "kspace, shape", cases=[ @@ -55,6 +54,7 @@ def test_ndft_implicit1(kspace, shape): assert_almost_allclose(linop_coef, matrix_coef, atol=1e-4, rtol=1e-4, mismatch=5) + @parametrize_with_cases( "kspace, shape", cases=[ @@ -69,10 +69,12 @@ def test_ndft_nufft(kspace, shape): random_kspace = 1j * np.random.randn(len(kspace)) random_kspace += np.random.randn(len(kspace)) random_image = np.random.randn(*shape) + 1j * np.random.randn(*shape) - operator = get_operator("pynfft")(kspace, shape) # FIXME: @PAC, we need to get ref here + operator = get_operator("pynfft")( + kspace, shape + ) # FIXME: @PAC, we need to get ref here nufft_k = operator.op(random_image) nufft_i = operator.adj_op(random_kspace) - + ndft_k = np.empty(ndft_op.n_samples, dtype=random_image.dtype) ndft_i = np.empty(shape, dtype=random_kspace.dtype) ndft_op.op(ndft_k, random_image) @@ -103,4 +105,3 @@ def test_ndft_fft(kspace_grid, shape): kspace_fft = sp.fft.fftn(sp.fft.fftshift(img)) assert_almost_allclose(kspace, kspace_fft, atol=1e-4, rtol=1e-4, mismatch=5) - From aecb844c74ae53ae67deb852204bc9e647ac28fd Mon Sep 17 00:00:00 2001 From: chaithyagr Date: Fri, 26 Apr 2024 15:40:50 +0200 Subject: [PATCH 32/50] Lint --- src/mrinufft/operators/interfaces/nudft_numpy.py | 4 ++-- tests/test_ndft.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/mrinufft/operators/interfaces/nudft_numpy.py b/src/mrinufft/operators/interfaces/nudft_numpy.py index 3e8e81aa..bcc6c033 100644 --- a/src/mrinufft/operators/interfaces/nudft_numpy.py +++ b/src/mrinufft/operators/interfaces/nudft_numpy.py @@ -32,7 +32,7 @@ def implicit_type2_ndft(ktraj, image, shape, normalize=False): for j in range(np.prod(image.shape)): res += image[j] * np.exp(-2j * np.pi * ktraj @ grid_r[:, j]) if normalize: - matrix /= np.sqrt(np.prod(shape)) * np.power(np.sqrt(2), len(shape)) + res /= np.sqrt(np.prod(shape)) * np.power(np.sqrt(2), len(shape)) return res @@ -44,7 +44,7 @@ def implicit_type1_ndft(ktraj, coeffs, shape, normalize=False): for i in range(len(ktraj)): res += coeffs[i] * np.exp(2j * np.pi * ktraj[i] @ grid_r) if normalize: - matrix /= np.sqrt(np.prod(shape)) * np.power(np.sqrt(2), len(shape)) + res /= np.sqrt(np.prod(shape)) * np.power(np.sqrt(2), len(shape)) return res diff --git a/tests/test_ndft.py b/tests/test_ndft.py index 7f90d14e..fa66b8b2 100644 --- a/tests/test_ndft.py +++ b/tests/test_ndft.py @@ -64,7 +64,7 @@ def test_ndft_implicit1(kspace, shape): ], ) def test_ndft_nufft(kspace, shape): - """Test that NDFT matches NUFFT""" + """Test that NDFT matches NUFFT.""" ndft_op = RawNDFT(kspace, shape, normalize=True) random_kspace = 1j * np.random.randn(len(kspace)) random_kspace += np.random.randn(len(kspace)) From 3130bc1c5f443294a2f71dcae30178bb8357d392 Mon Sep 17 00:00:00 2001 From: chaithyagr Date: Fri, 26 Apr 2024 17:15:00 +0200 Subject: [PATCH 33/50] Added refbackend --- tests/test_ndft.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_ndft.py b/tests/test_ndft.py index fa66b8b2..57aedfa6 100644 --- a/tests/test_ndft.py +++ b/tests/test_ndft.py @@ -63,18 +63,18 @@ def test_ndft_implicit1(kspace, shape): CasesTrajectories.case_grid3D, ], ) -def test_ndft_nufft(kspace, shape): +def test_ndft_nufft(kspace, shape, request): """Test that NDFT matches NUFFT.""" ndft_op = RawNDFT(kspace, shape, normalize=True) random_kspace = 1j * np.random.randn(len(kspace)) random_kspace += np.random.randn(len(kspace)) random_image = np.random.randn(*shape) + 1j * np.random.randn(*shape) - operator = get_operator("pynfft")( + operator = get_operator(request.config.getoption("ref"))( kspace, shape - ) # FIXME: @PAC, we need to get ref here + ) nufft_k = operator.op(random_image) nufft_i = operator.adj_op(random_kspace) - + ndft_k = np.empty(ndft_op.n_samples, dtype=random_image.dtype) ndft_i = np.empty(shape, dtype=random_kspace.dtype) ndft_op.op(ndft_k, random_image) From bc014b8973e3b355f17365a9eb933cc57b92fb4b Mon Sep 17 00:00:00 2001 From: chaithyagr Date: Fri, 26 Apr 2024 17:17:48 +0200 Subject: [PATCH 34/50] Fix NDFT --- tests/test_ndft.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/test_ndft.py b/tests/test_ndft.py index 57aedfa6..7a157d34 100644 --- a/tests/test_ndft.py +++ b/tests/test_ndft.py @@ -69,12 +69,10 @@ def test_ndft_nufft(kspace, shape, request): random_kspace = 1j * np.random.randn(len(kspace)) random_kspace += np.random.randn(len(kspace)) random_image = np.random.randn(*shape) + 1j * np.random.randn(*shape) - operator = get_operator(request.config.getoption("ref"))( - kspace, shape - ) + operator = get_operator(request.config.getoption("ref"))(kspace, shape) nufft_k = operator.op(random_image) nufft_i = operator.adj_op(random_kspace) - + ndft_k = np.empty(ndft_op.n_samples, dtype=random_image.dtype) ndft_i = np.empty(shape, dtype=random_kspace.dtype) ndft_op.op(ndft_k, random_image) From 0cc73c41cf743ea19ffa053f0cd43b54f43f192e Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Mon, 29 Apr 2024 10:48:25 +0200 Subject: [PATCH 35/50] feat: use finufft as ref backend. --- tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 69598fdb..4e89f0ed 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,7 +15,7 @@ def pytest_addoption(parser): ) parser.addoption( "--ref", - default="pynfft", + default="finufft", help="Reference backend on which the tests are performed.", ) From 21e090f21803e9e57cc721010661d30449ce1a0b Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Mon, 29 Apr 2024 10:49:37 +0200 Subject: [PATCH 36/50] feat(tests): move ndft vs nufft tests to own file. --- tests/operators/test_operator_ref.py | 74 ++++++++++++++++++++++++++++ tests/test_ndft.py | 26 ---------- 2 files changed, 74 insertions(+), 26 deletions(-) create mode 100644 tests/operators/test_operator_ref.py diff --git a/tests/operators/test_operator_ref.py b/tests/operators/test_operator_ref.py new file mode 100644 index 00000000..b51e1633 --- /dev/null +++ b/tests/operators/test_operator_ref.py @@ -0,0 +1,74 @@ +"""Tests for the reference backend.""" + +from pytest_cases import parametrize_with_cases, fixture +from case_trajectories import CasesTrajectories + +from mrinufft import get_operator +from mrinufft.operators.interfaces.nudft_numpy import MRInumpy +from helpers import assert_almost_allclose, kspace_from_op, image_from_op + + +@fixture(scope="session", autouse=True) +def ref_backend(request): + """Get the reference backend from the CLI.""" + return request.config.getoption("ref") + + +@fixture(scope="module") +@parametrize_with_cases( + "kspace, shape", + cases=[ + CasesTrajectories.case_random2D, + CasesTrajectories.case_grid2D, + CasesTrajectories.case_grid3D, + ], +) +def ref_operator(request, ref_backend, kspace, shape): + """Generate a NFFT operator, matching the property of the first operator.""" + return get_operator(ref_backend)(kspace, shape) + + +@fixture(scope="module") +def ndft_operator(ref_operator): + """Get a NDFT operator matching the reference operator.""" + return MRInumpy(ref_operator.samples, ref_operator.shape) + + +@fixture(scope="module") +def image_data(ref_operator): + """Generate a random image. Remains constant for the module.""" + return image_from_op(ref_operator) + + +@fixture(scope="module") +def kspace_data(ref_operator): + """Generate a random kspace. Remains constant for the module.""" + return kspace_from_op(ref_operator) + + +def test_ref_nufft_forward(ref_operator, ndft_operator, image_data): + """Test that the reference nufft matches the NDFT.""" + nufft_ksp = ref_operator.op(image_data) + ndft_ksp = ndft_operator.op(image_data) + + assert_almost_allclose( + nufft_ksp, + ndft_ksp, + atol=1e-4, + rtol=1e-4, + mismatch=5, + ) + + +def test_ref_nufft_adjoint(ref_operator, ndft_operator, kspace_data): + """Test that the reference nufft matches the NDFT adjoint.""" + nufft_img = ref_operator.adj_op(kspace_data) + ndft_img = ndft_operator.adj_op(kspace_data) + + assert_almost_allclose( + nufft_img, + ndft_img, + atol=1e-4, + rtol=1e-4, + mismatch=5, + ) diff --git a/tests/test_ndft.py b/tests/test_ndft.py index 7a157d34..cd21622e 100644 --- a/tests/test_ndft.py +++ b/tests/test_ndft.py @@ -55,32 +55,6 @@ def test_ndft_implicit1(kspace, shape): assert_almost_allclose(linop_coef, matrix_coef, atol=1e-4, rtol=1e-4, mismatch=5) -@parametrize_with_cases( - "kspace, shape", - cases=[ - CasesTrajectories.case_random2D, - CasesTrajectories.case_grid2D, - CasesTrajectories.case_grid3D, - ], -) -def test_ndft_nufft(kspace, shape, request): - """Test that NDFT matches NUFFT.""" - ndft_op = RawNDFT(kspace, shape, normalize=True) - random_kspace = 1j * np.random.randn(len(kspace)) - random_kspace += np.random.randn(len(kspace)) - random_image = np.random.randn(*shape) + 1j * np.random.randn(*shape) - operator = get_operator(request.config.getoption("ref"))(kspace, shape) - nufft_k = operator.op(random_image) - nufft_i = operator.adj_op(random_kspace) - - ndft_k = np.empty(ndft_op.n_samples, dtype=random_image.dtype) - ndft_i = np.empty(shape, dtype=random_kspace.dtype) - ndft_op.op(ndft_k, random_image) - ndft_op.adj_op(random_kspace, ndft_i) - assert_almost_allclose(nufft_k, ndft_k, atol=1e-4, rtol=1e-4, mismatch=5) - assert_almost_allclose(nufft_i, ndft_i, atol=1e-4, rtol=1e-4, mismatch=5) - - @parametrize_with_cases( "kspace_grid, shape", cases=[ From 7afdd8e82fe4a29ecb51fcf626843f8b6a18bbfb Mon Sep 17 00:00:00 2001 From: chaithyagr Date: Tue, 30 Apr 2024 09:06:12 +0200 Subject: [PATCH 37/50] Added rebart --- src/mrinufft/io/__init__.py | 5 ++- src/mrinufft/io/nsp.py | 33 +++-------------- src/mrinufft/io/siemens.py | 73 +++++++++++++++++++++++++++++++++++++ 3 files changed, 83 insertions(+), 28 deletions(-) create mode 100644 src/mrinufft/io/siemens.py diff --git a/src/mrinufft/io/__init__.py b/src/mrinufft/io/__init__.py index 039d0600..f77d3b32 100644 --- a/src/mrinufft/io/__init__.py +++ b/src/mrinufft/io/__init__.py @@ -1,7 +1,8 @@ """Input/Output module for trajectories and data.""" from .cfl import traj2cfl, cfl2traj -from .nsp import read_trajectory, write_trajectory +from .nsp import read_trajectory, write_trajectory, read_arbgrad_rawdat +from .siemens import read_siemens_rawdat __all__ = [ @@ -9,4 +10,6 @@ "cfl2traj", "read_trajectory", "write_trajectory", + "read_arbgrad_rawdat", + "read_siemens_rawdat", ] diff --git a/src/mrinufft/io/nsp.py b/src/mrinufft/io/nsp.py index 8021651b..4b9dbfd1 100644 --- a/src/mrinufft/io/nsp.py +++ b/src/mrinufft/io/nsp.py @@ -6,6 +6,7 @@ import numpy as np from datetime import datetime from array import array +from .siemens import read_siemens_rawdat from mrinufft.trajectories.utils import ( KMAX, @@ -392,7 +393,7 @@ def read_trajectory( return kspace_loc, params -def read_siemens_rawdat( +def read_arbgrad_rawdat( filename: str, removeOS: bool = False, squeeze: bool = True, @@ -429,32 +430,10 @@ def read_siemens_rawdat( You can install it using the following command: `pip install pymapVBVD` """ - try: - from mapvbvd import mapVBVD - except ImportError as err: - raise ImportError( - "The mapVBVD module is not available. Please install it using " - "the following command: pip install pymapVBVD" - ) from err - twixObj = mapVBVD(filename) - if isinstance(twixObj, list): - twixObj = twixObj[-1] - twixObj.image.flagRemoveOS = removeOS - twixObj.image.squeeze = squeeze - raw_kspace = twixObj.image[""] - data = np.moveaxis(raw_kspace, 0, 2) - hdr = { - "n_coils": int(twixObj.image.NCha), - "n_shots": int(twixObj.image.NLin), - "n_contrasts": int(twixObj.image.NSet), - "n_adc_samples": int(twixObj.image.NCol), - "n_slices": int(twixObj.image.NSli), - } - data = data.reshape( - hdr["n_coils"], - hdr["n_shots"] * hdr["n_adc_samples"], - hdr["n_slices"], - hdr["n_contrasts"], + data, hdr, twixObj = read_siemens_rawdat( + filename=filename, + removeOS=removeOS, + squeeze=squeeze, ) if "ARBGRAD_VE11C" in data_type: hdr["type"] = "ARBGRAD_GRE" diff --git a/src/mrinufft/io/siemens.py b/src/mrinufft/io/siemens.py new file mode 100644 index 00000000..5a230898 --- /dev/null +++ b/src/mrinufft/io/siemens.py @@ -0,0 +1,73 @@ +"""sasa""" +import numpy as np + + +def read_siemens_rawdat( + filename: str, + removeOS: bool = False, + squeeze: bool = True, + return_twix: bool = True, +): # pragma: no cover + """Read raw data from a Siemens MRI file. + + Parameters + ---------- + filename : str + The path to the Siemens MRI file. + removeOS : bool, optional + Whether to remove the oversampling, by default False. + squeeze : bool, optional + Whether to squeeze the dimensions of the data, by default True. + data_type : str, optional + The type of data to read, by default 'ARBGRAD_VE11C'. + return_twix : bool, optional + Whether to return the twix object, by default True. + + Returns + ------- + data: ndarray + Imported data formatted as n_coils X n_samples X n_slices X n_contrasts + hdr: dict + Extra information about the data parsed from the twix file + + Raises + ------ + ImportError + If the mapVBVD module is not available. + + Notes + ----- + This function requires the mapVBVD module to be installed. + You can install it using the following command: + `pip install pymapVBVD` + """ + try: + from mapvbvd import mapVBVD + except ImportError as err: + raise ImportError( + "The mapVBVD module is not available. Please install it using " + "the following command: pip install pymapVBVD" + ) from err + twixObj = mapVBVD(filename) + if isinstance(twixObj, list): + twixObj = twixObj[-1] + twixObj.image.flagRemoveOS = removeOS + twixObj.image.squeeze = squeeze + raw_kspace = twixObj.image[""] + data = np.moveaxis(raw_kspace, 0, 2) + hdr = { + "n_coils": int(twixObj.image.NCha), + "n_shots": int(twixObj.image.NLin), + "n_contrasts": int(twixObj.image.NSet), + "n_adc_samples": int(twixObj.image.NCol), + "n_slices": int(twixObj.image.NSli), + } + data = data.reshape( + hdr["n_coils"], + hdr["n_shots"] * hdr["n_adc_samples"], + hdr["n_slices"], + hdr["n_contrasts"], + ) + if return_twix: + return data, hdr, twixObj + return data, hdr From 140921e46d1c0b2a369bcbdd8de3bcf049ca2fa3 Mon Sep 17 00:00:00 2001 From: chaithyagr Date: Tue, 30 Apr 2024 16:17:17 +0200 Subject: [PATCH 38/50] Update codes --- pyproject.toml | 1 + src/mrinufft/extras/smaps.py | 22 ++++++++++++------- src/mrinufft/operators/interfaces/gpunufft.py | 1 - 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d3274501..8f7e00f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ finufft = ["finufft"] pynfft = ["pynfft2", "cython<3.0.0"] pynufft = ["pynufft"] io = ["pymapvbvd"] +smaps = ["scikit-image"] test = ["pytest<8.0.0", "pytest-cov", "pytest-xdist", "pytest-sugar", "pytest-cases"] dev = ["black", "isort", "ruff"] diff --git a/src/mrinufft/extras/smaps.py b/src/mrinufft/extras/smaps.py index de8c1236..b3319914 100644 --- a/src/mrinufft/extras/smaps.py +++ b/src/mrinufft/extras/smaps.py @@ -4,6 +4,7 @@ from skimage.morphology import convex_hull_image from .utils import register_smaps import numpy as np +from typing import Tuple def _extract_kspace_center( @@ -81,15 +82,16 @@ def _extract_kspace_center( window = xp.sum(kspace_loc**2/ xp.asarray(threshold)**2, axis=1) <= 1 else: raise ValueError("Unsupported window function.") - data_thresholded = window * data_thresholded + data_thresholded = window * kspace_data # Return k-space locations & density just for consistency return data_thresholded, kspace_loc, density @register_smaps @flat_traj -def low_frequency(traj, shape, kspace_data, threshold, backend, density=None, - extract_kwargs=None, blurr_factor=0, mask=True): +def low_frequency(traj, shape, kspace_data, backend, threshold: float|Tuple[float, ...] = 0.1, + density=None, window_fun: str = "ellipse", blurr_factor: float = 0, + mask: bool = True): """ Calculate low-frequency sensitivity maps. @@ -101,14 +103,19 @@ def low_frequency(traj, shape, kspace_data, threshold, backend, density=None, The shape of the image. kspace_data : numpy.ndarray The k-space data. - threshold : float + threshold : float, or tuple of float, optional The threshold used for extracting the k-space center. + By default it is 0.1 backend : str The backend used for the operator. density : numpy.ndarray, optional The density compensation weights. - extract_kwargs : dict, optional - Additional keyword arguments for the `extract_kspace_center` function. + window_fun: "Hann", "Hanning", "Hamming", or a callable, default None. + The window function to apply to the selected data. It is computed with + the center locations selected. Only works with circular mask. + If window_fun is a callable, it takes as input the array (n_samples x n_dims) + of sample positions and returns an array of n_samples weights to be + applied to the selected k-space values, before the smaps estimation. blurr_factor : float, optional The blurring factor for smoothing the sensitivity maps. mask: bool, optional default `True` @@ -128,8 +135,7 @@ def low_frequency(traj, shape, kspace_data, threshold, backend, density=None, kspace_loc=traj, threshold=threshold, density=density, - img_shape=shape, - **(extract_kwargs or {}), + window_fun=window_fun, ) smaps_adj_op = get_operator(backend)( samples, diff --git a/src/mrinufft/operators/interfaces/gpunufft.py b/src/mrinufft/operators/interfaces/gpunufft.py index 092cd68f..f43e5c80 100644 --- a/src/mrinufft/operators/interfaces/gpunufft.py +++ b/src/mrinufft/operators/interfaces/gpunufft.py @@ -377,7 +377,6 @@ def __init__( self.dtype = self.samples.dtype self.n_coils = n_coils self.n_batchs = n_batchs - self.smaps = smaps self.squeeze_dims = squeeze_dims self.compute_density(density) self.compute_smaps(smaps) From 90bf8325c2fbe07aa8a0aa8f065f1bb59093052b Mon Sep 17 00:00:00 2001 From: chaithyagr Date: Tue, 30 Apr 2024 16:52:46 +0200 Subject: [PATCH 39/50] updated mask --- src/mrinufft/extras/smaps.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/mrinufft/extras/smaps.py b/src/mrinufft/extras/smaps.py index b3319914..1caf3cd1 100644 --- a/src/mrinufft/extras/smaps.py +++ b/src/mrinufft/extras/smaps.py @@ -91,7 +91,7 @@ def _extract_kspace_center( @flat_traj def low_frequency(traj, shape, kspace_data, backend, threshold: float|Tuple[float, ...] = 0.1, density=None, window_fun: str = "ellipse", blurr_factor: float = 0, - mask: bool = True): + mask: bool = False): """ Calculate low-frequency sensitivity maps. @@ -118,7 +118,7 @@ def low_frequency(traj, shape, kspace_data, backend, threshold: float|Tuple[floa applied to the selected k-space values, before the smaps estimation. blurr_factor : float, optional The blurring factor for smoothing the sensitivity maps. - mask: bool, optional default `True` + mask: bool, optional default `False` Whether the Sensitivity maps must be masked Returns From ae50a8e27340e8c8f4de96990b781abf503fa9d7 Mon Sep 17 00:00:00 2001 From: chaithyagr Date: Tue, 21 May 2024 18:30:35 +0200 Subject: [PATCH 40/50] Fixs --- src/mrinufft/extras/smaps.py | 65 +++++++++++++++++++++++------------- 1 file changed, 41 insertions(+), 24 deletions(-) diff --git a/src/mrinufft/extras/smaps.py b/src/mrinufft/extras/smaps.py index 1caf3cd1..68d28624 100644 --- a/src/mrinufft/extras/smaps.py +++ b/src/mrinufft/extras/smaps.py @@ -1,3 +1,5 @@ +"""SMaps module for sensitivity maps estimation.""" + from mrinufft.density.utils import flat_traj from mrinufft.operators.base import get_array_module from skimage.filters import threshold_otsu, gaussian @@ -8,15 +10,19 @@ def _extract_kspace_center( - kspace_data, kspace_loc, threshold=None, density=None, window_fun="ellipse", - ): + kspace_data, + kspace_loc, + threshold=None, + density=None, + window_fun="ellipse", +): r"""Extract k-space center and corresponding sampling locations. - + The extracted center of the k-space, i.e. both the kspace locations and kspace values. If the density compensators are passed, the corresponding compensators for the center of k-space data will also be returned. The return dtypes for density compensation and kspace data is same as input - + Parameters ---------- kspace_data: numpy.ndarray @@ -31,16 +37,16 @@ def _extract_kspace_center( If window_fun is a callable, it takes as input the array (n_samples x n_dims) of sample positions and returns an array of n_samples weights to be applied to the selected k-space values, before the smaps estimation. - + Returns ------- - data_thresholded: ndarray + data_thresholded: ndarray The k-space values in the center region. center_loc: ndarray The locations in the center region. - density_comp: ndarray, optional + density_comp: ndarray, optional The density compensation weights (if requested) - + Notes ----- The Hann (or Hanning) and Hamming windows of width :math:`2\theta` are defined as: @@ -58,13 +64,17 @@ def _extract_kspace_center( xp = get_array_module(kspace_data) if isinstance(threshold, float): threshold = (threshold,) * kspace_loc.shape[1] - + if window_fun == "rect": data_ordered = xp.copy(kspace_data) - index = xp.linspace(0, kspace_loc.shape[0] - 1, kspace_loc.shape[0], dtype=xp.int64) - condition = xp.logical_and.reduce(tuple( - xp.abs(kspace_loc[:, i]) <= threshold[i] for i in range(len(threshold)) - )) + index = xp.linspace( + 0, kspace_loc.shape[0] - 1, kspace_loc.shape[0], dtype=xp.int64 + ) + condition = xp.logical_and.reduce( + tuple( + xp.abs(kspace_loc[:, i]) <= threshold[i] for i in range(len(threshold)) + ) + ) index = xp.extract(condition, index) center_locations = kspace_loc[index, :] data_thresholded = data_ordered[:, index] @@ -79,7 +89,9 @@ def _extract_kspace_center( a_0 = 0.5 if window_fun in ["hann", "hanning"] else 0.53836 window = a_0 + (1 - a_0) * xp.cos(xp.pi * radius / threshold) elif window_fun == "ellipse": - window = xp.sum(kspace_loc**2/ xp.asarray(threshold)**2, axis=1) <= 1 + window = ( + xp.sum(kspace_loc ** 2 / xp.asarray(threshold) ** 2, axis=1) <= 1 + ) else: raise ValueError("Unsupported window function.") data_thresholded = window * kspace_data @@ -89,9 +101,17 @@ def _extract_kspace_center( @register_smaps @flat_traj -def low_frequency(traj, shape, kspace_data, backend, threshold: float|Tuple[float, ...] = 0.1, - density=None, window_fun: str = "ellipse", blurr_factor: float = 0, - mask: bool = False): +def low_frequency( + traj, + shape, + kspace_data, + backend, + threshold: float | Tuple[float, ...] = 0.1, + density=None, + window_fun: str = "ellipse", + blurr_factor: float = 0, + mask: bool = False, +): """ Calculate low-frequency sensitivity maps. @@ -130,6 +150,7 @@ def low_frequency(traj, shape, kspace_data, backend, threshold: float|Tuple[floa """ # defer import to later to prevent circular import from mrinufft import get_operator + k_space, samples, dc = _extract_kspace_center( kspace_data=kspace_data, kspace_loc=traj, @@ -138,18 +159,15 @@ def low_frequency(traj, shape, kspace_data, backend, threshold: float|Tuple[floa window_fun=window_fun, ) smaps_adj_op = get_operator(backend)( - samples, - shape, - density=dc, - n_coils=k_space.shape[0] + samples, shape, density=dc, n_coils=k_space.shape[0] ) Smaps = smaps_adj_op.adj_op(k_space) SOS = np.linalg.norm(Smaps, axis=0) if mask: thresh = threshold_otsu(SOS) # Create convex hull from mask - convex_hull = convex_hull_image(SOS>thresh) - Smaps = Smaps * convex_hull + convex_hull = convex_hull_image(SOS > thresh) + Smaps = Smaps * convex_hull # Smooth out the sensitivity maps if blurr_factor > 0: Smaps = gaussian(Smaps, sigma=blurr_factor * np.asarray(shape)) @@ -159,4 +177,3 @@ def low_frequency(traj, shape, kspace_data, backend, threshold: float|Tuple[floa SOS = np.linalg.norm(Smaps, axis=0) + 1e-10 Smaps = Smaps / SOS return Smaps, SOS - \ No newline at end of file From 2013cf100d6efa8cd75c3a510334fead97eac8c3 Mon Sep 17 00:00:00 2001 From: chaithyagr Date: Tue, 21 May 2024 18:31:53 +0200 Subject: [PATCH 41/50] PEP --- src/mrinufft/extras/utils.py | 1 + src/mrinufft/io/siemens.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/mrinufft/extras/utils.py b/src/mrinufft/extras/utils.py index faf56a0c..d4d5fa4e 100644 --- a/src/mrinufft/extras/utils.py +++ b/src/mrinufft/extras/utils.py @@ -1,3 +1,4 @@ +"""Utils for extras module.""" from mrinufft._utils import MethodRegister register_smaps = MethodRegister("sensitivity_maps") diff --git a/src/mrinufft/io/siemens.py b/src/mrinufft/io/siemens.py index 5a230898..b3cc79c2 100644 --- a/src/mrinufft/io/siemens.py +++ b/src/mrinufft/io/siemens.py @@ -1,4 +1,4 @@ -"""sasa""" +"""Siemens specific rawdat reader, wrapper over pymapVBVD.""" import numpy as np From 064f9c1374f91de9726860c3415dc2197c8f548a Mon Sep 17 00:00:00 2001 From: chaithyagr Date: Thu, 23 May 2024 16:20:23 +0200 Subject: [PATCH 42/50] Add lint fixes --- src/mrinufft/extras/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/mrinufft/extras/utils.py b/src/mrinufft/extras/utils.py index d4d5fa4e..41988b1f 100644 --- a/src/mrinufft/extras/utils.py +++ b/src/mrinufft/extras/utils.py @@ -3,6 +3,7 @@ register_smaps = MethodRegister("sensitivity_maps") + def get_smaps(name, *args, **kwargs): """Get the density compensation function from its name.""" try: @@ -15,4 +16,4 @@ def get_smaps(name, *args, **kwargs): if args or kwargs: return method(*args, **kwargs) - return method \ No newline at end of file + return method From 7de40f62f58bab59a4ed4f1917f92714f349e32f Mon Sep 17 00:00:00 2001 From: chaithyagr Date: Thu, 23 May 2024 16:21:15 +0200 Subject: [PATCH 43/50] Added PEP fixes --- src/mrinufft/density/geometry_based.py | 2 +- src/mrinufft/operators/base.py | 6 +++--- src/mrinufft/trajectories/maths.py | 6 +++--- src/mrinufft/trajectories/tools.py | 4 ++-- src/mrinufft/trajectories/trajectory3D.py | 10 +++++----- 5 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/mrinufft/density/geometry_based.py b/src/mrinufft/density/geometry_based.py index 4dc0ecc5..cb091831 100644 --- a/src/mrinufft/density/geometry_based.py +++ b/src/mrinufft/density/geometry_based.py @@ -87,7 +87,7 @@ def voronoi_unique(traj, *args, **kwargs): # For edge point (infinite voronoi cells) we extrapolate from neighbours # Initial implementation in Jeff Fessler's MIRT - rho = np.sum(traj**2, axis=1) + rho = np.sum(traj ** 2, axis=1) igood = (rho > 0.6 * np.max(rho)) & ~np.isinf(wi) if len(igood) < 10: print("dubious extrapolation with", len(igood), "points") diff --git a/src/mrinufft/operators/base.py b/src/mrinufft/operators/base.py index f60f24be..034d6bcf 100644 --- a/src/mrinufft/operators/base.py +++ b/src/mrinufft/operators/base.py @@ -253,7 +253,7 @@ def with_off_resonnance_correction(self, B, C, indices): from ..off_resonnance import MRIFourierCorrected return MRIFourierCorrected(self, B, C, indices) - + def compute_smaps(self, method=None): """Compute the sensitivity maps and set it. @@ -265,7 +265,7 @@ def compute_smaps(self, method=None): If a dict, it should have a key 'name', to determine which method to use. other items will be used as kwargs. If a callable, it should take the samples and the shape as input. - Note that this callable function should also hold the k-space data + Note that this callable function should also hold the k-space data (use funtools.partial) """ if isinstance(method, np.ndarray): @@ -287,7 +287,7 @@ def compute_smaps(self, method=None): self.shape, density=self.density, backend=self.backend, - **kwargs + **kwargs, ) def make_autograd(self, variable="data"): diff --git a/src/mrinufft/trajectories/maths.py b/src/mrinufft/trajectories/maths.py index d8ad192f..eb8310e9 100644 --- a/src/mrinufft/trajectories/maths.py +++ b/src/mrinufft/trajectories/maths.py @@ -187,19 +187,19 @@ def Ra(vector, theta): return np.array( [ [ - cos_t + v_x**2 * (1 - cos_t), + cos_t + v_x ** 2 * (1 - cos_t), v_x * v_y * (1 - cos_t) + v_z * sin_t, v_x * v_z * (1 - cos_t) - v_y * sin_t, ], [ v_y * v_x * (1 - cos_t) - v_z * sin_t, - cos_t + v_y**2 * (1 - cos_t), + cos_t + v_y ** 2 * (1 - cos_t), v_y * v_z * (1 - cos_t) + v_x * sin_t, ], [ v_z * v_x * (1 - cos_t) + v_y * sin_t, v_z * v_y * (1 - cos_t) - v_x * sin_t, - cos_t + v_z**2 * (1 - cos_t), + cos_t + v_z ** 2 * (1 - cos_t), ], ] ) diff --git a/src/mrinufft/trajectories/tools.py b/src/mrinufft/trajectories/tools.py index a3feebcd..972d013d 100644 --- a/src/mrinufft/trajectories/tools.py +++ b/src/mrinufft/trajectories/tools.py @@ -310,7 +310,7 @@ def stack_spherically( # Attribute shots to stacks following density proportional to surface Nc_per_stack = np.ones(nb_stacks).astype(int) - density = radii**2 # simplified version + density = radii ** 2 # simplified version for _ in range(Nc - nb_stacks): idx = np.argmax(density / Nc_per_stack) Nc_per_stack[idx] += 1 @@ -406,7 +406,7 @@ def shellify( ) # Carve upper hemisphere from trajectory - z_coords = KMAX**2 - shell_upper[..., 0] ** 2 - shell_upper[..., 1] ** 2 + z_coords = KMAX ** 2 - shell_upper[..., 0] ** 2 - shell_upper[..., 1] ** 2 z_signs = np.sign(z_coords) shell_upper[..., 2] += z_signs * np.sqrt(np.abs(z_coords)) diff --git a/src/mrinufft/trajectories/trajectory3D.py b/src/mrinufft/trajectories/trajectory3D.py index aab30197..bc2255b5 100644 --- a/src/mrinufft/trajectories/trajectory3D.py +++ b/src/mrinufft/trajectories/trajectory3D.py @@ -240,7 +240,7 @@ def initialize_3D_wave_caipi( elif packing == Packings.CIRCLE: positions = [[0, 0]] counter = 0 - while len(positions) < side**2: + while len(positions) < side ** 2: counter += 1 perimeter = 2 * np.pi * counter nb_shots = int(np.trunc(perimeter)) @@ -353,11 +353,11 @@ def initialize_3D_seiffert_spiral( """ # Normalize ellipses integrations by the requested period spiral = np.zeros((1, Ns // (1 + in_out), 3)) - period = 4 * ellipk(curve_index**2) + period = 4 * ellipk(curve_index ** 2) times = np.linspace(0, nb_revolutions * period, Ns // (1 + in_out), endpoint=False) # Initialize first shot - jacobi = ellipj(times, curve_index**2) + jacobi = ellipj(times, curve_index ** 2) spiral[0, :, 0] = jacobi[0] * np.cos(curve_index * times) spiral[0, :, 1] = jacobi[0] * np.sin(curve_index * times) spiral[0, :, 2] = jacobi[1] @@ -655,7 +655,7 @@ def initialize_3D_seiffert_shells( Nc_per_shell[idx] += 1 # Normalize ellipses integrations by the requested period - period = 4 * ellipk(curve_index**2) + period = 4 * ellipk(curve_index ** 2) times = np.linspace(0, nb_revolutions * period, Ns, endpoint=False) # Create shells one by one @@ -667,7 +667,7 @@ def initialize_3D_seiffert_shells( k0 = radii[i] # Initialize first shot - jacobi = ellipj(times, curve_index**2) + jacobi = ellipj(times, curve_index ** 2) trajectory[count, :, 0] = k0 * jacobi[0] * np.cos(curve_index * times) trajectory[count, :, 1] = k0 * jacobi[0] * np.sin(curve_index * times) trajectory[count, :, 2] = k0 * jacobi[1] From 238ec00a458d4fe37156da441c1dfc522b9bfbaf Mon Sep 17 00:00:00 2001 From: chaithyagr Date: Thu, 23 May 2024 16:36:49 +0200 Subject: [PATCH 44/50] Black --- examples/example_density.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/example_density.py b/examples/example_density.py index 98950aee..38280edd 100644 --- a/examples/example_density.py +++ b/examples/example_density.py @@ -111,7 +111,7 @@ # %% flat_traj = traj.reshape(-1, 2) -weights = np.sqrt(np.sum(flat_traj**2, axis=1)) +weights = np.sqrt(np.sum(flat_traj ** 2, axis=1)) nufft = get_operator("finufft")(traj, shape=mri_2D.shape, density=weights) adjoint_manual = nufft.adj_op(kspace) fig, axs = plt.subplots(1, 3, figsize=(15, 5)) From a58962c2722a16b1473dc845266ef2d840029bd3 Mon Sep 17 00:00:00 2001 From: chaithyagr Date: Thu, 23 May 2024 16:48:44 +0200 Subject: [PATCH 45/50] Fix black --- examples/example_density.py | 2 +- src/mrinufft/density/geometry_based.py | 2 +- src/mrinufft/extras/smaps.py | 4 +--- src/mrinufft/extras/utils.py | 1 + src/mrinufft/io/siemens.py | 1 + src/mrinufft/trajectories/maths.py | 6 +++--- src/mrinufft/trajectories/tools.py | 4 ++-- src/mrinufft/trajectories/trajectory3D.py | 10 +++++----- 8 files changed, 15 insertions(+), 15 deletions(-) diff --git a/examples/example_density.py b/examples/example_density.py index 38280edd..98950aee 100644 --- a/examples/example_density.py +++ b/examples/example_density.py @@ -111,7 +111,7 @@ # %% flat_traj = traj.reshape(-1, 2) -weights = np.sqrt(np.sum(flat_traj ** 2, axis=1)) +weights = np.sqrt(np.sum(flat_traj**2, axis=1)) nufft = get_operator("finufft")(traj, shape=mri_2D.shape, density=weights) adjoint_manual = nufft.adj_op(kspace) fig, axs = plt.subplots(1, 3, figsize=(15, 5)) diff --git a/src/mrinufft/density/geometry_based.py b/src/mrinufft/density/geometry_based.py index cb091831..4dc0ecc5 100644 --- a/src/mrinufft/density/geometry_based.py +++ b/src/mrinufft/density/geometry_based.py @@ -87,7 +87,7 @@ def voronoi_unique(traj, *args, **kwargs): # For edge point (infinite voronoi cells) we extrapolate from neighbours # Initial implementation in Jeff Fessler's MIRT - rho = np.sum(traj ** 2, axis=1) + rho = np.sum(traj**2, axis=1) igood = (rho > 0.6 * np.max(rho)) & ~np.isinf(wi) if len(igood) < 10: print("dubious extrapolation with", len(igood), "points") diff --git a/src/mrinufft/extras/smaps.py b/src/mrinufft/extras/smaps.py index 68d28624..8a33f353 100644 --- a/src/mrinufft/extras/smaps.py +++ b/src/mrinufft/extras/smaps.py @@ -89,9 +89,7 @@ def _extract_kspace_center( a_0 = 0.5 if window_fun in ["hann", "hanning"] else 0.53836 window = a_0 + (1 - a_0) * xp.cos(xp.pi * radius / threshold) elif window_fun == "ellipse": - window = ( - xp.sum(kspace_loc ** 2 / xp.asarray(threshold) ** 2, axis=1) <= 1 - ) + window = xp.sum(kspace_loc**2 / xp.asarray(threshold) ** 2, axis=1) <= 1 else: raise ValueError("Unsupported window function.") data_thresholded = window * kspace_data diff --git a/src/mrinufft/extras/utils.py b/src/mrinufft/extras/utils.py index 41988b1f..5c9a7b9d 100644 --- a/src/mrinufft/extras/utils.py +++ b/src/mrinufft/extras/utils.py @@ -1,4 +1,5 @@ """Utils for extras module.""" + from mrinufft._utils import MethodRegister register_smaps = MethodRegister("sensitivity_maps") diff --git a/src/mrinufft/io/siemens.py b/src/mrinufft/io/siemens.py index b3cc79c2..9cc782aa 100644 --- a/src/mrinufft/io/siemens.py +++ b/src/mrinufft/io/siemens.py @@ -1,4 +1,5 @@ """Siemens specific rawdat reader, wrapper over pymapVBVD.""" + import numpy as np diff --git a/src/mrinufft/trajectories/maths.py b/src/mrinufft/trajectories/maths.py index eb8310e9..d8ad192f 100644 --- a/src/mrinufft/trajectories/maths.py +++ b/src/mrinufft/trajectories/maths.py @@ -187,19 +187,19 @@ def Ra(vector, theta): return np.array( [ [ - cos_t + v_x ** 2 * (1 - cos_t), + cos_t + v_x**2 * (1 - cos_t), v_x * v_y * (1 - cos_t) + v_z * sin_t, v_x * v_z * (1 - cos_t) - v_y * sin_t, ], [ v_y * v_x * (1 - cos_t) - v_z * sin_t, - cos_t + v_y ** 2 * (1 - cos_t), + cos_t + v_y**2 * (1 - cos_t), v_y * v_z * (1 - cos_t) + v_x * sin_t, ], [ v_z * v_x * (1 - cos_t) + v_y * sin_t, v_z * v_y * (1 - cos_t) - v_x * sin_t, - cos_t + v_z ** 2 * (1 - cos_t), + cos_t + v_z**2 * (1 - cos_t), ], ] ) diff --git a/src/mrinufft/trajectories/tools.py b/src/mrinufft/trajectories/tools.py index 972d013d..a3feebcd 100644 --- a/src/mrinufft/trajectories/tools.py +++ b/src/mrinufft/trajectories/tools.py @@ -310,7 +310,7 @@ def stack_spherically( # Attribute shots to stacks following density proportional to surface Nc_per_stack = np.ones(nb_stacks).astype(int) - density = radii ** 2 # simplified version + density = radii**2 # simplified version for _ in range(Nc - nb_stacks): idx = np.argmax(density / Nc_per_stack) Nc_per_stack[idx] += 1 @@ -406,7 +406,7 @@ def shellify( ) # Carve upper hemisphere from trajectory - z_coords = KMAX ** 2 - shell_upper[..., 0] ** 2 - shell_upper[..., 1] ** 2 + z_coords = KMAX**2 - shell_upper[..., 0] ** 2 - shell_upper[..., 1] ** 2 z_signs = np.sign(z_coords) shell_upper[..., 2] += z_signs * np.sqrt(np.abs(z_coords)) diff --git a/src/mrinufft/trajectories/trajectory3D.py b/src/mrinufft/trajectories/trajectory3D.py index bc2255b5..aab30197 100644 --- a/src/mrinufft/trajectories/trajectory3D.py +++ b/src/mrinufft/trajectories/trajectory3D.py @@ -240,7 +240,7 @@ def initialize_3D_wave_caipi( elif packing == Packings.CIRCLE: positions = [[0, 0]] counter = 0 - while len(positions) < side ** 2: + while len(positions) < side**2: counter += 1 perimeter = 2 * np.pi * counter nb_shots = int(np.trunc(perimeter)) @@ -353,11 +353,11 @@ def initialize_3D_seiffert_spiral( """ # Normalize ellipses integrations by the requested period spiral = np.zeros((1, Ns // (1 + in_out), 3)) - period = 4 * ellipk(curve_index ** 2) + period = 4 * ellipk(curve_index**2) times = np.linspace(0, nb_revolutions * period, Ns // (1 + in_out), endpoint=False) # Initialize first shot - jacobi = ellipj(times, curve_index ** 2) + jacobi = ellipj(times, curve_index**2) spiral[0, :, 0] = jacobi[0] * np.cos(curve_index * times) spiral[0, :, 1] = jacobi[0] * np.sin(curve_index * times) spiral[0, :, 2] = jacobi[1] @@ -655,7 +655,7 @@ def initialize_3D_seiffert_shells( Nc_per_shell[idx] += 1 # Normalize ellipses integrations by the requested period - period = 4 * ellipk(curve_index ** 2) + period = 4 * ellipk(curve_index**2) times = np.linspace(0, nb_revolutions * period, Ns, endpoint=False) # Create shells one by one @@ -667,7 +667,7 @@ def initialize_3D_seiffert_shells( k0 = radii[i] # Initialize first shot - jacobi = ellipj(times, curve_index ** 2) + jacobi = ellipj(times, curve_index**2) trajectory[count, :, 0] = k0 * jacobi[0] * np.cos(curve_index * times) trajectory[count, :, 1] = k0 * jacobi[0] * np.sin(curve_index * times) trajectory[count, :, 2] = k0 * jacobi[1] From ebd61d364692e4841f22aaceff6150fdbd299d67 Mon Sep 17 00:00:00 2001 From: chaithyagr Date: Thu, 23 May 2024 16:52:13 +0200 Subject: [PATCH 46/50] Fix --- src/mrinufft/extras/smaps.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/mrinufft/extras/smaps.py b/src/mrinufft/extras/smaps.py index 8a33f353..9051b8d3 100644 --- a/src/mrinufft/extras/smaps.py +++ b/src/mrinufft/extras/smaps.py @@ -2,8 +2,6 @@ from mrinufft.density.utils import flat_traj from mrinufft.operators.base import get_array_module -from skimage.filters import threshold_otsu, gaussian -from skimage.morphology import convex_hull_image from .utils import register_smaps import numpy as np from typing import Tuple @@ -148,6 +146,8 @@ def low_frequency( """ # defer import to later to prevent circular import from mrinufft import get_operator + from skimage.filters import threshold_otsu, gaussian + from skimage.morphology import convex_hull_image k_space, samples, dc = _extract_kspace_center( kspace_data=kspace_data, From c0aa0c5f725b21c64f09603b250c1b965135e38b Mon Sep 17 00:00:00 2001 From: chaithyagr Date: Thu, 23 May 2024 17:24:50 +0200 Subject: [PATCH 47/50] Added PSF weighting --- src/mrinufft/operators/interfaces/gpunufft.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/mrinufft/operators/interfaces/gpunufft.py b/src/mrinufft/operators/interfaces/gpunufft.py index b1f354ef..c97b0148 100644 --- a/src/mrinufft/operators/interfaces/gpunufft.py +++ b/src/mrinufft/operators/interfaces/gpunufft.py @@ -478,7 +478,7 @@ def uses_sense(self): return self.impl.uses_sense @classmethod - def pipe(cls, kspace_loc, volume_shape, num_iterations=10, osf=2, **kwargs): + def pipe(cls, kspace_loc, volume_shape, num_iterations=10, osf=2, normalize=True, **kwargs): """Compute the density compensation weights for a given set of kspace locations. Parameters @@ -491,6 +491,9 @@ def pipe(cls, kspace_loc, volume_shape, num_iterations=10, osf=2, **kwargs): the number of iterations for density estimation osf: float or int The oversampling factor the volume shape + normalize: bool + Whether to normalize the density compensation. + We normalize such that the energy of PSF = 1 """ if GPUNUFFT_AVAILABLE is False: raise ValueError( @@ -506,6 +509,12 @@ def pipe(cls, kspace_loc, volume_shape, num_iterations=10, osf=2, **kwargs): density_comp = grid_op.impl.operator.estimate_density_comp( max_iter=num_iterations ) + if normalize: + spike = np.zeros(volume_shape) + mid_loc = (v//2 for v in volume_shape) + spike[mid_loc] = 1 + psf = grid_op.adj_op(grid_op.op(spike)) + density_comp /= np.linalg.norm(psf) return density_comp.squeeze() def get_lipschitz_cst(self, max_iter=10, tolerance=1e-5, **kwargs): From 45bc400b2cfec8abee4c550780f7a6e5f1fa433c Mon Sep 17 00:00:00 2001 From: chaithyagr Date: Thu, 23 May 2024 17:48:10 +0200 Subject: [PATCH 48/50] Move to tuple --- src/mrinufft/operators/interfaces/gpunufft.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mrinufft/operators/interfaces/gpunufft.py b/src/mrinufft/operators/interfaces/gpunufft.py index c97b0148..ba3f4009 100644 --- a/src/mrinufft/operators/interfaces/gpunufft.py +++ b/src/mrinufft/operators/interfaces/gpunufft.py @@ -511,7 +511,7 @@ def pipe(cls, kspace_loc, volume_shape, num_iterations=10, osf=2, normalize=True ) if normalize: spike = np.zeros(volume_shape) - mid_loc = (v//2 for v in volume_shape) + mid_loc = tuple(v//2 for v in volume_shape) spike[mid_loc] = 1 psf = grid_op.adj_op(grid_op.op(spike)) density_comp /= np.linalg.norm(psf) From 18f5b349ae07a3baed08a85987056cf1b0a35ace Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Comby Date: Fri, 24 May 2024 16:36:02 +0200 Subject: [PATCH 49/50] lint --- src/mrinufft/operators/interfaces/gpunufft.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mrinufft/operators/interfaces/gpunufft.py b/src/mrinufft/operators/interfaces/gpunufft.py index ba3f4009..641706e2 100644 --- a/src/mrinufft/operators/interfaces/gpunufft.py +++ b/src/mrinufft/operators/interfaces/gpunufft.py @@ -511,7 +511,7 @@ def pipe(cls, kspace_loc, volume_shape, num_iterations=10, osf=2, normalize=True ) if normalize: spike = np.zeros(volume_shape) - mid_loc = tuple(v//2 for v in volume_shape) + mid_loc = tuple(v // 2 for v in volume_shape) spike[mid_loc] = 1 psf = grid_op.adj_op(grid_op.op(spike)) density_comp /= np.linalg.norm(psf) From 0639eda5acfe268666a1bf3f1f18a3d7d8a145dc Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Fri, 24 May 2024 17:05:17 +0200 Subject: [PATCH 50/50] lint --- src/mrinufft/operators/interfaces/gpunufft.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/mrinufft/operators/interfaces/gpunufft.py b/src/mrinufft/operators/interfaces/gpunufft.py index 641706e2..5e0d4ac0 100644 --- a/src/mrinufft/operators/interfaces/gpunufft.py +++ b/src/mrinufft/operators/interfaces/gpunufft.py @@ -478,7 +478,15 @@ def uses_sense(self): return self.impl.uses_sense @classmethod - def pipe(cls, kspace_loc, volume_shape, num_iterations=10, osf=2, normalize=True, **kwargs): + def pipe( + cls, + kspace_loc, + volume_shape, + num_iterations=10, + osf=2, + normalize=True, + **kwargs, + ): """Compute the density compensation weights for a given set of kspace locations. Parameters