From d8d4b29bd71a9738a87632d7e005b9f1ef803b55 Mon Sep 17 00:00:00 2001 From: Rittik Panda Date: Tue, 24 Dec 2024 07:26:57 +0000 Subject: [PATCH] fix: Ensure consistent default values for average argument in classification metrics --- src/torchmetrics/classification/stat_scores.py | 4 ++-- src/torchmetrics/functional/classification/accuracy.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/torchmetrics/classification/stat_scores.py b/src/torchmetrics/classification/stat_scores.py index 96d797fd5d6..814b942de49 100644 --- a/src/torchmetrics/classification/stat_scores.py +++ b/src/torchmetrics/classification/stat_scores.py @@ -309,7 +309,7 @@ def __init__( self, num_classes: int, top_k: int = 1, - average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, @@ -461,7 +461,7 @@ def __init__( self, num_labels: int, threshold: float = 0.5, - average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, diff --git a/src/torchmetrics/functional/classification/accuracy.py b/src/torchmetrics/functional/classification/accuracy.py index 2c02484924e..63293180354 100644 --- a/src/torchmetrics/functional/classification/accuracy.py +++ b/src/torchmetrics/functional/classification/accuracy.py @@ -167,7 +167,7 @@ def multiclass_accuracy( preds: Tensor, target: Tensor, num_classes: int, - average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", top_k: int = 1, multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, @@ -276,7 +276,7 @@ def multilabel_accuracy( target: Tensor, num_labels: int, threshold: float = 0.5, - average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True,