Skip to content

Commit

Permalink
Add helper function
Browse files Browse the repository at this point in the history
  • Loading branch information
zifuwanggg committed Oct 21, 2024
1 parent cfd2d1e commit 3f74183
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 31 deletions.
44 changes: 21 additions & 23 deletions monai/losses/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@

import numpy as np
import torch
import torch.linalg as LA
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.loss import _Loss

from monai.losses.focal_loss import FocalLoss
from monai.losses.spatial_mask import MaskedLoss
from monai.losses.utils import compute_tp_fp_fn
from monai.networks import one_hot
from monai.utils import DiceCEReduction, LossReduction, Weight, look_up_option, pytorch_after

Expand Down Expand Up @@ -67,6 +67,7 @@ def __init__(
smooth_dr: float = 1e-5,
batch: bool = False,
weight: Sequence[float] | float | int | torch.Tensor | None = None,
soft_label: bool = False,
) -> None:
"""
Args:
Expand Down Expand Up @@ -98,6 +99,7 @@ def __init__(
of the sequence should be the same as the number of classes. If not ``include_background``,
the number of classes should not include the background category class 0).
The value/values should be no less than 0. Defaults to None.
soft_label: whether the target contains non-binary values or not
Raises:
TypeError: When ``other_act`` is not an ``Optional[Callable]``.
Expand All @@ -123,6 +125,7 @@ def __init__(
weight = torch.as_tensor(weight) if weight is not None else None
self.register_buffer("class_weight", weight)
self.class_weight: None | torch.Tensor
self.soft_label = soft_label

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Expand Down Expand Up @@ -183,22 +186,15 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
# reducing spatial dimensions and batch
reduce_axis = [0] + reduce_axis

if self.squared_pred:
ground_o = torch.sum(target**2, dim=reduce_axis)
pred_o = torch.sum(input**2, dim=reduce_axis)
difference = LA.vector_norm(input - target, ord=2, dim=reduce_axis) ** 2
else:
ground_o = torch.sum(target, dim=reduce_axis)
pred_o = torch.sum(input, dim=reduce_axis)
difference = LA.vector_norm(input - target, ord=1, dim=reduce_axis)

denominator = ground_o + pred_o
intersection = (denominator - difference) / 2

if self.jaccard:
denominator = 2.0 * (denominator - intersection)
ord = 2 if self.squared_pred else 1
tp, fp, fn = compute_tp_fp_fn(input, target, reduce_axis, ord, self.soft_label)
if not self.jaccard:
fp *= 0.5
fn *= 0.5
numerator = 2 * tp + self.smooth_nr
denominator = 2 * (tp + fp + fn) + self.smooth_dr

f: torch.Tensor = 1.0 - (2.0 * intersection + self.smooth_nr) / (denominator + self.smooth_dr)
f = 1 - numerator / denominator

num_of_classes = target.shape[1]
if self.class_weight is not None and num_of_classes != 1:
Expand Down Expand Up @@ -282,6 +278,7 @@ def __init__(
smooth_nr: float = 1e-5,
smooth_dr: float = 1e-5,
batch: bool = False,
soft_label: bool = False,
) -> None:
"""
Args:
Expand All @@ -305,6 +302,7 @@ def __init__(
batch: whether to sum the intersection and union areas over the batch dimension before the dividing.
Defaults to False, intersection over union is computed from each item in the batch.
If True, the class-weighted intersection and union areas are first summed across the batches.
soft_label: whether the target contains non-binary values or not
Raises:
TypeError: When ``other_act`` is not an ``Optional[Callable]``.
Expand All @@ -329,6 +327,7 @@ def __init__(
self.smooth_nr = float(smooth_nr)
self.smooth_dr = float(smooth_dr)
self.batch = batch
self.soft_label = soft_label

def w_func(self, grnd):
if self.w_type == str(Weight.SIMPLE):
Expand Down Expand Up @@ -381,13 +380,12 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
if self.batch:
reduce_axis = [0] + reduce_axis

ground_o = torch.sum(target, reduce_axis)
pred_o = torch.sum(input, reduce_axis)
difference = LA.vector_norm(input - target, ord=1, dim=reduce_axis)

denominator = ground_o + pred_o
intersection = (denominator - difference) / 2
tp, fp, fn = compute_tp_fp_fn(input, target, reduce_axis, 1, self.soft_label)
fp *= 0.5
fn *= 0.5
denominator = 2 * (tp + fp + fn)

ground_o = torch.sum(target, reduce_axis)
w = self.w_func(ground_o.float())
infs = torch.isinf(w)
if self.batch:
Expand All @@ -399,7 +397,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
w = w + infs * max_values

final_reduce_dim = 0 if self.batch else 1
numer = 2.0 * (intersection * w).sum(final_reduce_dim, keepdim=True) + self.smooth_nr
numer = 2.0 * (tp * w).sum(final_reduce_dim, keepdim=True) + self.smooth_nr
denom = (denominator * w).sum(final_reduce_dim, keepdim=True) + self.smooth_dr
f: torch.Tensor = 1.0 - (numer / denom)

Expand Down
15 changes: 7 additions & 8 deletions monai/losses/tversky.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
from collections.abc import Callable

import torch
import torch.linalg as LA
from torch.nn.modules.loss import _Loss

from monai.losses.utils import compute_tp_fp_fn
from monai.networks import one_hot
from monai.utils import LossReduction

Expand Down Expand Up @@ -50,6 +50,7 @@ def __init__(
smooth_nr: float = 1e-5,
smooth_dr: float = 1e-5,
batch: bool = False,
soft_label: bool = False,
) -> None:
"""
Args:
Expand All @@ -74,6 +75,7 @@ def __init__(
batch: whether to sum the intersection and union areas over the batch dimension before the dividing.
Defaults to False, a Dice loss value is computed independently from each item in the batch
before any `reduction`.
soft_label: whether the target contains non-binary values or not
Raises:
TypeError: When ``other_act`` is not an ``Optional[Callable]``.
Expand All @@ -97,6 +99,7 @@ def __init__(
self.smooth_nr = float(smooth_nr)
self.smooth_dr = float(smooth_dr)
self.batch = batch
self.soft_label = soft_label

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Expand Down Expand Up @@ -144,13 +147,9 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
# reducing spatial dimensions and batch
reduce_axis = [0] + reduce_axis

pred_o = torch.sum(input, reduce_axis)
ground_o = torch.sum(target, reduce_axis)
difference = LA.vector_norm(input - target, ord=1, dim=reduce_axis)

tp = (pred_o + ground_o - difference) / 2
fp = self.alpha * (pred_o - tp)
fn = self.beta * (ground_o - tp)
tp, fp, fn = compute_tp_fp_fn(input, target, reduce_axis, 1, self.soft_label, False)
fp *= self.alpha
fn *= self.beta
numerator = tp + self.smooth_nr
denominator = tp + fp + fn + self.smooth_dr

Expand Down
60 changes: 60 additions & 0 deletions monai/losses/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import warnings

import torch
import torch.linalg as LA


def compute_tp_fp_fn(
input: torch.Tensor,
target: torch.Tensor,
reduce_axis: list[int],
ord: int,
soft_label: bool,
decoupled: bool = True,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Adapted from:
https://github.com/zifuwanggg/JDTLosses
"""
if torch.unique(target).shape[0] > 2 and not soft_label:
warnings.warn("soft labels are used, but `soft_label == False`.")

# the original implementation that is erroneous with soft labels
if ord == 1 and not soft_label:
tp = torch.sum(input * target, dim=reduce_axis)
# the original implementation of Dice and Jaccard loss
if decoupled:
fp = torch.sum(input, dim=reduce_axis) - tp
fn = torch.sum(target, dim=reduce_axis) - tp
# the original implementation of Tversky loss
else:
fp = torch.sum(input * (1 - target), dim=reduce_axis)
fn = torch.sum((1 - input) * target, dim=reduce_axis)
else:
pred_o = LA.vector_norm(input, ord=ord, dim=reduce_axis)
ground_o = LA.vector_norm(target, ord=ord, dim=reduce_axis)
difference = LA.vector_norm(input - target, ord=ord, dim=reduce_axis)

if ord > 1:
pred_o = torch.pow(pred_o, exponent=ord)
ground_o = torch.pow(ground_o, exponent=ord)
difference = torch.pow(difference, exponent=ord)

tp = (pred_o + ground_o - difference) / 2
fp = pred_o - tp
fn = ground_o - tp

return tp, fp, fn

0 comments on commit 3f74183

Please sign in to comment.