Skip to content

Commit

Permalink
fix: Ensure consistent default values('micro') for average argument i…
Browse files Browse the repository at this point in the history
…n classification metrics
  • Loading branch information
rittik9 committed Dec 24, 2024
1 parent 7bda8f7 commit d03deda
Show file tree
Hide file tree
Showing 8 changed files with 14 additions and 14 deletions.
2 changes: 1 addition & 1 deletion docs/source/pages/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ of metrics e.g. computation of confidence intervals by resampling of input data.
.. testoutput::
:options: +NORMALIZE_WHITESPACE

{'mean': tensor(0.1333), 'std': tensor(0.1554)}
{'mean': tensor(0.1069), 'std': tensor(0.1180)}

You can see all implemented wrappers under the wrapper section of the API docs.

Expand Down
4 changes: 2 additions & 2 deletions src/torchmetrics/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ class MulticlassAccuracy(MulticlassStatScores):
>>> preds = tensor([2, 1, 0, 1])
>>> metric = MulticlassAccuracy(num_classes=3)
>>> metric(preds, target)
tensor(0.8333)
tensor(0.7500)
>>> mca = MulticlassAccuracy(num_classes=3, average=None)
>>> mca(preds, target)
tensor([0.5000, 1.0000, 1.0000])
Expand All @@ -228,7 +228,7 @@ class MulticlassAccuracy(MulticlassStatScores):
... [0.05, 0.82, 0.13]])
>>> metric = MulticlassAccuracy(num_classes=3)
>>> metric(preds, target)
tensor(0.8333)
tensor(0.7500)
>>> mca = MulticlassAccuracy(num_classes=3, average=None)
>>> mca(preds, target)
tensor([0.5000, 1.0000, 1.0000])
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/classification/hamming.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ class MulticlassHammingDistance(MulticlassStatScores):
>>> preds = tensor([2, 1, 0, 1])
>>> metric = MulticlassHammingDistance(num_classes=3)
>>> metric(preds, target)
tensor(0.1667)
tensor(0.2500)
>>> mchd = MulticlassHammingDistance(num_classes=3, average=None)
>>> mchd(preds, target)
tensor([0.5000, 0.0000, 0.0000])
Expand Down
4 changes: 2 additions & 2 deletions src/torchmetrics/classification/negative_predictive_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ class MulticlassNegativePredictiveValue(MulticlassStatScores):
>>> preds = tensor([2, 1, 0, 1])
>>> metric = MulticlassNegativePredictiveValue(num_classes=3)
>>> metric(preds, target)
tensor(0.8889)
tensor(0.8750)
>>> metric = MulticlassNegativePredictiveValue(num_classes=3, average=None)
>>> metric(preds, target)
tensor([0.6667, 1.0000, 1.0000])
Expand Down Expand Up @@ -371,7 +371,7 @@ class MultilabelNegativePredictiveValue(MultilabelStatScores):
>>> preds = tensor([[0, 0, 1], [1, 0, 1]])
>>> metric = MultilabelNegativePredictiveValue(num_labels=3)
>>> metric(preds, target)
tensor(0.5000)
tensor(0.6667)
>>> mls = MultilabelNegativePredictiveValue(num_labels=3, average=None)
>>> mls(preds, target)
tensor([1.0000, 0.5000, 0.0000])
Expand Down
6 changes: 3 additions & 3 deletions src/torchmetrics/classification/precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ class MulticlassPrecision(MulticlassStatScores):
>>> preds = tensor([2, 1, 0, 1])
>>> metric = MulticlassPrecision(num_classes=3)
>>> metric(preds, target)
tensor(0.8333)
tensor(0.7500)
>>> mcp = MulticlassPrecision(num_classes=3, average=None)
>>> mcp(preds, target)
tensor([1.0000, 0.5000, 1.0000])
Expand Down Expand Up @@ -402,7 +402,7 @@ class MultilabelPrecision(MultilabelStatScores):
>>> preds = tensor([[0, 0, 1], [1, 0, 1]])
>>> metric = MultilabelPrecision(num_labels=3)
>>> metric(preds, target)
tensor(0.5000)
tensor(0.6667)
>>> mlp = MultilabelPrecision(num_labels=3, average=None)
>>> mlp(preds, target)
tensor([1.0000, 0.0000, 0.5000])
Expand Down Expand Up @@ -696,7 +696,7 @@ class MulticlassRecall(MulticlassStatScores):
>>> preds = tensor([2, 1, 0, 1])
>>> metric = MulticlassRecall(num_classes=3)
>>> metric(preds, target)
tensor(0.8333)
tensor(0.7500)
>>> mcr = MulticlassRecall(num_classes=3, average=None)
>>> mcr(preds, target)
tensor([0.5000, 1.0000, 1.0000])
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/classification/specificity.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ class MulticlassSpecificity(MulticlassStatScores):
>>> preds = tensor([2, 1, 0, 1])
>>> metric = MulticlassSpecificity(num_classes=3)
>>> metric(preds, target)
tensor(0.8889)
tensor(0.8750)
>>> mcs = MulticlassSpecificity(num_classes=3, average=None)
>>> mcs(preds, target)
tensor([1.0000, 0.6667, 1.0000])
Expand Down
4 changes: 2 additions & 2 deletions src/torchmetrics/classification/stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def __init__(
self,
num_classes: int,
top_k: int = 1,
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
multidim_average: Literal["global", "samplewise"] = "global",
ignore_index: Optional[int] = None,
validate_args: bool = True,
Expand Down Expand Up @@ -461,7 +461,7 @@ def __init__(
self,
num_labels: int,
threshold: float = 0.5,
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
multidim_average: Literal["global", "samplewise"] = "global",
ignore_index: Optional[int] = None,
validate_args: bool = True,
Expand Down
4 changes: 2 additions & 2 deletions src/torchmetrics/functional/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def multiclass_accuracy(
preds: Tensor,
target: Tensor,
num_classes: int,
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
top_k: int = 1,
multidim_average: Literal["global", "samplewise"] = "global",
ignore_index: Optional[int] = None,
Expand Down Expand Up @@ -276,7 +276,7 @@ def multilabel_accuracy(
target: Tensor,
num_labels: int,
threshold: float = 0.5,
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
multidim_average: Literal["global", "samplewise"] = "global",
ignore_index: Optional[int] = None,
validate_args: bool = True,
Expand Down

0 comments on commit d03deda

Please sign in to comment.