+"""Total generalized variation denoising prior."""
+
+__all__ = ["TGVDenoiser", "tgv_denoise"]
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+
+[docs]class TGVDenoiser(nn.Module):
+
r"""
+
Proximal operator of (2nd order) Total Generalised Variation operator.
+
+
(see K. Bredies, K. Kunisch, and T. Pock, "Total generalized variation," SIAM J. Imaging Sci., 3(3), 492-526, 2010.)
+
+
This algorithm converges to the unique image :math:`x` (and the auxiliary vector field :math:`r`) minimizing
+
+
.. math::
+
+
\underset{x, r}{\arg\min} \; \frac{1}{2}\|x-y\|_2^2 + \lambda_1 \|r\|_{1,2} + \lambda_2 \|J(Dx-r)\|_{1,F}
+
+
where :math:`D` maps an image to its gradient field and :math:`J` maps a vector field to its Jacobian.
+
For a large value of :math:`\lambda_2`, the TGV behaves like the TV.
+
For a small value, it behaves like the :math:`\ell_1`-Frobenius norm of the Hessian.
+
+
The problem is solved with an over-relaxed Chambolle-Pock algorithm (see L. Condat, "A primal-dual splitting method
+
for convex optimization involving Lipschitzian, proximable and linear composite terms", J. Optimization Theory and
+
Applications, vol. 158, no. 2, pp. 460-479, 2013.
+
+
Code (and description) adapted from Laurent Condat's matlab version (https://lcondat.github.io/software.html) and
+
Daniil Smolyakov's `code <https://github.com/RoundedGlint585/TGVDenoising/blob/master/TGV%20WithoutHist.ipynb>`_.
+
+
Attributes
+
----------
+
ndim : int
+
Number of spatial dimensions, can be either ``2`` or ``3``.
+
ths : float, optional
+
Denoise threshold. The default is ``0.1``.
+
trainable : bool, optional
+
If ``True``, threshold value is trainable, otherwise it is not.
+
The default is ``False``.
+
device : str, optional
+
Device on which the wavelet transform is computed.
+
The default is ``None`` (infer from input).
+
verbose : bool, optional
+
Whether to print computation details or not. The default is ``False``.
+
niter : int, optional,
+
Maximum number of iterations. The default is ``1000``.
+
crit : float, optional
+
Convergence criterion. The default is 1e-5.
+
x2 : torch.Tensor, optional
+
Primary variable for warm restart. The default is ``None``.
+
u2 : torch.Tensor, optional
+
Dual variable for warm restart. The default is ``None``.
+
r2 : torch.Tensor, optional
+
Auxiliary variable for warm restart. The default is ``None``.
+
+
Notes
+
-----
+
The regularization term :math:`\|r\|_{1,2} + \|J(Dx-r)\|_{1,F}` is implicitly normalized by its Lipschitz
+
constant, i.e. :math:`\sqrt{72}`, see e.g. K. Bredies et al., "Total generalized variation," SIAM J. Imaging
+
Sci., 3(3), 492-526, 2010.
+
+
"""
+
+
[docs] def __init__(
+
self,
+
ndim,
+
ths=0.1,
+
trainable=False,
+
device=None,
+
verbose=False,
+
niter=100,
+
crit=1e-5,
+
x2=None,
+
u2=None,
+
r2=None,
+
):
+
super().__init__()
+
+
if trainable:
+
self.ths = nn.Parameter(ths)
+
else:
+
self.ths = ths
+
+
self.denoiser = _TGVDenoiser(
+
ndim=ndim,
+
device=device,
+
verbose=verbose,
+
n_it_max=niter,
+
crit=crit,
+
x2=x2,
+
u2=u2,
+
r2=r2,
+
)
+
self.denoiser.device = device
+
+
def forward(self, input):
+
# get complex
+
if torch.is_complex(input):
+
iscomplex = True
+
else:
+
iscomplex = False
+
+
# default device
+
idevice = input.device
+
if self.denoiser.device is None:
+
device = idevice
+
else:
+
device = self.denoiser.device
+
+
# get input shape
+
ndim = self.denoiser.ndim
+
ishape = input.shape
+
+
# reshape for computation
+
input = input.reshape(-1, *ishape[-ndim:])
+
if iscomplex:
+
input = torch.stack((input.real, input.imag), axis=1)
+
input = input.reshape(-1, *ishape[-ndim:])
+
+
# apply denoising
+
output = self.denoiser(input[:, None, ...].to(device), self.ths).to(
+
idevice
+
) # perform the denoising on the real-valued tensor
+
+
# reshape back
+
if iscomplex:
+
output = (
+
output[::2, ...] + 1j * output[1::2, ...]
+
) # build the denoised complex data
+
output = output.reshape(ishape)
+
+
return output.to(idevice)
+
+
+[docs]def tgv_denoise(
+
input,
+
ndim,
+
ths=0.1,
+
device=None,
+
verbose=False,
+
niter=100,
+
crit=1e-5,
+
x2=None,
+
u2=None,
+
):
+
r"""
+
Apply Total Generalized Variation denoising.
+
+
(see K. Bredies, K. Kunisch, and T. Pock, "Total generalized variation," SIAM J. Imaging Sci., 3(3), 492-526, 2010.)
+
+
This algorithm converges to the unique image :math:`x` (and the auxiliary vector field :math:`r`) minimizing
+
+
.. math::
+
+
\underset{x, r}{\arg\min} \; \frac{1}{2}\|x-y\|_2^2 + \lambda_1 \|r\|_{1,2} + \lambda_2 \|J(Dx-r)\|_{1,F}
+
+
where :math:`D` maps an image to its gradient field and :math:`J` maps a vector field to its Jacobian.
+
For a large value of :math:`\lambda_2`, the TGV behaves like the TV.
+
For a small value, it behaves like the :math:`\ell_1`-Frobenius norm of the Hessian.
+
+
The problem is solved with an over-relaxed Chambolle-Pock algorithm (see L. Condat, "A primal-dual splitting method
+
for convex optimization involving Lipschitzian, proximable and linear composite terms", J. Optimization Theory and
+
Applications, vol. 158, no. 2, pp. 460-479, 2013.
+
+
Code (and description) adapted from Laurent Condat's matlab version (https://lcondat.github.io/software.html) and
+
Daniil Smolyakov's `code <https://github.com/RoundedGlint585/TGVDenoising/blob/master/TGV%20WithoutHist.ipynb>`_.
+
+
Arguments
+
---------
+
input : np.ndarray | torch.Tensor
+
Input image of shape (..., n_ndim, ..., n_0).
+
ndim : int
+
Number of spatial dimensions, can be either ``2`` or ``3``.
+
ths : float, optional
+
Denoise threshold. Default is ``0.1``.
+
ndim : int
+
Number of spatial dimensions, can be either ``2`` or ``3``.
+
ths : float, optional
+
Denoise threshold. The default is ``0.1``.
+
trainable : bool, optional
+
If ``True``, threshold value is trainable, otherwise it is not.
+
The default is ``False``.
+
device : str, optional
+
Device on which the wavelet transform is computed.
+
The default is ``None`` (infer from input).
+
verbose : bool, optional
+
Whether to print computation details or not. The default is ``False``.
+
niter : int, optional,
+
Maximum number of iterations. The default is ``1000``.
+
crit : float, optional
+
Convergence criterion. The default is 1e-5.
+
x2 : torch.Tensor, optional
+
Primary variable for warm restart. The default is ``None``.
+
u2 : torch.Tensor, optional
+
Dual variable for warm restart. The default is ``None``.
+
r2 : torch.Tensor, optional
+
Auxiliary variable for warm restart. The default is ``None``.
+
+
Returns
+
-------
+
output : np.ndarray | torch.Tensor
+
Denoised image of shape (..., n_ndim, ..., n_0).
+
+
"""
+
# cast to numpy if required
+
if isinstance(input, np.ndarray):
+
isnumpy = True
+
input = torch.as_tensor(input)
+
else:
+
isnumpy = False
+
+
# initialize denoiser
+
TV = TGVDenoiser(ndim, ths, False, device, verbose, niter, crit, x2, u2)
+
output = TV(input)
+
+
# cast back to numpy if requried
+
if isnumpy:
+
output = output.numpy(force=True)
+
+
return output
+
+
+# %% local utils
+class _TGVDenoiser(nn.Module):
+ def __init__(
+ self,
+ ndim,
+ device,
+ verbose=False,
+ n_it_max=1000,
+ crit=1e-5,
+ x2=None,
+ u2=None,
+ r2=None,
+ ):
+ super().__init__()
+ self.device = device
+ self.ndim = ndim
+
+ if ndim == 2:
+ self.nabla = self.nabla2
+ self.nabla_adjoint = self.nabla2_adjoint
+ self.epsilon = self.epsilon2
+ self.epsilon_adjoint = self.epsilon2_adjoint
+ elif ndim == 3:
+ self.nabla = self.nabla3
+ self.nabla_adjoint = self.nabla3_adjoint
+ self.epsilon = self.epsilon3
+ self.epsilon_adjoint = self.epsilon3_adjoint
+
+ self.verbose = verbose
+ self.n_it_max = n_it_max
+ self.crit = crit
+ self.restart = True
+
+ self.tau = 0.01 # > 0
+
+ self.rho = 1.99 # in 1,2
+ self.sigma = 1 / self.tau / 72
+
+ self.x2 = x2
+ self.r2 = r2
+ self.u2 = u2
+
+ self.has_converged = False
+
+ def prox_tau_fx(self, x, y):
+ return (x + self.tau * y) / (1 + self.tau)
+
+ def prox_tau_fr(self, r, lambda1):
+ left = torch.sqrt(torch.sum(r**2, axis=-1)) / (self.tau * lambda1)
+ tmp = r - r / (
+ torch.maximum(
+ left, torch.tensor([1], device=left.device).type(left.dtype)
+ ).unsqueeze(-1)
+ )
+ return tmp
+
+ def prox_sigma_g_conj(self, u, lambda2):
+ return u / (
+ torch.maximum(
+ torch.sqrt(torch.sum(u**2, axis=-1)) / lambda2,
+ torch.tensor([1], device=u.device).type(u.dtype),
+ ).unsqueeze(-1)
+ )
+
+ def forward(self, y, ths=None):
+ restart = (
+ True
+ if (self.restart or self.x2 is None or self.x2.shape != y.shape)
+ else False
+ )
+
+ if restart:
+ self.x2 = y.clone()
+ self.r2 = torch.zeros((*self.x2.shape, 2), device=self.x2.device).type(
+ self.x2.dtype
+ )
+ self.u2 = torch.zeros((*self.x2.shape, 4), device=self.x2.device).type(
+ self.x2.dtype
+ )
+ self.restart = False
+
+ if ths is not None:
+ lambda1 = ths * 0.1
+ lambda2 = ths * 0.15
+
+ cy = (y**2).sum() / 2
+ primalcostlowerbound = 0
+
+ for _ in range(self.n_it_max):
+ x_prev = self.x2.clone()
+ tmp = self.tau * self.epsilon_adjoint(self.u2)
+ x = self.prox_tau_fx(self.x2 - self.nabla_adjoint(tmp), y)
+ r = self.prox_tau_fr(self.r2 + tmp, lambda1)
+ u = self.prox_sigma_g_conj(
+ self.u2
+ + self.sigma
+ * self.epsilon(self.nabla(2 * x - self.x2) - (2 * r - self.r2)),
+ lambda2,
+ )
+ self.x2 = self.x2 + self.rho * (x - self.x2)
+ self.r2 = self.r2 + self.rho * (r - self.r2)
+ self.u2 = self.u2 + self.rho * (u - self.u2)
+
+ rel_err = torch.linalg.norm(
+ x_prev.flatten() - self.x2.flatten()
+ ) / torch.linalg.norm(self.x2.flatten() + 1e-12)
+
+ if _ > 1 and rel_err < self.crit:
+ self.has_converged = True
+ if self.verbose:
+ print("TGV prox reached convergence")
+ break
+
+ if self.verbose and _ % 100 == 0:
+ primalcost = (
+ torch.linalg.norm(x.flatten() - y.flatten()) ** 2
+ + lambda1 * torch.sum(torch.sqrt(torch.sum(r**2, axis=-1)))
+ + lambda2
+ * torch.sum(
+ torch.sqrt(
+ torch.sum(self.epsilon(self.nabla(x) - r) ** 2, axis=-1)
+ )
+ )
+ )
+ # dualcost = cy - ((y - nablaT(epsilonT(u))) ** 2).sum() / 2.0
+ tmp = torch.max(
+ torch.sqrt(torch.sum(self.epsilon_adjoint(u) ** 2, axis=-1))
+ ) # to check feasibility: the value will be <= lambda1 only at convergence. Since u is not feasible, the dual cost is not reliable: the gap=primalcost-dualcost can be <0 and cannot be used as stopping criterion.
+ u3 = u / torch.maximum(
+ tmp / lambda1, torch.tensor([1], device=tmp.device).type(tmp.dtype)
+ ) # u3 is a scaled version of u, which is feasible. so, its dual cost is a valid, but very rough lower bound of the primal cost.
+ dualcost2 = (
+ cy
+ - torch.sum((y - self.nabla_adjoint(self.epsilon_adjoint(u3))) ** 2)
+ / 2.0
+ ) # we display the best value of dualcost2 computed so far.
+ primalcostlowerbound = max(primalcostlowerbound, dualcost2.item())
+ if self.verbose:
+ print(
+ "Iter: ",
+ _,
+ " Primal cost: ",
+ primalcost.item(),
+ " Rel err:",
+ rel_err,
+ )
+
+ if _ == self.n_it_max - 1:
+ if self.verbose:
+ print(
+ "The algorithm did not converge, stopped after "
+ + str(_ + 1)
+ + " iterations."
+ )
+
+ return self.x2
+
+ @staticmethod
+ def nabla2(x):
+ r"""
+ Applies the finite differences operator associated with tensors of the same shape as x.
+ """
+ b, c, h, w = x.shape
+ u = torch.zeros((b, c, h, w, 2), device=x.device).type(x.dtype)
+ u[:, :, :-1, :, 0] = u[:, :, :-1, :, 0] - x[:, :, :-1]
+ u[:, :, :-1, :, 0] = u[:, :, :-1, :, 0] + x[:, :, 1:]
+ u[:, :, :, :-1, 1] = u[:, :, :, :-1, 1] - x[..., :-1]
+ u[:, :, :, :-1, 1] = u[:, :, :, :-1, 1] + x[..., 1:]
+ return u
+
+ @staticmethod
+ def nabla2_adjoint(x):
+ r"""
+ Applies the adjoint of the finite difference operator.
+ """
+ b, c, h, w = x.shape[:-1]
+ u = torch.zeros((b, c, h, w), device=x.device).type(
+ x.dtype
+ ) # note that we just reversed left and right sides of each line to obtain the transposed operator
+ u[:, :, :-1] = u[:, :, :-1] - x[:, :, :-1, :, 0]
+ u[:, :, 1:] = u[:, :, 1:] + x[:, :, :-1, :, 0]
+ u[..., :-1] = u[..., :-1] - x[..., :-1, 1]
+ u[..., 1:] = u[..., 1:] + x[..., :-1, 1]
+ return u
+
+ @staticmethod
+ def nabla3(x):
+ r"""
+ Applies the finite differences operator associated with tensors of the same shape as x.
+ """
+ b, c, d, h, w = x.shape
+ u = torch.zeros((b, c, d, h, w, 3), device=x.device).type(x.dtype)
+ u[:, :, :-1, :, :, 0] = u[:, :, :-1, :, :, 0] - x[:, :, :-1]
+ u[:, :, :-1, :, :, 0] = u[:, :, :-1, :, :, 0] + x[:, :, 1:]
+ u[:, :, :, :-1, :, 1] = u[:, :, :, :-1, :, 1] - x[:, :, :, :-1]
+ u[:, :, :, :-1, :, 1] = u[:, :, :, :-1, :, 1] + x[:, :, :, 1:]
+ u[:, :, :, :, :-1, 2] = u[:, :, :, :, :-1, 2] - x[:, :, :, :, :-1]
+ u[:, :, :, :, :-1, 2] = u[:, :, :, :, :-1, 2] + x[:, :, :, :, 1:]
+
+ return u
+
+ @staticmethod
+ def nabla3_adjoint(x):
+ r"""
+ Applies the adjoint of the finite difference operator.
+ """
+ b, c, d, h, w = x.shape
+ u = torch.zeros((b, c, d, h, w), device=x.device).type(x.dtype)
+ u[:, :, :-1, :, 0] = u[:, :, :-1, :, 0] - x[:, :, :-1]
+ u[:, :, 1:, :, 0] = u[:, :, 1:, :, 0] + x[:, :, :-1]
+ u[:, :, :, :-1, 1] = u[:, :, :, :-1, 1] - x[:, :, :, :-1]
+ u[:, :, :, 1:, 1] = u[:, :, :, 1:, 1] + x[:, :, :, :-1]
+ u[:, :, :, :, :-1, 2] = u[:, :, :, :, :-1, 2] - x[:, :, :, :, :-1]
+ u[:, :, :, :, 1:, 2] = u[:, :, :, :, 1:, 2] + x[:, :, :, :, :-1]
+
+ return u
+
+ @staticmethod
+ def epsilon2(I): # Simplified
+ r"""
+ Applies the jacobian of a vector field.
+ """
+ b, c, h, w, _ = I.shape
+ G = torch.zeros((b, c, h, w, 4), device=I.device).type(I.dtype)
+ G[:, :, 1:, :, 0] = G[:, :, 1:, :, 0] - I[:, :, :-1, :, 0] # xdy
+ G[..., 0] = G[..., 0] + I[..., 0]
+ G[..., 1:, 1] = G[..., 1:, 1] - I[..., :-1, 0] # xdx
+ G[..., 1:, 1] = G[..., 1:, 1] + I[..., 1:, 0]
+ G[..., 1:, 2] = G[..., 1:, 2] - I[..., :-1, 1] # xdx
+ G[..., 2] = G[..., 2] + I[..., 1]
+ G[:, :, :-1, :, 3] = G[:, :, :-1, :, 3] - I[:, :, :-1, :, 1] # xdy
+ G[:, :, :-1, :, 3] = G[:, :, :-1, :, 3] + I[:, :, 1:, :, 1]
+ return G
+
+ @staticmethod
+ def epsilon2_adjoint(G):
+ r"""
+ Applies the adjoint of the jacobian of a vector field.
+ """
+ b, c, h, w, _ = G.shape
+ I = torch.zeros((b, c, h, w, 2), device=G.device).type(G.dtype)
+ I[:, :, :-1, :, 0] = I[:, :, :-1, :, 0] - G[:, :, 1:, :, 0]
+ I[..., 0] = I[..., 0] + G[..., 0]
+ I[..., :-1, 0] = I[..., :-1, 0] - G[..., 1:, 1]
+ I[..., 1:, 0] = I[..., 1:, 0] + G[..., 1:, 1]
+ I[..., :-1, 1] = I[..., :-1, 1] - G[..., 1:, 2]
+ I[..., 1] = I[..., 1] + G[..., 2]
+ I[:, :, :-1, :, 1] = I[:, :, :-1, :, 1] - G[:, :, :-1, :, 3]
+ I[:, :, 1:, :, 1] = I[:, :, 1:, :, 1] + G[:, :, :-1, :, 3]
+ return I
+
+ @staticmethod
+ def epsilon3(I): # Adapted for 3D matrices
+ r"""
+ Applies the jacobian of a vector field.
+ """
+ b, c, d, h, w = I.shape
+ G = torch.zeros((b, c, d, h, w, 6), device=I.device).type(I.dtype)
+ G[:, :, :, 1:, :, 0] = G[:, :, :, 1:, :, 0] - I[:, :, :, :-1, :, 0] # xdy
+ G[..., 0] = G[..., 0] + I[..., 0]
+ G[..., 1:, :, 1] = G[..., 1:, :, 1] - I[..., :, :-1, 0] # xdx
+ G[..., 1:, :, 1] = G[..., 1:, :, 1] + I[..., :, 1:, 0]
+ G[..., 1:, :, 2] = G[..., 1:, :, 2] - I[..., :, :-1, 1] # xdz
+ G[..., 2] = G[..., 2] + I[..., :, 1, 0]
+ G[:, :, :, :-1, :, 3] = G[:, :, :, :-1, :, 3] - I[:, :, :, :-1, :, 1] # xdy
+ G[:, :, :, :-1, :, 3] = G[:, :, :, :-1, :, 3] + I[:, :, :, 1:, :, 1]
+ G[..., 3] = G[..., 3] + I[..., 0]
+ G[..., 1:, :, 4] = G[..., 1:, :, 4] - I[..., 1:, :, :-1, 2] # xdz
+ G[..., 4] = G[..., 4] + I[..., 1, :, :, 0]
+ G[:, :, :, :, :-1, 5] = G[:, :, :, :, :-1, 5] - I[:, :, :, :, :-1, 2] # xdy
+ G[:, :, :, 1:, :, 5] = G[:, :, :, 1:, :, 5] + I[:, :, :, :, :-1, 2]
+ return G
+
+ @staticmethod
+ def epsilon3_adjoint(G): # Adapted for 3D matrices
+ r"""
+ Applies the adjoint of the jacobian of a vector field.
+ """
+ b, c, d, h, w, _ = G.shape
+ I = torch.zeros((b, c, d, h, w, 3), device=G.device).type(G.dtype)
+ I[:, :, :, :-1, :, 0] = I[:, :, :, :-1, :, 0] - G[:, :, :, 1:, :, 0]
+ I[..., 0] = I[..., 0] + G[..., 0]
+ I[..., :-1, :, 0] = I[..., :-1, :, 0] - G[..., 1:, :, 1]
+ I[..., 1:, :, 0] = I[..., 1:, :, 0] + G[..., 1:, :, 1]
+ I[..., :-1, :, 1] = I[..., :-1, :, 1] - G[..., 1:, :, 2]
+ I[..., 0] = I[..., 0] + G[..., 2]
+ I[:, :, :, :-1, :, 1] = I[:, :, :, :-1, :, 1] - G[:, :, :, :-1, :, 3]
+ I[:, :, :, 1:, :, 1] = I[:, :, :, 1:, :, 1] + G[:, :, :, :-1, :, 3]
+ I[..., 1] = I[..., 1] + G[..., 3]
+ I[..., :-1, :, 2] = I[..., :-1, :, 2] - G[..., 1:, :, 4]
+ I[..., 0] = I[..., 0] + G[..., 4]
+ I[:, :, :, :, :-1, 2] = I[:, :, :, :, :-1, 2] - G[:, :, :, :, :-1, 5]
+ I[:, :, :, 1:, :, 2] = I[:, :, :, 1:, :, 2] + G[:, :, :, :, :-1, 5]
+ return I
+
+
+