Skip to content

Commit

Permalink
Merge branch 'master' into display_dens
Browse files Browse the repository at this point in the history
  • Loading branch information
chaithyagr authored Sep 27, 2024
2 parents cb076bd + 092c795 commit 165043f
Show file tree
Hide file tree
Showing 7 changed files with 270 additions and 7 deletions.
68 changes: 68 additions & 0 deletions examples/example_cg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# %%
"""
Example of using the Conjugate Gradient method.
This script demonstrates the use of the Conjugate Gradient (CG) method
for solving systems of linear equations of the form Ax = b, where A is a symmetric
positive-definite matrix. The CG method is an iterative algorithm that is particularly
useful for large, sparse systems where direct methods are computationally expensive.
The Conjugate Gradient method is widely used in various scientific and engineering
applications, including solving partial differential equations, optimization problems,
and machine learning tasks.
References
----------
- Inpirations:
- https://sigpy.readthedocs.io/en/latest/_modules/sigpy/alg.html#ConjugateGradient
- https://aquaulb.github.io/book_solving_pde_mooc/solving_pde_mooc/notebooks/05_IterativeMethods/05_02_Conjugate_Gradient.html
- Wikipedia:
- https://en.wikipedia.org/wiki/Conjugate_gradient_method
- https://en.wikipedia.org/wiki/Momentum
"""

# %%
# Imports
import numpy as np
import mrinufft
from brainweb_dl import get_mri
from mrinufft.extras.gradient import cg
from mrinufft.density import voronoi
from matplotlib import pyplot as plt

# %%
# Setup Inputs
samples_loc = mrinufft.initialize_2D_spiral(Nc=64, Ns=256)
image = get_mri(sub_id=4)
image = np.flipud(image[90])

# %%
# Setup the NUFFT operator
NufftOperator = mrinufft.get_operator("gpunufft") # get the operator
density = voronoi(samples_loc) # get the density

nufft = NufftOperator(
samples_loc, shape=image.shape, density=density, n_coils=1
) # create the NUFFT operator

# %%
# Reconstruct the image using the CG method
kspace_data = nufft.op(image) # get the k-space data
reconstructed_image = cg(nufft, kspace_data) # reconstruct the image

# %%
# Display the results
plt.figure(figsize=(10, 5))
plt.subplot(1, 3, 1)
plt.title("Original Image")
plt.imshow(abs(image), cmap="gray")

plt.subplot(1, 3, 2)
plt.title("Reconstructed Image with CG")
plt.imshow(abs(reconstructed_image), cmap="gray")

plt.subplot(1, 3, 3)
plt.title("Reconstructed Image with adjoint")
plt.imshow(abs(nufft.adj_op(kspace_data)), cmap="gray")

