Skip to content

Commit

Permalink
Remove cuda version focal loss (#3431)
Browse files Browse the repository at this point in the history
  • Loading branch information
jaegukhyun authored Apr 30, 2024
1 parent 950bec0 commit 8766676
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 61 deletions.
13 changes: 5 additions & 8 deletions src/otx/algo/detection/losses/cross_focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torch import Tensor, nn
from torch.cuda.amp import custom_fwd

from otx.algo.detection.losses.focal_loss import py_sigmoid_focal_loss, sigmoid_focal_loss
from otx.algo.detection.losses.focal_loss import py_sigmoid_focal_loss


def cross_sigmoid_focal_loss(
Expand All @@ -36,13 +36,10 @@ def cross_sigmoid_focal_loss(
avg_factor: average factors.
valid_label_mask: ignore label mask.
"""
if torch.cuda.is_available() and inputs.is_cuda:
calculate_loss_func = sigmoid_focal_loss
else:
inputs_size = inputs.size(1)
targets = torch.nn.functional.one_hot(targets, num_classes=inputs_size + 1)
targets = targets[:, :inputs_size]
calculate_loss_func = py_sigmoid_focal_loss
inputs_size = inputs.size(1)
targets = torch.nn.functional.one_hot(targets, num_classes=inputs_size + 1)
targets = targets[:, :inputs_size]
calculate_loss_func = py_sigmoid_focal_loss

loss = calculate_loss_func(
inputs,
Expand Down
53 changes: 0 additions & 53 deletions src/otx/algo/detection/losses/focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,6 @@
import torch
import torch.nn.functional

# TODO(Eugene): replace mmcv.sigmoid_focal_loss with torchvision
# https://github.com/openvinotoolkit/training_extensions/pull/3281
from mmcv.ops import sigmoid_focal_loss as _sigmoid_focal_loss

from otx.algo.detection.losses.weighted_loss import weight_reduce_loss

if TYPE_CHECKING:
Expand Down Expand Up @@ -74,52 +70,3 @@ def py_sigmoid_focal_loss(
msg = "The number of dimensions in weight should be equal to the number of dimensions in loss."
raise ValueError(msg)
return weight_reduce_loss(loss, weight, reduction, avg_factor)


def sigmoid_focal_loss(
pred: Tensor,
target: Tensor,
weight: None | Tensor = None,
gamma: float = 2.0,
alpha: float = 0.25,
reduction: str = "mean",
avg_factor: int | None = None,
) -> torch.Tensor:
r"""A wrapper of cuda version `Focal Loss <https://arxiv.org/abs/1708.02002>`_.
Args:
pred (torch.Tensor): The prediction with shape (N, C), C is the number
of classes.
target (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): Sample-wise loss weight.
gamma (float, optional): The gamma for calculating the modulating
factor. Defaults to 2.0.
alpha (float, optional): A balanced form for Focal Loss.
Defaults to 0.25.
reduction (str, optional): The method used to reduce the loss into
a scalar. Defaults to 'mean'. Options are "none", "mean" and "sum".
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
"""
# Function.apply does not accept keyword arguments, so the decorator
# "weighted_loss" is not applicable
loss = _sigmoid_focal_loss(pred.contiguous(), target.contiguous(), gamma, alpha, None, "none")
if weight is not None:
if weight.shape != loss.shape:
if weight.size(0) == loss.size(0):
# For most cases, weight is of shape (num_priors, ),
# which means it does not have the second axis num_class
weight = weight.view(-1, 1)
else:
# Sometimes, weight per anchor per class is also needed. e.g.
# in FSAF. But it may be flattened of shape
# (num_priors x num_class, ), while loss is still of shape
# (num_priors, num_class).
if weight.numel() != loss.numel():
msg = "The number of elements in weight should be equal to the number of elements in loss."
raise ValueError(msg)
weight = weight.view(loss.size(0), -1)
if weight.ndim != loss.ndim:
msg = "The number of dimensions in weight should be equal to the number of dimensions in loss."
raise ValueError(msg)
return weight_reduce_loss(loss, weight, reduction, avg_factor)

0 comments on commit 8766676

Please sign in to comment.