Skip to content

Commit

Permalink
4599 mri ssim loss (Project-MONAI#4600)
Browse files Browse the repository at this point in the history
Signed-off-by: mersad95zd <[email protected]>
  • Loading branch information
mersad95zd authored Jul 8, 2022
1 parent e270741 commit 4071750
Show file tree
Hide file tree
Showing 8 changed files with 270 additions and 1 deletion.
9 changes: 9 additions & 0 deletions docs/source/losses.rst
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,15 @@ Registration Losses
.. autoclass:: GlobalMutualInformationLoss
:members:

Reconstruction Losses
---------------------

`SSIMLoss`
~~~~~~~~~~
.. autoclass:: monai.losses.ssim_loss.SSIMLoss
:members:


Loss Wrappers
-------------

Expand Down
4 changes: 4 additions & 0 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ Metrics
.. autoclass:: PSNRMetric
:members:

`Structural similarity index measure`
-------------------------------------
.. autoclass:: monai.metrics.regression.SSIMMetric

`Cumulative average`
--------------------
.. autoclass:: CumulativeAverage
Expand Down
1 change: 1 addition & 0 deletions monai/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,6 @@
from .image_dissimilarity import GlobalMutualInformationLoss, LocalNormalizedCrossCorrelationLoss
from .multi_scale import MultiScaleLoss
from .spatial_mask import MaskedLoss
from .ssim_loss import SSIMLoss
from .tversky import TverskyLoss
from .unified_focal_loss import AsymmetricUnifiedFocalLoss
93 changes: 93 additions & 0 deletions monai/losses/ssim_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import torch
import torch.nn.functional as F
from torch import nn

from monai.utils.type_conversion import convert_to_dst_type


class SSIMLoss(nn.Module):
"""
Build a Pytorch version of the SSIM loss function based on the original formula of SSIM
Modified and adopted from:
https://github.com/facebookresearch/fastMRI/blob/main/banding_removal/fastmri/ssim_loss_mixin.py
For more info, visit
https://vicuesoft.com/glossary/term/ssim-ms-ssim/
SSIM reference paper:
Wang, Zhou, et al. "Image quality assessment: from error visibility to structural
similarity." IEEE transactions on image processing 13.4 (2004): 600-612.
"""

def __init__(self, win_size: int = 7, k1: float = 0.01, k2: float = 0.03, spatial_dims: int = 2):
"""
Args:
win_size: gaussian weighting window size
k1: stability constant used in the luminance denominator
k2: stability constant used in the contrast denominator
spatial_dims: if 2, input shape is expected to be (B,C,W,H). if 3, it is expected to be (B,C,W,H,D)
"""
super().__init__()
self.win_size = win_size
self.k1, self.k2 = k1, k2
self.spatial_dims = spatial_dims
self.register_buffer(
"w", torch.ones([1, 1] + [win_size for _ in range(spatial_dims)]) / win_size**spatial_dims
)
self.cov_norm = (win_size**2) / (win_size**2 - 1)

def forward(self, x: torch.Tensor, y: torch.Tensor, data_range: torch.Tensor) -> torch.Tensor:
"""
Args:
x: first sample (e.g., the reference image). Its shape is (B,C,W,H) for 2D data and (B,C,W,H,D) for 3D.
A fastMRI sample should use the 2D format with C being the number of slices.
y: second sample (e.g., the reconstructed image). It has similar shape as x.
data_range: dynamic range of the data
Returns:
1-ssim_value (recall this is meant to be a loss function)
Example:
.. code-block:: python
import torch
x = torch.ones([1,1,10,10])/2
y = torch.ones([1,1,10,10])/2
data_range = x.max().unsqueeze(0)
# the following line should print 1.0 (or 0.9999)
print(1-SSIMLoss(spatial_dims=2)(x,y,data_range))
"""
data_range = data_range[(None,) * (self.spatial_dims + 2)]
# determine whether to work with 2D convolution or 3D
conv = getattr(F, f"conv{self.spatial_dims}d")
w = convert_to_dst_type(src=self.w, dst=x)[0]

c1 = (self.k1 * data_range) ** 2 # stability constant for luminance
c2 = (self.k2 * data_range) ** 2 # stability constant for contrast
ux = conv(x, w) # mu_x
uy = conv(y, w) # mu_y
uxx = conv(x * x, w) # mu_x^2
uyy = conv(y * y, w) # mu_y^2
uxy = conv(x * y, w) # mu_xy
vx = self.cov_norm * (uxx - ux * ux) # sigma_x
vy = self.cov_norm * (uyy - uy * uy) # sigma_y
vxy = self.cov_norm * (uxy - ux * uy) # sigma_xy

numerator = (2 * ux * uy + c1) * (2 * vxy + c2)
denom = (ux**2 + uy**2 + c1) * (vx + vy + c2)
ssim_value = numerator / denom
loss: torch.Tensor = 1 - ssim_value.mean()
return loss
2 changes: 1 addition & 1 deletion monai/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from .hausdorff_distance import HausdorffDistanceMetric, compute_hausdorff_distance, compute_percent_hausdorff_distance
from .meandice import DiceMetric, compute_meandice
from .metric import Cumulative, CumulativeIterationMetric, IterationMetric, Metric
from .regression import MAEMetric, MSEMetric, PSNRMetric, RMSEMetric
from .regression import MAEMetric, MSEMetric, PSNRMetric, RMSEMetric, SSIMMetric
from .rocauc import ROCAUCMetric, compute_roc_auc
from .surface_dice import SurfaceDiceMetric, compute_surface_dice
from .surface_distance import SurfaceDistanceMetric, compute_average_surface_distance
Expand Down
62 changes: 62 additions & 0 deletions monai/metrics/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import torch

from monai.losses.ssim_loss import SSIMLoss
from monai.metrics.utils import do_metric_reduction
from monai.utils import MetricReduction

Expand Down Expand Up @@ -224,3 +225,64 @@ def compute_mean_error_metrics(y_pred: torch.Tensor, y: torch.Tensor, func) -> t
# reduction of batch handled inside __call__() using do_metric_reduction() in respective calling class
flt = partial(torch.flatten, start_dim=1)
return torch.mean(flt(func(y - y_pred)), dim=-1, keepdim=True)


class SSIMMetric(RegressionMetric):
r"""
Build a Pytorch version of the SSIM metric based on the original formula of SSIM
.. math::
\operatorname {SSIM}(x,y) =\frac {(2 \mu_x \mu_y + c_1)(2 \sigma_{xy} + c_2)}{((\mu_x^2 + \
\mu_y^2 + c_1)(\sigma_x^2 + \sigma_y^2 + c_2)}
For more info, visit
https://vicuesoft.com/glossary/term/ssim-ms-ssim/
Modified and adopted from:
https://github.com/facebookresearch/fastMRI/blob/main/banding_removal/fastmri/ssim_loss_mixin.py
SSIM reference paper:
Wang, Zhou, et al. "Image quality assessment: from error visibility to structural
similarity." IEEE transactions on image processing 13.4 (2004): 600-612.
Args:
data_range: dynamic range of the data
win_size: gaussian weighting window size
k1: stability constant used in the luminance denominator
k2: stability constant used in the contrast denominator
spatial_dims: if 2, input shape is expected to be (B,C,W,H). if 3, it is expected to be (B,C,W,H,D)
"""

def __init__(
self, data_range: torch.Tensor, win_size: int = 7, k1: float = 0.01, k2: float = 0.03, spatial_dims: int = 2
):
super().__init__()
self.data_range = data_range
self.win_size = win_size
self.k1, self.k2 = k1, k2
self.spatial_dims = spatial_dims

def _compute_metric(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""
Args:
x: first sample (e.g., the reference image). Its shape is (B,C,W,H) for 2D data and (B,C,W,H,D) for 3D.
A fastMRI sample should use the 2D format with C being the number of slices.
y: second sample (e.g., the reconstructed image). It has similar shape as x
Returns:
ssim_value
Example:
.. code-block:: python
import torch
x = torch.ones([1,1,10,10])/2 # ground truth
y = torch.ones([1,1,10,10])/2 # prediction
data_range = x.max().unsqueeze(0)
# the following line should print 1.0 (or 0.9999)
print(SSIMMetric(data_range=data_range,spatial_dims=2)._compute_metric(x,y))
"""
ssim_value: torch.Tensor = 1 - SSIMLoss(self.win_size, self.k1, self.k2, self.spatial_dims)(
x, y, self.data_range
)
return ssim_value
53 changes: 53 additions & 0 deletions tests/test_ssim_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import torch
from parameterized import parameterized

from monai.losses.ssim_loss import SSIMLoss

x = torch.ones([1, 1, 10, 10]) / 2
y1 = torch.ones([1, 1, 10, 10]) / 2
y2 = torch.zeros([1, 1, 10, 10])
data_range = x.max().unsqueeze(0)
TESTS2D = []
for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]:
TESTS2D.append((x.to(device), y1.to(device), data_range.to(device), torch.tensor(1.0).unsqueeze(0).to(device)))
TESTS2D.append((x.to(device), y2.to(device), data_range.to(device), torch.tensor(0.0).unsqueeze(0).to(device)))

x = torch.ones([1, 1, 10, 10, 10]) / 2
y1 = torch.ones([1, 1, 10, 10, 10]) / 2
y2 = torch.zeros([1, 1, 10, 10, 10])
data_range = x.max().unsqueeze(0)
TESTS3D = []
for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]:
TESTS3D.append((x.to(device), y1.to(device), data_range.to(device), torch.tensor(1.0).unsqueeze(0).to(device)))
TESTS3D.append((x.to(device), y2.to(device), data_range.to(device), torch.tensor(0.0).unsqueeze(0).to(device)))


class TestSSIMLoss(unittest.TestCase):
@parameterized.expand(TESTS2D)
def test2d(self, x, y, drange, res):
result = 1 - SSIMLoss(spatial_dims=2)(x, y, drange)
self.assertTrue(isinstance(result, torch.Tensor))
self.assertTrue(torch.abs(res - result).item() < 0.001)

@parameterized.expand(TESTS3D)
def test3d(self, x, y, drange, res):
result = 1 - SSIMLoss(spatial_dims=3)(x, y, drange)
self.assertTrue(isinstance(result, torch.Tensor))
self.assertTrue(torch.abs(res - result).item() < 0.001)


if __name__ == "__main__":
unittest.main()
47 changes: 47 additions & 0 deletions tests/test_ssim_metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import torch
from parameterized import parameterized

from monai.metrics.regression import SSIMMetric

x = torch.ones([1, 1, 10, 10]) / 2
y1 = torch.ones([1, 1, 10, 10]) / 2
y2 = torch.zeros([1, 1, 10, 10])
data_range = x.max().unsqueeze(0)
TESTS2D = [(x, y1, data_range, torch.tensor(1.0).unsqueeze(0)), (x, y2, data_range, torch.tensor(0.0).unsqueeze(0))]

x = torch.ones([1, 1, 10, 10, 10]) / 2
y1 = torch.ones([1, 1, 10, 10, 10]) / 2
y2 = torch.zeros([1, 1, 10, 10, 10])
data_range = x.max().unsqueeze(0)
TESTS3D = [(x, y1, data_range, torch.tensor(1.0).unsqueeze(0)), (x, y2, data_range, torch.tensor(0.0).unsqueeze(0))]


class TestSSIMMetric(unittest.TestCase):
@parameterized.expand(TESTS2D)
def test2d(self, x, y, drange, res):
result = SSIMMetric(data_range=drange, spatial_dims=2)._compute_metric(x, y)
self.assertTrue(isinstance(result, torch.Tensor))
self.assertTrue(torch.abs(res - result).item() < 0.001)

@parameterized.expand(TESTS3D)
def test3d(self, x, y, drange, res):
result = SSIMMetric(data_range=drange, spatial_dims=3)._compute_metric(x, y)
self.assertTrue(isinstance(result, torch.Tensor))
self.assertTrue(torch.abs(res - result).item() < 0.001)


if __name__ == "__main__":
unittest.main()

0 comments on commit 4071750

Please sign in to comment.