Skip to content

Commit

Permalink
Delete Device2Host caused by comm with device and host (#2840)
Browse files Browse the repository at this point in the history
* async host/device
* unittest
* chlog
* Apply suggestions from code review

---------

Co-authored-by: zhaozheng09 <[email protected]>
Co-authored-by: meng song <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: Jirka B <[email protected]>
Co-authored-by: Nicki Skafte <[email protected]>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
8 people authored Dec 21, 2024
1 parent b9ab4bc commit 3ff199c
Show file tree
Hide file tree
Showing 31 changed files with 171 additions and 109 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Delete `Device2Host` caused by comm with device and host ([#2840](https://github.com/PyTorchLightning/metrics/pull/2840))


- Fixed issue with shared state in metric collection when using dice score ([#2848](https://github.com/PyTorchLightning/metrics/pull/2848))


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
_multiclass_confusion_matrix_format,
_multiclass_confusion_matrix_tensor_validation,
)
from torchmetrics.utilities.compute import normalize_logits_if_needed
from torchmetrics.utilities.enums import ClassificationTaskNoMultilabel


Expand Down Expand Up @@ -239,8 +240,7 @@ def _multiclass_calibration_error_update(
preds: Tensor,
target: Tensor,
) -> tuple[Tensor, Tensor]:
if not torch.all((preds >= 0) * (preds <= 1)):
preds = preds.softmax(1)
preds = normalize_logits_if_needed(preds, "softmax")
confidences, predictions = preds.max(dim=1)
accuracies = predictions.eq(target)
return confidences.float(), accuracies.float()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing_extensions import Literal

from torchmetrics.utilities.checks import _check_same_shape
from torchmetrics.utilities.compute import normalize_logits_if_needed
from torchmetrics.utilities.data import _bincount
from torchmetrics.utilities.enums import ClassificationTask
from torchmetrics.utilities.prints import rank_zero_warn
Expand Down Expand Up @@ -137,9 +138,7 @@ def _binary_confusion_matrix_format(
target = target[idx]

if preds.is_floating_point():
if not torch.all((preds >= 0) * (preds <= 1)):
# preds is logits, convert with sigmoid
preds = preds.sigmoid()
preds = normalize_logits_if_needed(preds, "sigmoid")
if convert_to_labels:
preds = preds > threshold

Expand Down Expand Up @@ -491,8 +490,7 @@ def _multilabel_confusion_matrix_format(
"""
if preds.is_floating_point():
if not torch.all((preds >= 0) * (preds <= 1)):
preds = preds.sigmoid()
preds = normalize_logits_if_needed(preds, "sigmoid")
if should_threshold:
preds = preds > threshold
preds = torch.movedim(preds, 1, -1).reshape(-1, num_labels)
Expand Down
5 changes: 2 additions & 3 deletions src/torchmetrics/functional/classification/hinge.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
_multiclass_confusion_matrix_format,
_multiclass_confusion_matrix_tensor_validation,
)
from torchmetrics.utilities.compute import normalize_logits_if_needed
from torchmetrics.utilities.data import to_onehot
from torchmetrics.utilities.enums import ClassificationTaskNoMultilabel

Expand Down Expand Up @@ -153,9 +154,7 @@ def _multiclass_hinge_loss_update(
squared: bool,
multiclass_mode: Literal["crammer-singer", "one-vs-all"] = "crammer-singer",
) -> tuple[Tensor, Tensor]:
if not torch.all((preds >= 0) * (preds <= 1)):
preds = preds.softmax(1)

preds = normalize_logits_if_needed(preds, "softmax")
target = to_onehot(target, max(2, preds.shape[1])).bool()
if multiclass_mode == "crammer-singer":
margin = preds[target]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from typing_extensions import Literal

from torchmetrics.utilities.checks import _check_same_shape
from torchmetrics.utilities.compute import _safe_divide, interp
from torchmetrics.utilities.compute import _safe_divide, interp, normalize_logits_if_needed
from torchmetrics.utilities.data import _bincount, _cumsum
from torchmetrics.utilities.enums import ClassificationTask
from torchmetrics.utilities.prints import rank_zero_warn
Expand Down Expand Up @@ -182,8 +182,7 @@ def _binary_precision_recall_curve_format(
preds = preds[idx]
target = target[idx]

if not torch.all((preds >= 0) * (preds <= 1)):
preds = preds.sigmoid()
preds = normalize_logits_if_needed(preds, "sigmoid")

thresholds = _adjust_threshold_arg(thresholds, preds.device)
return preds, target, thresholds
Expand Down Expand Up @@ -452,8 +451,7 @@ def _multiclass_precision_recall_curve_format(
preds = preds[idx]
target = target[idx]

if not torch.all((preds >= 0) * (preds <= 1)):
preds = preds.softmax(1)
preds = normalize_logits_if_needed(preds, "softmax")

if average == "micro":
preds = preds.flatten()
Expand Down Expand Up @@ -761,8 +759,8 @@ def _multilabel_precision_recall_curve_format(
"""
preds = preds.transpose(0, 1).reshape(num_labels, -1).T
target = target.transpose(0, 1).reshape(num_labels, -1).T
if not torch.all((preds >= 0) * (preds <= 1)):
preds = preds.sigmoid()

preds = normalize_logits_if_needed(preds, "sigmoid")

thresholds = _adjust_threshold_arg(thresholds, preds.device)
if ignore_index is not None and thresholds is not None:
Expand Down
8 changes: 3 additions & 5 deletions src/torchmetrics/functional/classification/stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing_extensions import Literal

from torchmetrics.utilities.checks import _check_same_shape, _input_format_classification
from torchmetrics.utilities.compute import normalize_logits_if_needed
from torchmetrics.utilities.data import _bincount, select_topk
from torchmetrics.utilities.enums import AverageMethod, ClassificationTask, DataType, MDMCAverageMethod

Expand Down Expand Up @@ -105,9 +106,7 @@ def _binary_stat_scores_format(
"""
if preds.is_floating_point():
if not torch.all((preds >= 0) * (preds <= 1)):
# preds is logits, convert with sigmoid
preds = preds.sigmoid()
preds = normalize_logits_if_needed(preds, "sigmoid")
preds = preds > threshold

preds = preds.reshape(preds.shape[0], -1)
Expand Down Expand Up @@ -659,8 +658,7 @@ def _multilabel_stat_scores_format(
"""
if preds.is_floating_point():
if not torch.all((preds >= 0) * (preds <= 1)):
preds = preds.sigmoid()
preds = normalize_logits_if_needed(preds, "sigmoid")
preds = preds > threshold
preds = preds.reshape(*preds.shape[:2], -1)
target = target.reshape(*target.shape[:2], -1)
Expand Down
43 changes: 43 additions & 0 deletions src/torchmetrics/utilities/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import torch
from torch import Tensor
from typing_extensions import Literal


def _safe_matmul(x: Tensor, y: Tensor) -> Tensor:
Expand Down Expand Up @@ -184,3 +185,45 @@ def interp(x: Tensor, xp: Tensor, fp: Tensor) -> Tensor:
indices = torch.clamp(indices, 0, len(m) - 1)

return m[indices] * x + b[indices]


def normalize_logits_if_needed(tensor: Tensor, normalization: Literal["sigmoid", "softmax"]) -> Tensor:
"""Normalize logits if needed.
If input tensor is outside the [0,1] we assume that logits are provided and apply the normalization.
Use torch.where to prevent device-host sync.
Args:
tensor: input tensor that may be logits or probabilities
normalization: normalization method, either 'sigmoid' or 'softmax'
Returns:
normalized tensor if needed
Example:
>>> import torch
>>> tensor = torch.tensor([-1.0, 0.0, 1.0])
>>> normalize_logits_if_needed(tensor, normalization="sigmoid")
tensor([0.2689, 0.5000, 0.7311])
>>> tensor = torch.tensor([[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]])
>>> normalize_logits_if_needed(tensor, normalization="softmax")
tensor([[0.0900, 0.2447, 0.6652],
[0.6652, 0.2447, 0.0900]])
>>> tensor = torch.tensor([0.0, 0.5, 1.0])
>>> normalize_logits_if_needed(tensor, normalization="sigmoid")
tensor([0.0000, 0.5000, 1.0000])
"""
# decrease sigmoid on cpu .
if tensor.device == torch.device("cpu"):
if not torch.all((tensor >= 0) * (tensor <= 1)):
tensor = tensor.sigmoid() if normalization == "sigmoid" else torch.softmax(tensor, dim=1)
return tensor

# decrease device-host sync on device .
condition = ((tensor < 0) | (tensor > 1)).any()
return torch.where(
condition,
torch.sigmoid(tensor) if normalization == "sigmoid" else torch.softmax(tensor, dim=1),
tensor,
)
13 changes: 7 additions & 6 deletions tests/unittests/classification/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
multilabel_accuracy,
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1

from unittests import NUM_CLASSES, THRESHOLD
from unittests._helpers import seed_all
Expand Down Expand Up @@ -153,8 +154,8 @@ def test_binary_accuracy_half_cpu(self, inputs, dtype):
"""Test dtype support of the metric on CPU."""
preds, target = inputs

if (preds < 0).any() and dtype == torch.half:
pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision")
if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half:
pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1")
self.run_precision_test_cpu(
preds=preds,
target=target,
Expand Down Expand Up @@ -310,8 +311,8 @@ def test_multiclass_accuracy_half_cpu(self, inputs, dtype):
"""Test dtype support of the metric on CPU."""
preds, target = inputs

if (preds < 0).any() and dtype == torch.half:
pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision")
if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half:
pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1")
self.run_precision_test_cpu(
preds=preds,
target=target,
Expand Down Expand Up @@ -585,8 +586,8 @@ def test_multilabel_accuracy_half_cpu(self, inputs, dtype):
"""Test dtype support of the metric on CPU."""
preds, target = inputs

if (preds < 0).any() and dtype == torch.half:
pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision")
if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half:
pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1")
self.run_precision_test_cpu(
preds=preds,
target=target,
Expand Down
5 changes: 3 additions & 2 deletions tests/unittests/classification/test_auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from torchmetrics.functional.classification.auroc import binary_auroc, multiclass_auroc, multilabel_auroc
from torchmetrics.functional.classification.roc import binary_roc
from torchmetrics.metric import Metric
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1

from unittests import NUM_CLASSES
from unittests._helpers import seed_all
Expand Down Expand Up @@ -102,8 +103,8 @@ def test_binary_auroc_dtype_cpu(self, inputs, dtype):
"""Test dtype support of the metric on CPU."""
preds, target = inputs

if (preds < 0).any() and dtype == torch.half:
pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision")
if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half:
pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1")
self.run_precision_test_cpu(
preds=preds,
target=target,
Expand Down
5 changes: 3 additions & 2 deletions tests/unittests/classification/test_average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
)
from torchmetrics.functional.classification.precision_recall_curve import binary_precision_recall_curve
from torchmetrics.metric import Metric
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1

from unittests import NUM_CLASSES
from unittests._helpers import seed_all
Expand Down Expand Up @@ -106,8 +107,8 @@ def test_binary_average_precision_differentiability(self, inputs):
def test_binary_average_precision_dtype_cpu(self, inputs, dtype):
"""Test dtype support of the metric on CPU."""
preds, target = inputs
if (preds < 0).any() and dtype == torch.half:
pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision")
if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half:
pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1")
self.run_precision_test_cpu(
preds=preds,
target=target,
Expand Down
5 changes: 3 additions & 2 deletions tests/unittests/classification/test_calibration_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
multiclass_calibration_error,
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1

from unittests import NUM_CLASSES
from unittests._helpers import seed_all
Expand Down Expand Up @@ -112,8 +113,8 @@ def test_binary_calibration_error_differentiability(self, inputs):
def test_binary_calibration_error_dtype_cpu(self, inputs, dtype):
"""Test dtype support of the metric on CPU."""
preds, target = inputs
if (preds < 0).any() and dtype == torch.half:
pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision")
if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half:
pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1")
self.run_precision_test_cpu(
preds=preds,
target=target,
Expand Down
9 changes: 5 additions & 4 deletions tests/unittests/classification/test_cohen_kappa.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from torchmetrics.classification.cohen_kappa import BinaryCohenKappa, CohenKappa, MulticlassCohenKappa
from torchmetrics.functional.classification.cohen_kappa import binary_cohen_kappa, multiclass_cohen_kappa
from torchmetrics.metric import Metric
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1

from unittests import NUM_CLASSES, THRESHOLD
from unittests._helpers import seed_all
Expand Down Expand Up @@ -103,8 +104,8 @@ def test_binary_cohen_kappa_dtypes_cpu(self, inputs, dtype):
"""Test dtype support of the metric on CPU."""
preds, target = inputs

if (preds < 0).any() and dtype == torch.half:
pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision")
if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half:
pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1")
self.run_precision_test_cpu(
preds=preds,
target=target,
Expand Down Expand Up @@ -206,8 +207,8 @@ def test_multiclass_cohen_kappa_dtypes_cpu(self, inputs, dtype):
"""Test dtype support of the metric on CPU."""
preds, target = inputs

if (preds < 0).any() and dtype == torch.half:
pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision")
if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half:
pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1")
self.run_precision_test_cpu(
preds=preds,
target=target,
Expand Down
9 changes: 5 additions & 4 deletions tests/unittests/classification/test_confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
multilabel_confusion_matrix,
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1

from unittests import NUM_CLASSES, THRESHOLD
from unittests._helpers import seed_all
Expand Down Expand Up @@ -114,8 +115,8 @@ def test_binary_confusion_matrix_dtype_cpu(self, inputs, dtype):
"""Test dtype support of the metric on CPU."""
preds, target = inputs

if (preds < 0).any() and dtype == torch.half:
pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision")
if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half:
pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1")
self.run_precision_test_cpu(
preds=preds,
target=target,
Expand Down Expand Up @@ -367,8 +368,8 @@ def test_multilabel_confusion_matrix_dtype_cpu(self, inputs, dtype):
"""Test dtype support of the metric on CPU."""
preds, target = inputs

if (preds < 0).any() and dtype == torch.half:
pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision")
if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half:
pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1")
self.run_precision_test_cpu(
preds=preds,
target=target,
Expand Down
9 changes: 5 additions & 4 deletions tests/unittests/classification/test_exact_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torchmetrics.classification.exact_match import ExactMatch, MulticlassExactMatch, MultilabelExactMatch
from torchmetrics.functional.classification.exact_match import multiclass_exact_match, multilabel_exact_match
from torchmetrics.metric import Metric
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1

from unittests import NUM_CLASSES, THRESHOLD
from unittests._helpers import seed_all
Expand Down Expand Up @@ -121,8 +122,8 @@ def test_multiclass_exact_match_half_cpu(self, inputs, dtype):
"""Test dtype support of the metric on CPU."""
preds, target = inputs

if (preds < 0).any() and dtype == torch.half:
pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision")
if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half:
pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1")
self.run_precision_test_cpu(
preds=preds,
target=target,
Expand Down Expand Up @@ -250,8 +251,8 @@ def test_multilabel_exact_match_half_cpu(self, inputs, dtype):
"""Test dtype support of the metric on CPU."""
preds, target = inputs

if (preds < 0).any() and dtype == torch.half:
pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision")
if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half:
pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1")
self.run_precision_test_cpu(
preds=preds,
target=target,
Expand Down
Loading

0 comments on commit 3ff199c

Please sign in to comment.