From 66cd861706a470c4256dc85153ef11e7467f0a3e Mon Sep 17 00:00:00 2001 From: AsmaTANABEN Date: Thu, 21 Nov 2024 15:48:29 +0100 Subject: [PATCH] Fixes #211: Adjust normalization of density compensation coefficients. Add Cufinufft Pipe function authored by Chaithya G.R. --- .../operators/interfaces/cufinufft.py | 81 +++++++++++++++++++ src/mrinufft/operators/interfaces/gpunufft.py | 15 ++-- 2 files changed, 90 insertions(+), 6 deletions(-) diff --git a/src/mrinufft/operators/interfaces/cufinufft.py b/src/mrinufft/operators/interfaces/cufinufft.py index 9f3c360f..2bb30115 100644 --- a/src/mrinufft/operators/interfaces/cufinufft.py +++ b/src/mrinufft/operators/interfaces/cufinufft.py @@ -39,6 +39,30 @@ DTYPE_R2C = {"float32": "complex64", "float64": "complex128"} +def _next235beven(n, b): + """Find the next even integer not less than n. + + This function finds the next even integer not less than n, with prime factors no + larger than 5, and is a multiple of b (where b is a number that only + has prime factors 2, 3, and 5). + It is used in particular with `pipe` density compensation estimation. + """ + if n <= 2: + return 2 + if n % 2 == 1: + n += 1 # make it even + nplus = n - 2 # to cancel out the +=2 at start of loop + numdiv = 2 # a dummy that is >1 + while numdiv > 1 or nplus % b != 0: + nplus += 2 # stays even + numdiv = nplus + while numdiv % 2 == 0: + numdiv //= 2 # remove all factors of 2, 3, 5... + while numdiv % 3 == 0: + numdiv //= 3 + while numdiv % 5 == 0: + numdiv //= 5 + return nplus def _error_check(ier, msg): if ier != 0: @@ -849,3 +873,60 @@ def toggle_grad_traj(self): if self.uses_sense: self.smaps = self.smaps.conj() self.raw_op.toggle_grad_traj() + + @classmethod + def pipe( + cls, + kspace_loc, + volume_shape, + num_iterations=10, + osf=2, + normalize=True, + **kwargs, + ): + """Compute the density compensation weights for a given set of kspace locations. + + Parameters + ---------- + kspace_loc: np.ndarray + the kspace locations + volume_shape: np.ndarray + the volume shape + num_iterations: int default 10 + the number of iterations for density estimation + osf: float or int + The oversampling factor the volume shape + normalize: bool + Whether to normalize the density compensation. + """ + if CUFINUFFT_AVAILABLE is False: + raise ValueError( + "gpuNUFFT is not available, cannot " "estimate the density compensation" + ) + original_shape = volume_shape + volume_shape = np.array([_next235beven(int(osf * i), 1) for i in volume_shape]) + grid_op = MRICufiNUFFT( + samples=kspace_loc, + shape=volume_shape, + upsampfac=1, + gpu_spreadinterponly=1, + gpu_kerevalmeth=0, + **kwargs, + ) + density_comp = cp.ones(kspace_loc.shape[0], dtype=grid_op.cpx_dtype) + for _ in range(num_iterations): + density_comp /= cp.abs( + grid_op.op( + grid_op.adj_op(density_comp.astype(grid_op.cpx_dtype)) + ).squeeze() + ) + if normalize: + test_op = MRICufiNUFFT( + samples=kspace_loc, + shape=original_shape, + **kwargs + ) + test_im = cp.ones(original_shape, dtype=test_op.cpx_dtype) + test_im_recon = test_op.adj_op(density_comp*test_op.op(test_im)) + density_comp /= cp.mean(cp.abs(test_im_recon)) + return density_comp.squeeze() diff --git a/src/mrinufft/operators/interfaces/gpunufft.py b/src/mrinufft/operators/interfaces/gpunufft.py index b5d65af6..54be73bd 100644 --- a/src/mrinufft/operators/interfaces/gpunufft.py +++ b/src/mrinufft/operators/interfaces/gpunufft.py @@ -590,12 +590,12 @@ def pipe( The oversampling factor the volume shape normalize: bool Whether to normalize the density compensation. - We normalize such that the energy of PSF = 1 """ if GPUNUFFT_AVAILABLE is False: raise ValueError( "gpuNUFFT is not available, cannot " "estimate the density compensation" ) + original_shape = volume_shape volume_shape = (np.array(volume_shape) * osf).astype(int) grid_op = MRIGpuNUFFT( samples=kspace_loc, @@ -607,11 +607,14 @@ def pipe( max_iter=num_iterations ) if normalize: - spike = np.zeros(volume_shape) - mid_loc = tuple(v // 2 for v in volume_shape) - spike[mid_loc] = 1 - psf = grid_op.adj_op(grid_op.op(spike)) - density_comp /= np.linalg.norm(psf) + test_op = MRIGpuNUFFT( + samples=kspace_loc, + shape=original_shape, + **kwargs + ) + test_im = np.ones(original_shape, dtype=np.complex64) + test_im_recon = test_op.adj_op(density_comp*test_op.op(test_im)) + density_comp /= np.mean(np.abs(test_im_recon)) return density_comp.squeeze() def get_lipschitz_cst(self, max_iter=10, tolerance=1e-5, **kwargs):