Skip to content

Commit

Permalink
Add label smoothing param in DiceCELoss
Browse files Browse the repository at this point in the history
Fixes Project-MONAI#7957

Signed-off-by: ytl0623 <[email protected]>
  • Loading branch information
ytl0623 committed Aug 7, 2024
1 parent 6c23fd0 commit 4ef1785
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion monai/losses/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,7 @@ def __init__(
weight: torch.Tensor | None = None,
lambda_dice: float = 1.0,
lambda_ce: float = 1.0,
label_smoothing: float = 0.0,
) -> None:
"""
Args:
Expand Down Expand Up @@ -728,7 +729,7 @@ def __init__(
batch=batch,
weight=dice_weight,
)
self.cross_entropy = nn.CrossEntropyLoss(weight=weight, reduction=reduction)
self.cross_entropy = nn.CrossEntropyLoss(weight=weight, reduction=reduction, label_smoothing=label_smoothing)
self.binary_cross_entropy = nn.BCEWithLogitsLoss(pos_weight=weight, reduction=reduction)
if lambda_dice < 0.0:
raise ValueError("lambda_dice should be no less than 0.0.")
Expand All @@ -737,6 +738,7 @@ def __init__(
self.lambda_dice = lambda_dice
self.lambda_ce = lambda_ce
self.old_pt_ver = not pytorch_after(1, 10)
self.label_smoothing = label_smoothing

def ce(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Expand Down

0 comments on commit 4ef1785

Please sign in to comment.