Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes for normalization issues #214

Merged
merged 7 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions src/mrinufft/operators/interfaces/cufinufft.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,32 @@
DTYPE_R2C = {"float32": "complex64", "float64": "complex128"}


def _next235beven(n, b):
"""Find the next even integer not less than n.

This function finds the next even integer not less than n, with prime factors no
larger than 5, and is a multiple of b (where b is a number that only
has prime factors 2, 3, and 5).
It is used in particular with `pipe` density compensation estimation.
"""
if n <= 2:
return 2
if n % 2 == 1:
n += 1 # make it even
nplus = n - 2 # to cancel out the +=2 at start of loop
numdiv = 2 # a dummy that is >1
while numdiv > 1 or nplus % b != 0:
nplus += 2 # stays even
numdiv = nplus
while numdiv % 2 == 0:
numdiv //= 2 # remove all factors of 2, 3, 5...
while numdiv % 3 == 0:
numdiv //= 3
while numdiv % 5 == 0:
numdiv //= 5
return nplus


def _error_check(ier, msg):
if ier != 0:
raise RuntimeError(msg)
Expand Down Expand Up @@ -849,3 +875,56 @@ def toggle_grad_traj(self):
if self.uses_sense:
self.smaps = self.smaps.conj()
self.raw_op.toggle_grad_traj()

@classmethod
def pipe(
cls,
kspace_loc,
volume_shape,
num_iterations=10,
osf=2,
normalize=True,
**kwargs,
):
"""Compute the density compensation weights for a given set of kspace locations.

Parameters
----------
kspace_loc: np.ndarray
the kspace locations
volume_shape: np.ndarray
the volume shape
num_iterations: int default 10
the number of iterations for density estimation
osf: float or int
The oversampling factor the volume shape
normalize: bool
Whether to normalize the density compensation.
"""
if CUFINUFFT_AVAILABLE is False:
raise ValueError(
"gpuNUFFT is not available, cannot " "estimate the density compensation"
)
original_shape = volume_shape
volume_shape = np.array([_next235beven(int(osf * i), 1) for i in volume_shape])
grid_op = MRICufiNUFFT(
samples=kspace_loc,
shape=volume_shape,
upsampfac=1,
gpu_spreadinterponly=1,
gpu_kerevalmeth=0,
**kwargs,
)
density_comp = cp.ones(kspace_loc.shape[0], dtype=grid_op.cpx_dtype)
for _ in range(num_iterations):
density_comp /= cp.abs(
grid_op.op(
grid_op.adj_op(density_comp.astype(grid_op.cpx_dtype))
).squeeze()
)
if normalize:
test_op = MRICufiNUFFT(samples=kspace_loc, shape=original_shape, **kwargs)
test_im = cp.ones(original_shape, dtype=test_op.cpx_dtype)
test_im_recon = test_op.adj_op(density_comp * test_op.op(test_im))
density_comp /= cp.mean(cp.abs(test_im_recon))
return density_comp.squeeze()
11 changes: 5 additions & 6 deletions src/mrinufft/operators/interfaces/gpunufft.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,12 +590,12 @@ def pipe(
The oversampling factor the volume shape
normalize: bool
Whether to normalize the density compensation.
We normalize such that the energy of PSF = 1
"""
if GPUNUFFT_AVAILABLE is False:
raise ValueError(
"gpuNUFFT is not available, cannot " "estimate the density compensation"
)
original_shape = volume_shape
volume_shape = (np.array(volume_shape) * osf).astype(int)
grid_op = MRIGpuNUFFT(
samples=kspace_loc,
Expand All @@ -607,11 +607,10 @@ def pipe(
max_iter=num_iterations
)
if normalize:
spike = np.zeros(volume_shape)
mid_loc = tuple(v // 2 for v in volume_shape)
spike[mid_loc] = 1
psf = grid_op.adj_op(grid_op.op(spike))
density_comp /= np.linalg.norm(psf)
test_op = MRIGpuNUFFT(samples=kspace_loc, shape=original_shape, **kwargs)
test_im = np.ones(original_shape, dtype=np.complex64)
test_im_recon = test_op.adj_op(density_comp * test_op.op(test_im))
density_comp /= np.mean(np.abs(test_im_recon))
return density_comp.squeeze()

def get_lipschitz_cst(self, max_iter=10, tolerance=1e-5, **kwargs):
Expand Down
Loading