Skip to content

Commit

Permalink
fix: pipe and gpunufft cleanup
Browse files Browse the repository at this point in the history
* fix: cleanup test cases.

* fix: gpunufft is default for pipe.

* fix: kwargs extractions.

* fix: setup osf manually for pipe.

* fix: use correct entry point for density.

* refactor(gpunufft)!: remove front facing optional arguments for grid/interp only.

* docs: update pipe density example.

* style: docstring format.
  • Loading branch information
paquiteau authored Jan 10, 2024
1 parent 2f39bcf commit 8682266
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 31 deletions.
10 changes: 7 additions & 3 deletions examples/example_density.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,19 +132,23 @@
#
# .. warning::
# If this method is widely used in the literature, there exists no convergence guarantees for it.

#
# .. note::
# The Pipe method is currently only implemented for gpuNUFFT.

# %%
if check_backend("gpunufft"):
flat_traj = traj.reshape(-1, 2)
nufft = get_operator("gpunufft")(traj, shape=mri_2D.shape, density=False)
nufft = get_operator("gpunufft")(
traj, shape=mri_2D.shape, density={"name": "pipe", "osf": 2}
)
adjoint_manual = nufft.adj_op(kspace)
fig, axs = plt.subplots(1, 3, figsize=(15, 5))
axs[0].imshow(abs(mri_2D))
axs[0].set_title("Ground Truth")
axs[1].imshow(abs(adjoint))
axs[1].set_title("no density compensation")
axs[2].imshow(abs(adjoint_manual))
axs[2].set_title("manual density compensation")
axs[2].set_title("Pipe density compensation")

print(nufft.density)
2 changes: 1 addition & 1 deletion src/mrinufft/density/nufft_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

