Skip to content

Commit

Permalink
feat: add an optional class balancer to the custom losses
Browse files Browse the repository at this point in the history
  • Loading branch information
torms3 authored and supersergiy committed Dec 13, 2023
1 parent 9da6f33 commit 7e46b87
Showing 1 changed file with 22 additions and 3 deletions.
25 changes: 22 additions & 3 deletions zetta_utils/segmentation/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@

@builder.register("LossWithMask")
@typechecked
class LossWithMask(nn.Module):
class LossWithMask(nn.Module): # pragma: no cover
def __init__(
self,
criterion: Callable[..., nn.Module],
reduction: Literal["mean", "sum", "none"] = "sum",
balancer: nn.Module | None = None,
) -> None:
super().__init__()
try:
Expand All @@ -25,6 +26,8 @@ def __init__(
self.criterion = criterion()
assert self.criterion.reduction == "none"
self.reduction = reduction
self.balancer = balancer
self.balanced = False

def forward(
self,
Expand All @@ -36,6 +39,10 @@ def forward(
if nmsk.item() == 0:
return None

# Optional class balancing
if (not self.balanced) and (self.balancer is not None):
mask = self.balancer(trgt, mask)

loss = mask * self.criterion(pred, trgt)
if self.reduction == "none":
return loss
Expand All @@ -56,10 +63,11 @@ def __init__(
self,
criterion: Callable[..., nn.Module],
reduction: Literal["mean", "sum", "none"] = "sum",
balancer: nn.Module | None = None,
margin: float = 0,
logits: bool = False,
) -> None:
super().__init__(criterion, reduction)
super().__init__(criterion, reduction, balancer)
self.margin = np.clip(margin, 0, 1)
self.logits = logits

Expand All @@ -69,6 +77,11 @@ def forward(
trgt: torch.Tensor,
mask: torch.Tensor,
) -> torch.Tensor | None:
# Optional class balancing
if self.balancer is not None:
mask = self.balancer(trgt, mask)
self.balanced = True

high = 1 - self.margin
low = self.margin
activ = torch.sigmoid(pred) if self.logits else pred
Expand All @@ -85,9 +98,10 @@ def __init__(
self,
criterion: Callable[..., nn.Module],
reduction: Literal["mean", "sum", "none"] = "sum",
balancer: nn.Module | None = None,
margin: float = 0,
) -> None:
super().__init__(criterion, reduction)
super().__init__(criterion, reduction, balancer)
self.margin = np.clip(margin, 0, 1)

def forward(
Expand All @@ -96,6 +110,11 @@ def forward(
trgt: torch.Tensor,
mask: torch.Tensor,
) -> torch.Tensor | None:
# Optional class balancing
if self.balancer is not None:
mask = self.balancer(trgt, mask)
self.balanced = True

trgt[torch.eq(trgt, 1)] = 1 - self.margin
trgt[torch.eq(trgt, 0)] = self.margin
return super().forward(pred, trgt, mask)

0 comments on commit 7e46b87

Please sign in to comment.