Skip to content

Commit

Permalink
wip: testing lstsq
Browse files Browse the repository at this point in the history
  • Loading branch information
mcencini committed Apr 10, 2024
1 parent 8c74940 commit d660851
Show file tree
Hide file tree
Showing 12 changed files with 425 additions and 241 deletions.
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,11 @@ python_requires = >=3.9
# For more information, check out https://semver.org/.
install_requires =
importlib-metadata; python_version<"3.9"
click
dacite
h5py
ismrmrd
lightning
mat73
matplotlib
nibabel
Expand Down
3 changes: 3 additions & 0 deletions src/deepmr/optim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,17 @@
from . import cg as _cg
from . import admm as _admm
from . import pgd as _pgd
from . import lstsq as _lstsq

from .spectral import * # noqa
from .cg import * # noqa
from .admm import * # noqa
from .pgd import * # noqa
from .lstsq import * # noqa

__all__ = []
__all__.extend(_spectral.__all__)
__all__.extend(_cg.__all__)
__all__.extend(_admm.__all__)
__all__.extend(_pgd.__all__)
__all__.extend(_lstsq.__all__)
89 changes: 67 additions & 22 deletions src/deepmr/optim/admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

__all__ = ["admm_solve", "ADMMStep"]

import time

import numpy as np
import torch

