diff --git a/monai/losses/dice.py b/monai/losses/dice.py index e6f6bbb6cf..ec21baaa85 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -17,13 +17,13 @@ import numpy as np import torch -import torch.linalg as LA import torch.nn as nn import torch.nn.functional as F from torch.nn.modules.loss import _Loss from monai.losses.focal_loss import FocalLoss from monai.losses.spatial_mask import MaskedLoss +from monai.losses.utils import compute_tp_fp_fn from monai.networks import one_hot from monai.utils import DiceCEReduction, LossReduction, Weight, look_up_option, pytorch_after @@ -67,6 +67,7 @@ def __init__( smooth_dr: float = 1e-5, batch: bool = False, weight: Sequence[float] | float | int | torch.Tensor | None = None, + soft_label: bool = False, ) -> None: """ Args: @@ -98,6 +99,7 @@ def __init__( of the sequence should be the same as the number of classes. If not ``include_background``, the number of classes should not include the background category class 0). The value/values should be no less than 0. Defaults to None. + soft_label: whether the target contains non-binary values or not Raises: TypeError: When ``other_act`` is not an ``Optional[Callable]``. @@ -123,6 +125,7 @@ def __init__( weight = torch.as_tensor(weight) if weight is not None else None self.register_buffer("class_weight", weight) self.class_weight: None | torch.Tensor + self.soft_label = soft_label def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ @@ -183,22 +186,15 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: # reducing spatial dimensions and batch reduce_axis = [0] + reduce_axis - if self.squared_pred: - ground_o = torch.sum(target**2, dim=reduce_axis) - pred_o = torch.sum(input**2, dim=reduce_axis) - difference = LA.vector_norm(input - target, ord=2, dim=reduce_axis) ** 2 - else: - ground_o = torch.sum(target, dim=reduce_axis) - pred_o = torch.sum(input, dim=reduce_axis) - difference = LA.vector_norm(input - target, ord=1, dim=reduce_axis) - - denominator = ground_o + pred_o - intersection = (denominator - difference) / 2 - - if self.jaccard: - denominator = 2.0 * (denominator - intersection) + ord = 2 if self.squared_pred else 1 + tp, fp, fn = compute_tp_fp_fn(input, target, reduce_axis, ord, self.soft_label) + if not self.jaccard: + fp *= 0.5 + fn *= 0.5 + numerator = 2 * tp + self.smooth_nr + denominator = 2 * (tp + fp + fn) + self.smooth_dr - f: torch.Tensor = 1.0 - (2.0 * intersection + self.smooth_nr) / (denominator + self.smooth_dr) + f = 1 - numerator / denominator num_of_classes = target.shape[1] if self.class_weight is not None and num_of_classes != 1: @@ -282,6 +278,7 @@ def __init__( smooth_nr: float = 1e-5, smooth_dr: float = 1e-5, batch: bool = False, + soft_label: bool = False, ) -> None: """ Args: @@ -305,6 +302,7 @@ def __init__( batch: whether to sum the intersection and union areas over the batch dimension before the dividing. Defaults to False, intersection over union is computed from each item in the batch. If True, the class-weighted intersection and union areas are first summed across the batches. + soft_label: whether the target contains non-binary values or not Raises: TypeError: When ``other_act`` is not an ``Optional[Callable]``. @@ -329,6 +327,7 @@ def __init__( self.smooth_nr = float(smooth_nr) self.smooth_dr = float(smooth_dr) self.batch = batch + self.soft_label = soft_label def w_func(self, grnd): if self.w_type == str(Weight.SIMPLE): @@ -381,13 +380,12 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if self.batch: reduce_axis = [0] + reduce_axis - ground_o = torch.sum(target, reduce_axis) - pred_o = torch.sum(input, reduce_axis) - difference = LA.vector_norm(input - target, ord=1, dim=reduce_axis) - - denominator = ground_o + pred_o - intersection = (denominator - difference) / 2 + tp, fp, fn = compute_tp_fp_fn(input, target, reduce_axis, 1, self.soft_label) + fp *= 0.5 + fn *= 0.5 + denominator = 2 * (tp + fp + fn) + ground_o = torch.sum(target, reduce_axis) w = self.w_func(ground_o.float()) infs = torch.isinf(w) if self.batch: @@ -399,7 +397,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: w = w + infs * max_values final_reduce_dim = 0 if self.batch else 1 - numer = 2.0 * (intersection * w).sum(final_reduce_dim, keepdim=True) + self.smooth_nr + numer = 2.0 * (tp * w).sum(final_reduce_dim, keepdim=True) + self.smooth_nr denom = (denominator * w).sum(final_reduce_dim, keepdim=True) + self.smooth_dr f: torch.Tensor = 1.0 - (numer / denom) diff --git a/monai/losses/tversky.py b/monai/losses/tversky.py index a498882adf..d21aa76537 100644 --- a/monai/losses/tversky.py +++ b/monai/losses/tversky.py @@ -15,9 +15,9 @@ from collections.abc import Callable import torch -import torch.linalg as LA from torch.nn.modules.loss import _Loss +from monai.losses.utils import compute_tp_fp_fn from monai.networks import one_hot from monai.utils import LossReduction @@ -50,6 +50,7 @@ def __init__( smooth_nr: float = 1e-5, smooth_dr: float = 1e-5, batch: bool = False, + soft_label: bool = False, ) -> None: """ Args: @@ -74,6 +75,7 @@ def __init__( batch: whether to sum the intersection and union areas over the batch dimension before the dividing. Defaults to False, a Dice loss value is computed independently from each item in the batch before any `reduction`. + soft_label: whether the target contains non-binary values or not Raises: TypeError: When ``other_act`` is not an ``Optional[Callable]``. @@ -97,6 +99,7 @@ def __init__( self.smooth_nr = float(smooth_nr) self.smooth_dr = float(smooth_dr) self.batch = batch + self.soft_label = soft_label def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ @@ -144,13 +147,9 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: # reducing spatial dimensions and batch reduce_axis = [0] + reduce_axis - pred_o = torch.sum(input, reduce_axis) - ground_o = torch.sum(target, reduce_axis) - difference = LA.vector_norm(input - target, ord=1, dim=reduce_axis) - - tp = (pred_o + ground_o - difference) / 2 - fp = self.alpha * (pred_o - tp) - fn = self.beta * (ground_o - tp) + tp, fp, fn = compute_tp_fp_fn(input, target, reduce_axis, 1, self.soft_label, False) + fp *= self.alpha + fn *= self.beta numerator = tp + self.smooth_nr denominator = tp + fp + fn + self.smooth_dr diff --git a/monai/losses/utils.py b/monai/losses/utils.py new file mode 100644 index 0000000000..646ebd6d62 --- /dev/null +++ b/monai/losses/utils.py @@ -0,0 +1,60 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import warnings + +import torch +import torch.linalg as LA + + +def compute_tp_fp_fn( + input: torch.Tensor, + target: torch.Tensor, + reduce_axis: list[int], + ord: int, + soft_label: bool, + decoupled: bool = True, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Adapted from: + https://github.com/zifuwanggg/JDTLosses + """ + if torch.unique(target).shape[0] > 2 and not soft_label: + warnings.warn("soft labels are used, but `soft_label == False`.") + + # the original implementation that is erroneous with soft labels + if ord == 1 and not soft_label: + tp = torch.sum(input * target, dim=reduce_axis) + # the original implementation of Dice and Jaccard loss + if decoupled: + fp = torch.sum(input, dim=reduce_axis) - tp + fn = torch.sum(target, dim=reduce_axis) - tp + # the original implementation of Tversky loss + else: + fp = torch.sum(input * (1 - target), dim=reduce_axis) + fn = torch.sum((1 - input) * target, dim=reduce_axis) + else: + pred_o = LA.vector_norm(input, ord=ord, dim=reduce_axis) + ground_o = LA.vector_norm(target, ord=ord, dim=reduce_axis) + difference = LA.vector_norm(input - target, ord=ord, dim=reduce_axis) + + if ord > 1: + pred_o = torch.pow(pred_o, exponent=ord) + ground_o = torch.pow(ground_o, exponent=ord) + difference = torch.pow(difference, exponent=ord) + + tp = (pred_o + ground_o - difference) / 2 + fp = pred_o - tp + fn = ground_o - tp + + return tp, fp, fn