From 868226644b68db3299c742172350ff4a4b288459 Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Comby Date: Wed, 10 Jan 2024 20:08:54 +0100 Subject: [PATCH] fix: pipe and gpunufft cleanup * fix: cleanup test cases. * fix: gpunufft is default for pipe. * fix: kwargs extractions. * fix: setup osf manually for pipe. * fix: use correct entry point for density. * refactor(gpunufft)!: remove front facing optional arguments for grid/interp only. * docs: update pipe density example. * style: docstring format. --- examples/example_density.py | 10 +++-- src/mrinufft/density/nufft_based.py | 2 +- src/mrinufft/operators/base.py | 4 +- src/mrinufft/operators/interfaces/gpunufft.py | 41 +++++++++++-------- tests/test_batch.py | 25 +++++++---- 5 files changed, 51 insertions(+), 31 deletions(-) diff --git a/examples/example_density.py b/examples/example_density.py index a22f5480..98950aee 100644 --- a/examples/example_density.py +++ b/examples/example_density.py @@ -132,14 +132,16 @@ # # .. warning:: # If this method is widely used in the literature, there exists no convergence guarantees for it. - +# # .. note:: # The Pipe method is currently only implemented for gpuNUFFT. # %% if check_backend("gpunufft"): flat_traj = traj.reshape(-1, 2) - nufft = get_operator("gpunufft")(traj, shape=mri_2D.shape, density=False) + nufft = get_operator("gpunufft")( + traj, shape=mri_2D.shape, density={"name": "pipe", "osf": 2} + ) adjoint_manual = nufft.adj_op(kspace) fig, axs = plt.subplots(1, 3, figsize=(15, 5)) axs[0].imshow(abs(mri_2D)) @@ -147,4 +149,6 @@ axs[1].imshow(abs(adjoint)) axs[1].set_title("no density compensation") axs[2].imshow(abs(adjoint_manual)) - axs[2].set_title("manual density compensation") + axs[2].set_title("Pipe density compensation") + + print(nufft.density) diff --git a/src/mrinufft/density/nufft_based.py b/src/mrinufft/density/nufft_based.py index 1b1b9cf5..c2d4d7fa 100644 --- a/src/mrinufft/density/nufft_based.py +++ b/src/mrinufft/density/nufft_based.py @@ -5,7 +5,7 @@ @register_density @flat_traj -def pipe(traj, shape, backend="cufinufft", **kwargs): +def pipe(traj, shape, backend="gpunufft", **kwargs): """Compute the density compensation weights using the pipe method. Parameters diff --git a/src/mrinufft/operators/base.py b/src/mrinufft/operators/base.py index c7905967..22561f60 100644 --- a/src/mrinufft/operators/base.py +++ b/src/mrinufft/operators/base.py @@ -176,8 +176,8 @@ def compute_density(self, method=None): kwargs = {} if isinstance(method, dict): - method = method["name"] # should be a string ! - kwargs = method.copy().remove("name") + kwargs = method.copy() + method = kwargs.pop("name") # must be a string ! if method == "pipe" and "backend" not in kwargs: kwargs["backend"] = self.backend if isinstance(method, str): diff --git a/src/mrinufft/operators/interfaces/gpunufft.py b/src/mrinufft/operators/interfaces/gpunufft.py index ca821091..7ea74340 100644 --- a/src/mrinufft/operators/interfaces/gpunufft.py +++ b/src/mrinufft/operators/interfaces/gpunufft.py @@ -300,13 +300,7 @@ def __init__( self.dtype = self.samples.dtype self.n_coils = n_coils self.smaps = smaps - if density is True: - self.density = self.pipe(self.samples, shape) - elif isinstance(density, np.ndarray): - self.density = density - else: - self.density = None - self.kwargs = kwargs + self.compute_density(density) self.impl = RawGpuNUFFT( samples=self.samples, shape=self.shape, @@ -314,38 +308,45 @@ def __init__( density_comp=self.density, smaps=smaps, kernel_width=kwargs.get("kernel_width", -int(np.log10(eps))), - **self.kwargs, + **kwargs, ) - def op(self, data, *args, **kwargs): + def op(self, data, coeffs=None): """Compute forward non-uniform Fourier Transform. Parameters ---------- img: np.ndarray input N-D array with the same shape as self.shape. + coeffs: np.ndarray, optional + output Array. Should be pinned memory for best performances. Returns ------- np.ndarray Masked Fourier transform of the input image. """ - return self.impl.op(data, *args, **kwargs) + return self.impl.op( + data, + coeffs, + ) - def adj_op(self, coeffs, *args, **kwargs): + def adj_op(self, coeffs, data=None): """Compute adjoint Non Unform Fourier Transform. Parameters ---------- - x: np.ndarray + coeffs: np.ndarray masked non-uniform Fourier transform 1D data. + data: np.ndarray, optional + output image array. Should be pinned memory for best performances. Returns ------- np.ndarray Inverse discrete Fourier transform of the input coefficients. """ - return self.impl.adj_op(coeffs, *args, **kwargs) + return self.impl.adj_op(coeffs, data) @property def uses_sense(self): @@ -353,7 +354,7 @@ def uses_sense(self): return self.impl.uses_sense @classmethod - def pipe(cls, kspace_loc, volume_shape, num_iterations=10): + def pipe(cls, kspace_loc, volume_shape, num_iterations=10, osf=2, **kwargs): """Compute the density compensation weights for a given set of kspace locations. Parameters @@ -364,19 +365,27 @@ def pipe(cls, kspace_loc, volume_shape, num_iterations=10): the volume shape num_iterations: int default 10 the number of iterations for density estimation + osf: float or int + The oversampling factor the volume shape """ if GPUNUFFT_AVAILABLE is False: raise ValueError( "gpuNUFFT is not available, cannot " "estimate the density compensation" ) + volume_shape = tuple(int(osf * s) for s in volume_shape) grid_op = MRIGpuNUFFT( samples=kspace_loc, shape=volume_shape, osf=1, + **kwargs, ) density_comp = np.ones(kspace_loc.shape[0]) for _ in range(num_iterations): density_comp = density_comp / np.abs( - grid_op.op(grid_op.adj_op(density_comp, None, True), None, True) + grid_op.impl.op( + grid_op.impl.adj_op(density_comp, None, True), + None, + True, + ) ) - return density_comp + return density_comp.squeeze() diff --git a/tests/test_batch.py b/tests/test_batch.py index 29c18bb1..2ddaee28 100644 --- a/tests/test_batch.py +++ b/tests/test_batch.py @@ -17,12 +17,14 @@ [ (1, 1, 1, False), (3, 1, 1, False), - (1, 4, 1, False), (1, 4, 1, True), + (1, 4, 1, False), + (1, 4, 2, True), (1, 4, 2, False), - (3, 2, 1, False), (3, 2, 1, True), + (3, 2, 1, False), (3, 4, 2, True), + (3, 4, 2, False), ], ) @parametrize_with_cases( @@ -65,7 +67,12 @@ def operator( def flat_operator(operator): """Generate a batch operator with n_batch=1.""" return get_operator(operator.backend)( - operator.samples, operator.shape, n_coils=operator.n_coils, smaps=operator.smaps + operator.samples, + operator.shape, + n_coils=operator.n_coils, + smaps=operator.smaps, + squeeze_dims=False, + n_trans=1, ) @@ -92,7 +99,7 @@ def kspace_data(operator): def test_batch_op(operator, flat_operator, image_data): """Test the batch type 2 (forward).""" - kspace_data = operator.op(image_data) + kspace_batched = operator.op(image_data) if operator.uses_sense: image_flat = image_data.reshape(-1, *operator.shape) @@ -107,7 +114,7 @@ def test_batch_op(operator, flat_operator, image_data): (operator.n_batchs, operator.n_coils, operator.n_samples), ) - npt.assert_array_almost_equal_nulp(kspace_data, kspace_flat) + npt.assert_array_almost_equal(kspace_batched, kspace_flat) def test_batch_adj_op(operator, flat_operator, kspace_data): @@ -127,9 +134,9 @@ def test_batch_adj_op(operator, flat_operator, kspace_data): shape, ) - image_data = operator.adj_op(kspace_data) + image_batched = operator.adj_op(kspace_data) # Reduced accuracy for the GPU cases... - npt.assert_allclose(image_data, image_flat, atol=1e-3, rtol=1e-3) + npt.assert_allclose(image_batched, image_flat, atol=1e-3, rtol=1e-3) def test_data_consistency(operator, image_data, kspace_data): @@ -166,12 +173,12 @@ def test_data_consistency_readonly(operator, image_data, kspace_data): def test_gradient_lipschitz(operator, image_data, kspace_data): """Test the gradient lipschitz constant.""" C = 1 if operator.uses_sense else operator.n_coils - img = image_data.copy().reshape(operator.n_batchs, C, *operator.shape) + img = image_data.copy().reshape(operator.n_batchs, C, *operator.shape).squeeze() for _ in range(10): grad = operator.data_consistency(img, kspace_data) norm = np.linalg.norm(grad) grad /= norm - np.copyto(img, grad) + np.copyto(img, grad.squeeze()) norm_prev = norm # TODO: check that the value is "not too far" from 1