Expand All @@ -18,12 +20,12 @@ def admm_solve(
step,
AHA,
D,
P=None,
niter=10,
device=None,
dc_niter=10,
dc_tol=1e-4,
dc_ndim=None,
tol=None,
save_history=False,
verbose=False,
):
Expand All @@ -41,9 +43,6 @@ def admm_solve(
Normal operator AHA = AH * A.
D : Callable
Signal denoiser for plug-n-play restoration.
P : Callable, optional
Polynomial preconditioner for data consistency.
The default is ``None`` (standard CG for data consistency).
niter : int, optional
Number of iterations. The default is ``10``.
device : str, optional
Expand All @@ -59,6 +58,9 @@ def admm_solve(
Number of spatial dimensions of the problem for inner data consistency step.
It is used to infer the batch axes. If ``AHA`` is a ``deepmr.linop.Linop``
operator, this is inferred from ``AHA.ndim`` and ``ndim`` is ignored.
atol : float, optional
Stopping condition for ADMM. If not provided, run until ``niter``.
The default is ``None``.
save_history : bool, optional
Record cost function. The default is ``False``.
verbose : bool, optional
Expand All @@ -76,6 +78,10 @@ def admm_solve(
input = torch.as_tensor(input)
else:
isnumpy = False

# assert inputs are correct
if verbose:
assert save_history is True, "We need to record history to print information."

# keep original device
idevice = input.device
Expand All @@ -94,15 +100,45 @@ def admm_solve(
AHy = input.clone()

# initialize algorithm
ADMM = ADMMStep(step, AHA, AHy, D, P, niter=dc_niter, tol=dc_tol, ndim=dc_ndim)
ADMM = ADMMStep(step, AHA, AHy, D, niter=dc_niter, tol=dc_tol, atol=tol, ndim=dc_ndim)

# initialize
input = 0 * input
history = []

# start timer
if verbose:
t0 = time.time()
nprint = np.linspace(0, niter, 5)
nprint = nprint.astype(int).tolist()
print("================================= ADMM =====================================")
print("| nsteps | n_AHA | data consistency | regularization | total cost | t-t0 [s]")
print("============================================================================")

# run algorithm
for n in range(niter):
output = ADMM(input)

# update variable
input = output.clone()

# if required, compute residual and check if we reached convergence
if ADMM.check_convergence(output, input, step):
break

# if required, save history
if save_history:
r = output - AHy
dc = 0.5 * torch.linalg.norm(r).item() ** 2
reg = np.sum([d.g(output) for d in D])
history.append(dc+reg)
if verbose and n in nprint:
t = time.time()
print(" {}{}{:.4f}{:.4f}{:.4f}{:.2f}".format(n, n * dc_niter, dc, reg, dc+reg, t-t0))

if verbose:
t1 = time.time()
print(f"Exiting ADMM: total elapsed time: {round(t1-t0, 2)} [s]")

# back to original device
output = output.to(device)
Expand All @@ -111,7 +147,7 @@ def admm_solve(
if isnumpy:
output = output.numpy(force=True)

return output
return output, history


class ADMMStep(nn.Module):
Expand Down Expand Up @@ -140,9 +176,12 @@ class ADMMStep(nn.Module):
If ``True``, gradient update step is trainable, otherwise it is not.
The default is ``False``.
niter : int, optional
Number of iterations of inner data consistency step.
Number of iterations of inner data consistency step. The default is ``10``.
tol : float, optional
Stopping condition for inner data consistency step.
Stopping condition for inner data consistency step. The default is ``1e-4``
atol : float, optional
Stopping condition for ADMM. If not provided, run until ``niter``.
The default is ``None``.
ndim : int, optional
Number of spatial dimensions of the problem for inner data consistency step.
It is used to infer the batch axes. If ``AHA`` is a ``deepmr.linop.Linop``
Expand All @@ -151,7 +190,7 @@ class ADMMStep(nn.Module):
"""

def __init__(
self, step, AHA, AHy, D, P, trainable=False, niter=10, tol=1e-4, ndim=None
self, step, AHA, AHy, D, trainable=False, niter=10, tol=1e-4, atol=None, ndim=None
):
super().__init__()
if trainable:
Expand All @@ -168,7 +207,6 @@ def __init__(
# assign operators
self.AHA = AHA
self.AHy = AHy
self.P = P

# assign denoisers
if hasattr(D, "__iter__"):
Expand All @@ -187,21 +225,18 @@ def __init__(
# dc solver settings
self.niter = niter
self.tol = tol
self.atol = atol

def forward(self, input):
# data consistency step: zk = (AHA + gamma * I).solve(AHy)
if self.P is None:
self.xi[0] = cg_solve(
self.AHy + self.step * (input - self.ui[0]),
self.AHA,
niter=self.niter,
tol=self.tol,
lamda=self.step,
ndim=self.ndim,
)
else:
v = self.step * (input - self.ui[0])
self.xi[0] = v - self.P * (self.AHA(v) - (v + self.AHy))
self.xi[0], _ = cg_solve(
self.AHy + self.step * (input - self.ui[0]),
self.AHA,
niter=self.niter,
tol=self.tol,
lamda=self.step,
ndim=self.ndim,
)

# denoise using each regularizator
for n in range(len(self.D)):
Expand All @@ -212,3 +247,13 @@ def forward(self, input):
self.ui += self.xi - output[None, ...]

return output

def check_convergence(self, output, input, step):
if self.atol is not None:
resid = torch.linalg.norm(output - input).item() / step
if resid < self.atol:
return True
else:
return False
else:
return False
40 changes: 37 additions & 3 deletions src/deepmr/optim/cg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

__all__ = ["cg_solve", "CGStep"]

import time

import numpy as np
import torch

Expand Down Expand Up @@ -62,7 +64,11 @@ def cg_solve(
input = torch.as_tensor(input)
else:
isnumpy = False


# assert inputs are correct
if verbose:
assert save_history is True, "We need to record history to print information."

# keep original device
idevice = input.device
if device is None:
Expand Down Expand Up @@ -95,13 +101,41 @@ def cg_solve(

# initialize
input = 0 * input

history = []

# start timer
if verbose:
t0 = time.time()
nprint = np.linspace(0, niter, 5)
nprint = nprint.astype(int).tolist()
print("====================== Conjugate Gradient ==========================")
print("| nsteps | data consistency | regularization | total cost | t-t0 [s]")
print("====================================================================")

# run algorithm
for n in range(niter):
output = CG(input)

# if required, compute residual and check if we reached convergence
if CG.check_convergence():
break

# update variable
input = output.clone()

# if required, save history
if save_history:
r = output - AHy
dc = 0.5 * torch.linalg.norm(r).item() ** 2
reg = lamda * torch.linalg.norm(output).item() ** 2
history.append(dc+reg)
if verbose and n in nprint:
t = time.time()
print(" {}{:.4f}{:.4f}{:.4f}{:.2f}".format(n, dc, reg, dc+reg, t-t0))

if verbose:
t1 = time.time()
print(f"Exiting Conjugate Gradient: total elapsed time: {round(t1-t0, 2)} [s]")

# back to original device
output = output.to(device)
Expand All @@ -110,7 +144,7 @@ def cg_solve(
if isnumpy:
output = output.numpy(force=True)

return output
return output, history


class CGStep(nn.Module):
Expand Down
Loading

0 comments on commit d660851

Please sign in to comment.