-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'master' into display_dens
- Loading branch information
Showing
7 changed files
with
270 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# %% | ||
""" | ||
Example of using the Conjugate Gradient method. | ||
This script demonstrates the use of the Conjugate Gradient (CG) method | ||
for solving systems of linear equations of the form Ax = b, where A is a symmetric | ||
positive-definite matrix. The CG method is an iterative algorithm that is particularly | ||
useful for large, sparse systems where direct methods are computationally expensive. | ||
The Conjugate Gradient method is widely used in various scientific and engineering | ||
applications, including solving partial differential equations, optimization problems, | ||
and machine learning tasks. | ||
References | ||
---------- | ||
- Inpirations: | ||
- https://sigpy.readthedocs.io/en/latest/_modules/sigpy/alg.html#ConjugateGradient | ||
- https://aquaulb.github.io/book_solving_pde_mooc/solving_pde_mooc/notebooks/05_IterativeMethods/05_02_Conjugate_Gradient.html | ||
- Wikipedia: | ||
- https://en.wikipedia.org/wiki/Conjugate_gradient_method | ||
- https://en.wikipedia.org/wiki/Momentum | ||
""" | ||
|
||
# %% | ||
# Imports | ||
import numpy as np | ||
import mrinufft | ||
from brainweb_dl import get_mri | ||
from mrinufft.extras.gradient import cg | ||
from mrinufft.density import voronoi | ||
from matplotlib import pyplot as plt | ||
|
||
# %% | ||
# Setup Inputs | ||
samples_loc = mrinufft.initialize_2D_spiral(Nc=64, Ns=256) | ||
image = get_mri(sub_id=4) | ||
image = np.flipud(image[90]) | ||
|
||
# %% | ||
# Setup the NUFFT operator | ||
NufftOperator = mrinufft.get_operator("gpunufft") # get the operator | ||
density = voronoi(samples_loc) # get the density | ||
|
||
nufft = NufftOperator( | ||
samples_loc, shape=image.shape, density=density, n_coils=1 | ||
) # create the NUFFT operator | ||
|
||
# %% | ||
# Reconstruct the image using the CG method | ||
kspace_data = nufft.op(image) # get the k-space data | ||
reconstructed_image = cg(nufft, kspace_data) # reconstruct the image | ||
|
||
# %% | ||
# Display the results | ||
plt.figure(figsize=(10, 5)) | ||
plt.subplot(1, 3, 1) | ||
plt.title("Original Image") | ||
plt.imshow(abs(image), cmap="gray") | ||
|
||
plt.subplot(1, 3, 2) | ||
plt.title("Reconstructed Image with CG") | ||
plt.imshow(abs(reconstructed_image), cmap="gray") | ||
|
||
plt.subplot(1, 3, 3) | ||
plt.title("Reconstructed Image with adjoint") | ||
plt.imshow(abs(nufft.adj_op(kspace_data)), cmap="gray") | ||
|
||
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
"""Conjugate gradient optimization algorithm for image reconstruction.""" | ||
|
||
import numpy as np | ||
|
||
from mrinufft.operators.base import with_numpy | ||
|
||
|
||
@with_numpy | ||
def cg(operator, kspace_data, x_init=None, num_iter=10, tol=1e-4): | ||
""" | ||
Perform conjugate gradient (CG) optimization for image reconstruction. | ||
The image is updated using the gradient of a data consistency term, | ||
and a velocity vector is used to accelerate convergence. | ||
Parameters | ||
---------- | ||
kspace_data : numpy.ndarray | ||
The k-space data to be used for image reconstruction. | ||
x_init : numpy.ndarray, optional | ||
An initial guess for the image. If None, an image of zeros with the same | ||
shape as the expected output is used. Default is None. | ||
num_iter : int, optional | ||
The maximum number of iterations to perform. Default is 10. | ||
tol : float, optional | ||
The tolerance for convergence. If the norm of the gradient falls below | ||
this value or the dot product between the image and k-space data is | ||
non-positive, the iterations stop. Default is 1e-4. | ||
Returns | ||
------- | ||
image : numpy.ndarray | ||
The reconstructed image after the optimization process. | ||
""" | ||
Lipschitz_cst = operator.get_lipschitz_cst() | ||
image = ( | ||
np.zeros(operator.shape, dtype=type(kspace_data[0])) | ||
if x_init is None | ||
else x_init | ||
) | ||
velocity = np.zeros_like(image) | ||
|
||
grad = operator.data_consistency(image, kspace_data) | ||
velocity = tol * velocity + grad / Lipschitz_cst | ||
image = image - velocity | ||
|
||
for _ in range(num_iter): | ||
grad_new = operator.data_consistency(image, kspace_data) | ||
if np.linalg.norm(grad_new) <= tol: | ||
break | ||
|
||
beta = np.dot(grad_new.flatten(), grad_new.flatten()) / np.dot( | ||
grad.flatten(), grad.flatten() | ||
) | ||
velocity = grad_new + beta * velocity | ||
|
||
image = image - velocity / Lipschitz_cst | ||
|
||
return image |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
"""Test for the cg function.""" | ||
|
||
import numpy as np | ||
import pytest | ||
from pytest_cases import parametrize_with_cases, parametrize, fixture | ||
from mrinufft.extras.gradient import cg | ||
from mrinufft import get_operator | ||
from case_trajectories import CasesTrajectories | ||
|
||
from helpers import ( | ||
image_from_op, | ||
param_array_interface, | ||
) | ||
from helpers import assert_almost_allclose | ||
|
||
|
||
@fixture(scope="module") | ||
@parametrize( | ||
"backend", | ||
[ | ||
"bart", | ||
"finufft", | ||
"cufinufft", | ||
"gpunufft", | ||
"sigpy", | ||
"torchkbnufft-cpu", | ||
"torchkbnufft-gpu", | ||
"tensorflow", | ||
], | ||
) | ||
@parametrize_with_cases( | ||
"kspace_locs, shape", | ||
cases=[ | ||
CasesTrajectories.case_random2D, | ||
CasesTrajectories.case_grid2D, | ||
CasesTrajectories.case_grid3D, | ||
], | ||
) | ||
def operator( | ||
request, | ||
backend="pynfft", | ||
kspace_locs=None, | ||
shape=None, | ||
n_coils=1, | ||
): | ||
"""Generate an operator.""" | ||
if backend in ["pynfft", "sigpy"] and kspace_locs.shape[-1] == 3: | ||
pytest.skip("3D for slow cpu is not tested") | ||
return get_operator(backend)(kspace_locs, shape, n_coils=n_coils, smaps=None) | ||
|
||
|
||
@fixture(scope="module") | ||
def image_data(operator): | ||
"""Generate a random image. Remains constant for the module.""" | ||
return image_from_op(operator) | ||
|
||
|
||
@param_array_interface | ||
def test_cg(operator, array_interface, image_data): | ||
"""Compare the interface to the raw NUDFT implementation.""" | ||
kspace_nufft = operator.op(image_data).squeeze() | ||
|
||
image_cg = cg(operator, kspace_nufft) | ||
kspace_cg = operator.op(image_cg).squeeze() | ||
|
||
assert_almost_allclose( | ||
kspace_cg, | ||
kspace_nufft, | ||
atol=2e-1, | ||
rtol=1e-1, | ||
mismatch=20, | ||
) |