diff --git a/src/torchmetrics/functional/segmentation/mean_iou.py b/src/torchmetrics/functional/segmentation/mean_iou.py index 9cfed0fa1bf..0dfc7a5e692 100644 --- a/src/torchmetrics/functional/segmentation/mean_iou.py +++ b/src/torchmetrics/functional/segmentation/mean_iou.py @@ -66,11 +66,9 @@ def _mean_iou_update( def _mean_iou_compute( intersection: Tensor, union: Tensor, - per_class: bool = False, ) -> Tensor: """Compute the mean IoU metric.""" - val = _safe_divide(intersection, union) - return val if per_class else torch.mean(val, 1) + return _safe_divide(intersection, union) def mean_iou( @@ -111,4 +109,7 @@ def mean_iou( """ _mean_iou_validate_args(num_classes, include_background, per_class, input_format) intersection, union = _mean_iou_update(preds, target, num_classes, include_background, input_format) - return _mean_iou_compute(intersection, union, per_class=per_class) + score = _mean_iou_compute(intersection, union) + valid_classes = union > 0 + score = score * valid_classes + return score if per_class else score.sum(dim=-1) / valid_classes.sum(dim=-1) diff --git a/src/torchmetrics/segmentation/mean_iou.py b/src/torchmetrics/segmentation/mean_iou.py index ae8dd3d2aea..0fe1fe699ed 100644 --- a/src/torchmetrics/segmentation/mean_iou.py +++ b/src/torchmetrics/segmentation/mean_iou.py @@ -111,21 +111,23 @@ def __init__( self.input_format = input_format num_classes = num_classes - 1 if not include_background else num_classes - self.add_state("score", default=torch.zeros(num_classes if per_class else 1), dist_reduce_fx="sum") - self.add_state("num_batches", default=torch.tensor(0), dist_reduce_fx="sum") + self.add_state("score", default=torch.zeros(num_classes), dist_reduce_fx="sum") + self.add_state("num_batches", default=torch.zeros(num_classes), dist_reduce_fx="sum") def update(self, preds: Tensor, target: Tensor) -> None: """Update the state with the new data.""" intersection, union = _mean_iou_update( preds, target, self.num_classes, self.include_background, self.input_format ) - score = _mean_iou_compute(intersection, union, per_class=self.per_class) - self.score += score.mean(0) if self.per_class else score.mean() - self.num_batches += 1 + score = _mean_iou_compute(intersection, union) + # only update for classes that are present (i.e. union > 0) + valid_classes = union > 0 + self.score += (score * valid_classes).sum(dim=0) + self.num_batches += valid_classes.sum(dim=0) def compute(self) -> Tensor: """Compute the final Mean Intersection over Union (mIoU).""" - return self.score / self.num_batches + return self.score / self.num_batches if self.per_class else (self.score / self.num_batches).mean() def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric.