Skip to content

Commit

Permalink
Fix multilabel_accuracy of MixedHLabelAccuracy (#4042)
Browse files Browse the repository at this point in the history
* Fix metric for multi-label

* Fix1

* Add CHANGELOG
  • Loading branch information
harimkang authored Oct 17, 2024
1 parent 93adf0d commit 08de045
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 5 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ All notable changes to this project will be documented in this file.
(<https://github.com/openvinotoolkit/training_extensions/pull/4014>)
- Fix out_features in HierarchicalCBAMClsHead
(<https://github.com/openvinotoolkit/training_extensions/pull/4016>)
- Fix multilabel_accuracy of MixedHLabelAccuracy
(<https://github.com/openvinotoolkit/training_extensions/pull/4042>)

## \[v2.1.0\]

Expand Down
13 changes: 9 additions & 4 deletions src/otx/core/metrics/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,12 +288,17 @@ def __init__(
]

# Multilabel classification accuracy metrics
if self.num_multilabel_classes > 0:
# https://github.com/Lightning-AI/torchmetrics/blob/6377aa5b6fe2863761839e6b8b5a857ef1b8acfa/src/torchmetrics/functional/classification/stat_scores.py#L583-L584
# MultilabelAccuracy is available when num_multilabel_classes is greater than 2.
self.multilabel_accuracy = None
if self.num_multilabel_classes > 1:
self.multilabel_accuracy = TorchmetricMultilabelAcc(
num_labels=self.num_multilabel_classes,
threshold=0.5,
average="macro",
)
elif self.num_multilabel_classes == 1:
self.multilabel_accuracy = TorchmetricAcc(task="binary", num_classes=self.num_multilabel_classes)

def _apply(self, fn: Callable, exclude_state: Sequence[str] = "") -> nn.Module:
self.multiclass_head_accuracy = [
Expand All @@ -303,7 +308,7 @@ def _apply(self, fn: Callable, exclude_state: Sequence[str] = "") -> nn.Module:
)
for acc in self.multiclass_head_accuracy
]
if self.num_multilabel_classes > 0:
if self.multilabel_accuracy is not None:
self.multilabel_accuracy = self.multilabel_accuracy._apply(fn, exclude_state) # noqa: SLF001
return self

Expand All @@ -322,7 +327,7 @@ def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
target_multiclass[multiclass_mask],
)

if self.num_multilabel_classes > 0:
if self.multilabel_accuracy is not None:
# Split preds into multiclass and multilabel parts
preds_multilabel = preds[:, self.num_multiclass_heads :]
target_multilabel = target[:, self.num_multiclass_heads :]
Expand All @@ -337,7 +342,7 @@ def compute(self) -> torch.Tensor:
),
)

if self.num_multilabel_classes > 0:
if self.multilabel_accuracy is not None:
multilabel_acc = self.multilabel_accuracy.compute()

return (multiclass_accs + multilabel_acc) / 2
Expand Down
27 changes: 26 additions & 1 deletion tests/unit/core/metrics/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
MultilabelAccuracywithLabelGroup,
)
from otx.core.types.label import HLabelInfo, LabelInfo
from torchmetrics.classification.accuracy import BinaryAccuracy, MulticlassAccuracy
from torchmetrics.classification.accuracy import BinaryAccuracy, MulticlassAccuracy, MultilabelAccuracy


class TestAccuracy:
Expand Down Expand Up @@ -120,3 +120,28 @@ def test_multilabel_only(self) -> None:
head_logits_info={"head1": (0, 5), "head2": (5, 10)},
threshold_multilabel=0.5,
)

def test_multilabel_accuracy(self, hlabel_accuracy) -> None:
# Normal Case: num_multilabel_classes > 1 -> MultilabelAccuracy
assert hlabel_accuracy.num_multilabel_classes == 3
assert isinstance(hlabel_accuracy.multilabel_accuracy, MultilabelAccuracy)

# Edge Case: num_multilabel_classes = 1 -> BinaryAccuracy
acc = MixedHLabelAccuracy(
num_multiclass_heads=2,
num_multilabel_classes=1,
head_logits_info={"head1": (0, 5), "head2": (5, 10)},
threshold_multilabel=0.5,
)
assert acc.num_multilabel_classes == 1
assert isinstance(acc.multilabel_accuracy, BinaryAccuracy)

# None Case: num_multilabel_classes = 0 -> None
acc = MixedHLabelAccuracy(
num_multiclass_heads=2,
num_multilabel_classes=0,
head_logits_info={"head1": (0, 5), "head2": (5, 10)},
threshold_multilabel=0.5,
)
assert acc.num_multilabel_classes == 0
assert acc.multilabel_accuracy is None

0 comments on commit 08de045

Please sign in to comment.