@register_density
@flat_traj
def pipe(traj, shape, backend="cufinufft", **kwargs):
def pipe(traj, shape, backend="gpunufft", **kwargs):
"""Compute the density compensation weights using the pipe method.
Parameters
Expand Down
4 changes: 2 additions & 2 deletions src/mrinufft/operators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,8 @@ def compute_density(self, method=None):

kwargs = {}
if isinstance(method, dict):
method = method["name"] # should be a string !
kwargs = method.copy().remove("name")
kwargs = method.copy()
method = kwargs.pop("name") # must be a string !
if method == "pipe" and "backend" not in kwargs:
kwargs["backend"] = self.backend
if isinstance(method, str):
Expand Down
41 changes: 25 additions & 16 deletions src/mrinufft/operators/interfaces/gpunufft.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,60 +300,61 @@ def __init__(
self.dtype = self.samples.dtype
self.n_coils = n_coils
self.smaps = smaps
if density is True:
self.density = self.pipe(self.samples, shape)
elif isinstance(density, np.ndarray):
self.density = density
else:
self.density = None
self.kwargs = kwargs
self.compute_density(density)
self.impl = RawGpuNUFFT(
samples=self.samples,
shape=self.shape,
n_coils=self.n_coils,
density_comp=self.density,
smaps=smaps,
kernel_width=kwargs.get("kernel_width", -int(np.log10(eps))),
**self.kwargs,
**kwargs,
)

def op(self, data, *args, **kwargs):
def op(self, data, coeffs=None):
"""Compute forward non-uniform Fourier Transform.
Parameters
----------
img: np.ndarray
input N-D array with the same shape as self.shape.
coeffs: np.ndarray, optional
output Array. Should be pinned memory for best performances.
Returns
-------
np.ndarray
Masked Fourier transform of the input image.
"""
return self.impl.op(data, *args, **kwargs)
return self.impl.op(
data,
coeffs,
)

def adj_op(self, coeffs, *args, **kwargs):
def adj_op(self, coeffs, data=None):
"""Compute adjoint Non Unform Fourier Transform.
Parameters
----------
x: np.ndarray
coeffs: np.ndarray
masked non-uniform Fourier transform 1D data.
data: np.ndarray, optional
output image array. Should be pinned memory for best performances.
Returns
-------
np.ndarray
Inverse discrete Fourier transform of the input coefficients.
"""
return self.impl.adj_op(coeffs, *args, **kwargs)
return self.impl.adj_op(coeffs, data)

@property
def uses_sense(self):
"""Return True if the Fourier Operator uses the SENSE method."""
return self.impl.uses_sense

@classmethod
def pipe(cls, kspace_loc, volume_shape, num_iterations=10):
def pipe(cls, kspace_loc, volume_shape, num_iterations=10, osf=2, **kwargs):
"""Compute the density compensation weights for a given set of kspace locations.
Parameters
Expand All @@ -364,19 +365,27 @@ def pipe(cls, kspace_loc, volume_shape, num_iterations=10):
the volume shape
num_iterations: int default 10
the number of iterations for density estimation
osf: float or int
The oversampling factor the volume shape
"""
if GPUNUFFT_AVAILABLE is False:
raise ValueError(
"gpuNUFFT is not available, cannot " "estimate the density compensation"
)
volume_shape = tuple(int(osf * s) for s in volume_shape)
grid_op = MRIGpuNUFFT(
samples=kspace_loc,
shape=volume_shape,
osf=1,
**kwargs,
)
density_comp = np.ones(kspace_loc.shape[0])
for _ in range(num_iterations):
density_comp = density_comp / np.abs(
grid_op.op(grid_op.adj_op(density_comp, None, True), None, True)
grid_op.impl.op(
grid_op.impl.adj_op(density_comp, None, True),
None,
True,
)
)
return density_comp
return density_comp.squeeze()
25 changes: 16 additions & 9 deletions tests/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@
[
(1, 1, 1, False),
(3, 1, 1, False),
(1, 4, 1, False),
(1, 4, 1, True),
(1, 4, 1, False),
(1, 4, 2, True),
(1, 4, 2, False),
(3, 2, 1, False),
(3, 2, 1, True),
(3, 2, 1, False),
(3, 4, 2, True),
(3, 4, 2, False),
],
)
@parametrize_with_cases(
Expand Down Expand Up @@ -65,7 +67,12 @@ def operator(
def flat_operator(operator):
"""Generate a batch operator with n_batch=1."""
return get_operator(operator.backend)(
operator.samples, operator.shape, n_coils=operator.n_coils, smaps=operator.smaps
operator.samples,
operator.shape,
n_coils=operator.n_coils,
smaps=operator.smaps,
squeeze_dims=False,
n_trans=1,
)


Expand All @@ -92,7 +99,7 @@ def kspace_data(operator):

def test_batch_op(operator, flat_operator, image_data):
"""Test the batch type 2 (forward)."""
kspace_data = operator.op(image_data)
kspace_batched = operator.op(image_data)

if operator.uses_sense:
image_flat = image_data.reshape(-1, *operator.shape)
Expand All @@ -107,7 +114,7 @@ def test_batch_op(operator, flat_operator, image_data):
(operator.n_batchs, operator.n_coils, operator.n_samples),
)

npt.assert_array_almost_equal_nulp(kspace_data, kspace_flat)
npt.assert_array_almost_equal(kspace_batched, kspace_flat)


def test_batch_adj_op(operator, flat_operator, kspace_data):
Expand All @@ -127,9 +134,9 @@ def test_batch_adj_op(operator, flat_operator, kspace_data):
shape,
)

image_data = operator.adj_op(kspace_data)
image_batched = operator.adj_op(kspace_data)
# Reduced accuracy for the GPU cases...
npt.assert_allclose(image_data, image_flat, atol=1e-3, rtol=1e-3)
npt.assert_allclose(image_batched, image_flat, atol=1e-3, rtol=1e-3)


def test_data_consistency(operator, image_data, kspace_data):
Expand Down Expand Up @@ -166,12 +173,12 @@ def test_data_consistency_readonly(operator, image_data, kspace_data):
def test_gradient_lipschitz(operator, image_data, kspace_data):
"""Test the gradient lipschitz constant."""
C = 1 if operator.uses_sense else operator.n_coils
img = image_data.copy().reshape(operator.n_batchs, C, *operator.shape)
img = image_data.copy().reshape(operator.n_batchs, C, *operator.shape).squeeze()
for _ in range(10):
grad = operator.data_consistency(img, kspace_data)
norm = np.linalg.norm(grad)
grad /= norm
np.copyto(img, grad)
np.copyto(img, grad.squeeze())
norm_prev = norm

# TODO: check that the value is "not too far" from 1
Expand Down

0 comments on commit 8682266

Please sign in to comment.