plt.show()
62 changes: 62 additions & 0 deletions src/mrinufft/extras/gradient.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""Conjugate gradient optimization algorithm for image reconstruction."""

import numpy as np

from mrinufft.operators.base import with_numpy


@with_numpy
def cg(operator, kspace_data, x_init=None, num_iter=10, tol=1e-4):
"""
Perform conjugate gradient (CG) optimization for image reconstruction.
The image is updated using the gradient of a data consistency term,
and a velocity vector is used to accelerate convergence.
Parameters
----------
kspace_data : numpy.ndarray
The k-space data to be used for image reconstruction.
x_init : numpy.ndarray, optional
An initial guess for the image. If None, an image of zeros with the same
shape as the expected output is used. Default is None.
num_iter : int, optional
The maximum number of iterations to perform. Default is 10.
tol : float, optional
The tolerance for convergence. If the norm of the gradient falls below
this value or the dot product between the image and k-space data is
non-positive, the iterations stop. Default is 1e-4.
Returns
-------
image : numpy.ndarray
The reconstructed image after the optimization process.
"""
Lipschitz_cst = operator.get_lipschitz_cst()
image = (
np.zeros(operator.shape, dtype=type(kspace_data[0]))
if x_init is None
else x_init
)
velocity = np.zeros_like(image)

grad = operator.data_consistency(image, kspace_data)
velocity = tol * velocity + grad / Lipschitz_cst
image = image - velocity

for _ in range(num_iter):
grad_new = operator.data_consistency(image, kspace_data)
if np.linalg.norm(grad_new) <= tol:
break

beta = np.dot(grad_new.flatten(), grad_new.flatten()) / np.dot(
grad.flatten(), grad.flatten()
)
velocity = grad_new + beta * velocity

image = image - velocity / Lipschitz_cst

return image
8 changes: 4 additions & 4 deletions src/mrinufft/operators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ def compute_smaps(self, method=None):
Note that this callable function should also hold the k-space data
(use funtools.partial)
"""
if isinstance(method, np.ndarray):
if is_host_array(method) or is_cuda_array(method):
self.smaps = method
return
if not method:
Expand Down Expand Up @@ -442,10 +442,10 @@ def compute_density(self, method=None):
or `torchkbnufft-gpu`.
"""
if isinstance(method, np.ndarray):
self.density = method
self._density = method
return None
if not method:
self.density = None
self._density = None
return None
if method is True:
method = "pipe"
Expand All @@ -466,7 +466,7 @@ def compute_density(self, method=None):
shape,
**kwargs,
)
self.density = method(self.samples, self.shape, **kwargs)
self._density = method(self.samples, self.shape, **kwargs)

def get_lipschitz_cst(self, max_iter=10, **kwargs):
"""Return the Lipschitz constant of the operator.
Expand Down
15 changes: 14 additions & 1 deletion src/mrinufft/operators/interfaces/cufinufft.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,8 @@ def smaps(self, new_smaps):
"""
self._check_smaps_shape(new_smaps)
if new_smaps is not None and hasattr(self, "smaps_cached"):
if self.smaps_cached:
if self.smaps_cached or is_cuda_array(new_smaps):
self.smaps_cached = True
warnings.warn(
f"{sizeof_fmt(new_smaps.size * np.dtype(self.cpx_dtype).itemsize)}"
"used on gpu for smaps."
Expand Down Expand Up @@ -279,6 +280,18 @@ def samples(self, samples):
self.raw_op._set_pts(typ, samples)
self.compute_density(self._density_method)

@FourierOperatorBase.density.setter
def density(self, new_density):
"""Update the density compensation."""
if new_density is None:
self._density = None
return
xp = get_array_module(new_density)
if xp.__name__ == "numpy":
self._density = cp.array(new_density)
elif xp.__name__ == "cupy":
self._density = new_density

@with_numpy_cupy
@nvtx_mark()
def op(self, data, ksp_d=None):
Expand Down
15 changes: 15 additions & 0 deletions src/mrinufft/operators/interfaces/gpunufft.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,21 @@ def samples(self, samples):
density=self.density,
)

@FourierOperatorBase.density.setter
def density(self, density):
"""Set the density for the Fourier Operator.
Parameters
----------
density: np.ndarray
The density for the Fourier Operator.
"""
self._density = density
self.raw_op.set_pts(
self._samples,
density=density,
)

@classmethod
def pipe(
cls,
Expand Down
37 changes: 35 additions & 2 deletions tests/operators/test_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from_interface,
)
from mrinufft import get_operator
from mrinufft._utils import get_array_module
from case_trajectories import CasesTrajectories


Expand Down Expand Up @@ -128,7 +129,7 @@ def new_smaps(operator):


@param_array_interface
def test_op(
def test_op_samples(
operator,
array_interface,
image_data,
Expand All @@ -145,7 +146,7 @@ def test_op(


@param_array_interface
def test_adj_op(
def test_adj_op_samples(
operator,
array_interface,
kspace_data,
Expand All @@ -162,6 +163,38 @@ def test_adj_op(
npt.assert_allclose(image_changed, image_true, atol=1e-3, rtol=1e-3)


@param_array_interface
def test_adj_op_density(
operator,
array_interface,
kspace_data,
):
"""Test the batch type 1 (adjoint)."""
kspace_data = to_interface(kspace_data, array_interface)
jitter = np.random.rand(operator.samples.shape[0]).astype(np.float32)
# Add very little noise to the trajectory, variance of 1e-3
if operator.uses_density:
# Test density can be updated
xp = get_array_module(operator.density)

operator.density += xp.array(jitter) / 100
else:
# Test that operator can handle density added later
operator.density = 1e-2 + jitter / 100
new_operator = update_operator(operator)
image_changed = from_interface(operator.adj_op(kspace_data), array_interface)
image_true = from_interface(new_operator.adj_op(kspace_data), array_interface)
# Reduced accuracy for the GPU cases...
npt.assert_allclose(image_changed, image_true, atol=1e-3, rtol=1e-3)
if operator.uses_density:
# Check if the operator can handle removing density compensation
operator.density = None
new_operator = update_operator(operator)
image_changed = from_interface(operator.adj_op(kspace_data), array_interface)
image_true = from_interface(new_operator.adj_op(kspace_data), array_interface)
npt.assert_allclose(image_changed, image_true, atol=1e-3, rtol=1e-3)


@param_array_interface
def test_op_smaps_update(
operator,
Expand Down
72 changes: 72 additions & 0 deletions tests/test_cg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""Test for the cg function."""

import numpy as np
import pytest
from pytest_cases import parametrize_with_cases, parametrize, fixture
from mrinufft.extras.gradient import cg
from mrinufft import get_operator
from case_trajectories import CasesTrajectories

from helpers import (
image_from_op,
param_array_interface,
)
from helpers import assert_almost_allclose


@fixture(scope="module")
@parametrize(
"backend",
[
"bart",
"finufft",
"cufinufft",
"gpunufft",
"sigpy",
"torchkbnufft-cpu",
"torchkbnufft-gpu",
"tensorflow",
],
)
@parametrize_with_cases(
"kspace_locs, shape",
cases=[
CasesTrajectories.case_random2D,
CasesTrajectories.case_grid2D,
CasesTrajectories.case_grid3D,
],
)
def operator(
request,
backend="pynfft",
kspace_locs=None,
shape=None,
n_coils=1,
):
"""Generate an operator."""
if backend in ["pynfft", "sigpy"] and kspace_locs.shape[-1] == 3:
pytest.skip("3D for slow cpu is not tested")
return get_operator(backend)(kspace_locs, shape, n_coils=n_coils, smaps=None)


@fixture(scope="module")
def image_data(operator):
"""Generate a random image. Remains constant for the module."""
return image_from_op(operator)


@param_array_interface
def test_cg(operator, array_interface, image_data):
"""Compare the interface to the raw NUDFT implementation."""
kspace_nufft = operator.op(image_data).squeeze()

image_cg = cg(operator, kspace_nufft)
kspace_cg = operator.op(image_cg).squeeze()

assert_almost_allclose(
kspace_cg,
kspace_nufft,
atol=2e-1,
rtol=1e-1,
mismatch=20,
)

0 comments on commit 165043f

Please sign in to comment.