From c5f70e004943641b2e007fecaf48d578ae8e3bb5 Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Sun, 19 Jan 2025 20:24:22 +0100 Subject: [PATCH] Update dice weight constraints for flexible loss weighting (#839) --- micro_sam/training/semantic_sam_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/micro_sam/training/semantic_sam_trainer.py b/micro_sam/training/semantic_sam_trainer.py index 61259482..84b421ad 100644 --- a/micro_sam/training/semantic_sam_trainer.py +++ b/micro_sam/training/semantic_sam_trainer.py @@ -70,8 +70,8 @@ def __init__(self, convert_inputs, num_classes: int, dice_weight: Optional[float self.compute_ce_loss = nn.CrossEntropyLoss() self.dice_weight = dice_weight - if self.dice_weight is not None: - assert self.dice_weight > 0 and self.dice_weight < 1, "The weight factor should lie between 0 and 1." + if self.dice_weight is not None and (self.dice_weight < 0 or self.dice_weight > 1): + raise ValueError("The weight factor should lie between 0 and 1.") self._kwargs = kwargs