From d03deda73e41ac05d8625f24fcbed1f3ddd1e203 Mon Sep 17 00:00:00 2001 From: rittik9 Date: Tue, 24 Dec 2024 18:36:10 +0530 Subject: [PATCH] fix: Ensure consistent default values('micro') for average argument in classification metrics --- docs/source/pages/overview.rst | 2 +- src/torchmetrics/classification/accuracy.py | 4 ++-- src/torchmetrics/classification/hamming.py | 2 +- .../classification/negative_predictive_value.py | 4 ++-- src/torchmetrics/classification/precision_recall.py | 6 +++--- src/torchmetrics/classification/specificity.py | 2 +- src/torchmetrics/classification/stat_scores.py | 4 ++-- src/torchmetrics/functional/classification/accuracy.py | 4 ++-- 8 files changed, 14 insertions(+), 14 deletions(-) diff --git a/docs/source/pages/overview.rst b/docs/source/pages/overview.rst index 933d4254cfe..9fa08fbbc99 100644 --- a/docs/source/pages/overview.rst +++ b/docs/source/pages/overview.rst @@ -453,7 +453,7 @@ of metrics e.g. computation of confidence intervals by resampling of input data. .. testoutput:: :options: +NORMALIZE_WHITESPACE - {'mean': tensor(0.1333), 'std': tensor(0.1554)} + {'mean': tensor(0.1069), 'std': tensor(0.1180)} You can see all implemented wrappers under the wrapper section of the API docs. diff --git a/src/torchmetrics/classification/accuracy.py b/src/torchmetrics/classification/accuracy.py index bc1a8bb5e36..73a9e6d190c 100644 --- a/src/torchmetrics/classification/accuracy.py +++ b/src/torchmetrics/classification/accuracy.py @@ -214,7 +214,7 @@ class MulticlassAccuracy(MulticlassStatScores): >>> preds = tensor([2, 1, 0, 1]) >>> metric = MulticlassAccuracy(num_classes=3) >>> metric(preds, target) - tensor(0.8333) + tensor(0.7500) >>> mca = MulticlassAccuracy(num_classes=3, average=None) >>> mca(preds, target) tensor([0.5000, 1.0000, 1.0000]) @@ -228,7 +228,7 @@ class MulticlassAccuracy(MulticlassStatScores): ... [0.05, 0.82, 0.13]]) >>> metric = MulticlassAccuracy(num_classes=3) >>> metric(preds, target) - tensor(0.8333) + tensor(0.7500) >>> mca = MulticlassAccuracy(num_classes=3, average=None) >>> mca(preds, target) tensor([0.5000, 1.0000, 1.0000]) diff --git a/src/torchmetrics/classification/hamming.py b/src/torchmetrics/classification/hamming.py index 183af336ae8..eb7df756d84 100644 --- a/src/torchmetrics/classification/hamming.py +++ b/src/torchmetrics/classification/hamming.py @@ -224,7 +224,7 @@ class MulticlassHammingDistance(MulticlassStatScores): >>> preds = tensor([2, 1, 0, 1]) >>> metric = MulticlassHammingDistance(num_classes=3) >>> metric(preds, target) - tensor(0.1667) + tensor(0.2500) >>> mchd = MulticlassHammingDistance(num_classes=3, average=None) >>> mchd(preds, target) tensor([0.5000, 0.0000, 0.0000]) diff --git a/src/torchmetrics/classification/negative_predictive_value.py b/src/torchmetrics/classification/negative_predictive_value.py index cdff97f86e2..764168c7c70 100644 --- a/src/torchmetrics/classification/negative_predictive_value.py +++ b/src/torchmetrics/classification/negative_predictive_value.py @@ -220,7 +220,7 @@ class MulticlassNegativePredictiveValue(MulticlassStatScores): >>> preds = tensor([2, 1, 0, 1]) >>> metric = MulticlassNegativePredictiveValue(num_classes=3) >>> metric(preds, target) - tensor(0.8889) + tensor(0.8750) >>> metric = MulticlassNegativePredictiveValue(num_classes=3, average=None) >>> metric(preds, target) tensor([0.6667, 1.0000, 1.0000]) @@ -371,7 +371,7 @@ class MultilabelNegativePredictiveValue(MultilabelStatScores): >>> preds = tensor([[0, 0, 1], [1, 0, 1]]) >>> metric = MultilabelNegativePredictiveValue(num_labels=3) >>> metric(preds, target) - tensor(0.5000) + tensor(0.6667) >>> mls = MultilabelNegativePredictiveValue(num_labels=3, average=None) >>> mls(preds, target) tensor([1.0000, 0.5000, 0.0000]) diff --git a/src/torchmetrics/classification/precision_recall.py b/src/torchmetrics/classification/precision_recall.py index 0bd6f8b0d99..bb2f592b24e 100644 --- a/src/torchmetrics/classification/precision_recall.py +++ b/src/torchmetrics/classification/precision_recall.py @@ -237,7 +237,7 @@ class MulticlassPrecision(MulticlassStatScores): >>> preds = tensor([2, 1, 0, 1]) >>> metric = MulticlassPrecision(num_classes=3) >>> metric(preds, target) - tensor(0.8333) + tensor(0.7500) >>> mcp = MulticlassPrecision(num_classes=3, average=None) >>> mcp(preds, target) tensor([1.0000, 0.5000, 1.0000]) @@ -402,7 +402,7 @@ class MultilabelPrecision(MultilabelStatScores): >>> preds = tensor([[0, 0, 1], [1, 0, 1]]) >>> metric = MultilabelPrecision(num_labels=3) >>> metric(preds, target) - tensor(0.5000) + tensor(0.6667) >>> mlp = MultilabelPrecision(num_labels=3, average=None) >>> mlp(preds, target) tensor([1.0000, 0.0000, 0.5000]) @@ -696,7 +696,7 @@ class MulticlassRecall(MulticlassStatScores): >>> preds = tensor([2, 1, 0, 1]) >>> metric = MulticlassRecall(num_classes=3) >>> metric(preds, target) - tensor(0.8333) + tensor(0.7500) >>> mcr = MulticlassRecall(num_classes=3, average=None) >>> mcr(preds, target) tensor([0.5000, 1.0000, 1.0000]) diff --git a/src/torchmetrics/classification/specificity.py b/src/torchmetrics/classification/specificity.py index dab5fde8a60..d15927cf590 100644 --- a/src/torchmetrics/classification/specificity.py +++ b/src/torchmetrics/classification/specificity.py @@ -214,7 +214,7 @@ class MulticlassSpecificity(MulticlassStatScores): >>> preds = tensor([2, 1, 0, 1]) >>> metric = MulticlassSpecificity(num_classes=3) >>> metric(preds, target) - tensor(0.8889) + tensor(0.8750) >>> mcs = MulticlassSpecificity(num_classes=3, average=None) >>> mcs(preds, target) tensor([1.0000, 0.6667, 1.0000]) diff --git a/src/torchmetrics/classification/stat_scores.py b/src/torchmetrics/classification/stat_scores.py index 3c55c431fa2..58a695cd491 100644 --- a/src/torchmetrics/classification/stat_scores.py +++ b/src/torchmetrics/classification/stat_scores.py @@ -309,7 +309,7 @@ def __init__( self, num_classes: int, top_k: int = 1, - average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, @@ -461,7 +461,7 @@ def __init__( self, num_labels: int, threshold: float = 0.5, - average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, diff --git a/src/torchmetrics/functional/classification/accuracy.py b/src/torchmetrics/functional/classification/accuracy.py index 2c02484924e..63293180354 100644 --- a/src/torchmetrics/functional/classification/accuracy.py +++ b/src/torchmetrics/functional/classification/accuracy.py @@ -167,7 +167,7 @@ def multiclass_accuracy( preds: Tensor, target: Tensor, num_classes: int, - average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", top_k: int = 1, multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, @@ -276,7 +276,7 @@ def multilabel_accuracy( target: Tensor, num_labels: int, threshold: float = 0.5, - average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True,