diff --git a/CHANGELOG.md b/CHANGELOG.md index 444ab2b12e3..62d7ffc4c5a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- Delete `Device2Host` caused by comm with device and host ([#2840](https://github.com/PyTorchLightning/metrics/pull/2840)) + + - Fixed issue with shared state in metric collection when using dice score ([#2848](https://github.com/PyTorchLightning/metrics/pull/2848)) diff --git a/src/torchmetrics/functional/classification/calibration_error.py b/src/torchmetrics/functional/classification/calibration_error.py index ca55bb2f79b..22e529588a0 100644 --- a/src/torchmetrics/functional/classification/calibration_error.py +++ b/src/torchmetrics/functional/classification/calibration_error.py @@ -23,6 +23,7 @@ _multiclass_confusion_matrix_format, _multiclass_confusion_matrix_tensor_validation, ) +from torchmetrics.utilities.compute import normalize_logits_if_needed from torchmetrics.utilities.enums import ClassificationTaskNoMultilabel @@ -239,8 +240,7 @@ def _multiclass_calibration_error_update( preds: Tensor, target: Tensor, ) -> tuple[Tensor, Tensor]: - if not torch.all((preds >= 0) * (preds <= 1)): - preds = preds.softmax(1) + preds = normalize_logits_if_needed(preds, "softmax") confidences, predictions = preds.max(dim=1) accuracies = predictions.eq(target) return confidences.float(), accuracies.float() diff --git a/src/torchmetrics/functional/classification/confusion_matrix.py b/src/torchmetrics/functional/classification/confusion_matrix.py index 92059072490..e6443c00d74 100644 --- a/src/torchmetrics/functional/classification/confusion_matrix.py +++ b/src/torchmetrics/functional/classification/confusion_matrix.py @@ -18,6 +18,7 @@ from typing_extensions import Literal from torchmetrics.utilities.checks import _check_same_shape +from torchmetrics.utilities.compute import normalize_logits_if_needed from torchmetrics.utilities.data import _bincount from torchmetrics.utilities.enums import ClassificationTask from torchmetrics.utilities.prints import rank_zero_warn @@ -137,9 +138,7 @@ def _binary_confusion_matrix_format( target = target[idx] if preds.is_floating_point(): - if not torch.all((preds >= 0) * (preds <= 1)): - # preds is logits, convert with sigmoid - preds = preds.sigmoid() + preds = normalize_logits_if_needed(preds, "sigmoid") if convert_to_labels: preds = preds > threshold @@ -491,8 +490,7 @@ def _multilabel_confusion_matrix_format( """ if preds.is_floating_point(): - if not torch.all((preds >= 0) * (preds <= 1)): - preds = preds.sigmoid() + preds = normalize_logits_if_needed(preds, "sigmoid") if should_threshold: preds = preds > threshold preds = torch.movedim(preds, 1, -1).reshape(-1, num_labels) diff --git a/src/torchmetrics/functional/classification/hinge.py b/src/torchmetrics/functional/classification/hinge.py index 8fe7cf840b8..2e6b8740886 100644 --- a/src/torchmetrics/functional/classification/hinge.py +++ b/src/torchmetrics/functional/classification/hinge.py @@ -23,6 +23,7 @@ _multiclass_confusion_matrix_format, _multiclass_confusion_matrix_tensor_validation, ) +from torchmetrics.utilities.compute import normalize_logits_if_needed from torchmetrics.utilities.data import to_onehot from torchmetrics.utilities.enums import ClassificationTaskNoMultilabel @@ -153,9 +154,7 @@ def _multiclass_hinge_loss_update( squared: bool, multiclass_mode: Literal["crammer-singer", "one-vs-all"] = "crammer-singer", ) -> tuple[Tensor, Tensor]: - if not torch.all((preds >= 0) * (preds <= 1)): - preds = preds.softmax(1) - + preds = normalize_logits_if_needed(preds, "softmax") target = to_onehot(target, max(2, preds.shape[1])).bool() if multiclass_mode == "crammer-singer": margin = preds[target] diff --git a/src/torchmetrics/functional/classification/precision_recall_curve.py b/src/torchmetrics/functional/classification/precision_recall_curve.py index 3c5a840efa1..b00c1975606 100644 --- a/src/torchmetrics/functional/classification/precision_recall_curve.py +++ b/src/torchmetrics/functional/classification/precision_recall_curve.py @@ -21,7 +21,7 @@ from typing_extensions import Literal from torchmetrics.utilities.checks import _check_same_shape -from torchmetrics.utilities.compute import _safe_divide, interp +from torchmetrics.utilities.compute import _safe_divide, interp, normalize_logits_if_needed from torchmetrics.utilities.data import _bincount, _cumsum from torchmetrics.utilities.enums import ClassificationTask from torchmetrics.utilities.prints import rank_zero_warn @@ -182,8 +182,7 @@ def _binary_precision_recall_curve_format( preds = preds[idx] target = target[idx] - if not torch.all((preds >= 0) * (preds <= 1)): - preds = preds.sigmoid() + preds = normalize_logits_if_needed(preds, "sigmoid") thresholds = _adjust_threshold_arg(thresholds, preds.device) return preds, target, thresholds @@ -452,8 +451,7 @@ def _multiclass_precision_recall_curve_format( preds = preds[idx] target = target[idx] - if not torch.all((preds >= 0) * (preds <= 1)): - preds = preds.softmax(1) + preds = normalize_logits_if_needed(preds, "softmax") if average == "micro": preds = preds.flatten() @@ -761,8 +759,8 @@ def _multilabel_precision_recall_curve_format( """ preds = preds.transpose(0, 1).reshape(num_labels, -1).T target = target.transpose(0, 1).reshape(num_labels, -1).T - if not torch.all((preds >= 0) * (preds <= 1)): - preds = preds.sigmoid() + + preds = normalize_logits_if_needed(preds, "sigmoid") thresholds = _adjust_threshold_arg(thresholds, preds.device) if ignore_index is not None and thresholds is not None: diff --git a/src/torchmetrics/functional/classification/stat_scores.py b/src/torchmetrics/functional/classification/stat_scores.py index fbb6098db40..d111b5459bb 100644 --- a/src/torchmetrics/functional/classification/stat_scores.py +++ b/src/torchmetrics/functional/classification/stat_scores.py @@ -18,6 +18,7 @@ from typing_extensions import Literal from torchmetrics.utilities.checks import _check_same_shape, _input_format_classification +from torchmetrics.utilities.compute import normalize_logits_if_needed from torchmetrics.utilities.data import _bincount, select_topk from torchmetrics.utilities.enums import AverageMethod, ClassificationTask, DataType, MDMCAverageMethod @@ -105,9 +106,7 @@ def _binary_stat_scores_format( """ if preds.is_floating_point(): - if not torch.all((preds >= 0) * (preds <= 1)): - # preds is logits, convert with sigmoid - preds = preds.sigmoid() + preds = normalize_logits_if_needed(preds, "sigmoid") preds = preds > threshold preds = preds.reshape(preds.shape[0], -1) @@ -659,8 +658,7 @@ def _multilabel_stat_scores_format( """ if preds.is_floating_point(): - if not torch.all((preds >= 0) * (preds <= 1)): - preds = preds.sigmoid() + preds = normalize_logits_if_needed(preds, "sigmoid") preds = preds > threshold preds = preds.reshape(*preds.shape[:2], -1) target = target.reshape(*target.shape[:2], -1) diff --git a/src/torchmetrics/utilities/compute.py b/src/torchmetrics/utilities/compute.py index cbb648a8844..5a46993d86b 100644 --- a/src/torchmetrics/utilities/compute.py +++ b/src/torchmetrics/utilities/compute.py @@ -15,6 +15,7 @@ import torch from torch import Tensor +from typing_extensions import Literal def _safe_matmul(x: Tensor, y: Tensor) -> Tensor: @@ -184,3 +185,45 @@ def interp(x: Tensor, xp: Tensor, fp: Tensor) -> Tensor: indices = torch.clamp(indices, 0, len(m) - 1) return m[indices] * x + b[indices] + + +def normalize_logits_if_needed(tensor: Tensor, normalization: Literal["sigmoid", "softmax"]) -> Tensor: + """Normalize logits if needed. + + If input tensor is outside the [0,1] we assume that logits are provided and apply the normalization. + Use torch.where to prevent device-host sync. + + Args: + tensor: input tensor that may be logits or probabilities + normalization: normalization method, either 'sigmoid' or 'softmax' + + Returns: + normalized tensor if needed + + Example: + >>> import torch + >>> tensor = torch.tensor([-1.0, 0.0, 1.0]) + >>> normalize_logits_if_needed(tensor, normalization="sigmoid") + tensor([0.2689, 0.5000, 0.7311]) + >>> tensor = torch.tensor([[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]]) + >>> normalize_logits_if_needed(tensor, normalization="softmax") + tensor([[0.0900, 0.2447, 0.6652], + [0.6652, 0.2447, 0.0900]]) + >>> tensor = torch.tensor([0.0, 0.5, 1.0]) + >>> normalize_logits_if_needed(tensor, normalization="sigmoid") + tensor([0.0000, 0.5000, 1.0000]) + + """ + # decrease sigmoid on cpu . + if tensor.device == torch.device("cpu"): + if not torch.all((tensor >= 0) * (tensor <= 1)): + tensor = tensor.sigmoid() if normalization == "sigmoid" else torch.softmax(tensor, dim=1) + return tensor + + # decrease device-host sync on device . + condition = ((tensor < 0) | (tensor > 1)).any() + return torch.where( + condition, + torch.sigmoid(tensor) if normalization == "sigmoid" else torch.softmax(tensor, dim=1), + tensor, + ) diff --git a/tests/unittests/classification/test_accuracy.py b/tests/unittests/classification/test_accuracy.py index 65e42c00b07..59b466d33c1 100644 --- a/tests/unittests/classification/test_accuracy.py +++ b/tests/unittests/classification/test_accuracy.py @@ -27,6 +27,7 @@ multilabel_accuracy, ) from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import NUM_CLASSES, THRESHOLD from unittests._helpers import seed_all @@ -153,8 +154,8 @@ def test_binary_accuracy_half_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, @@ -310,8 +311,8 @@ def test_multiclass_accuracy_half_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, @@ -585,8 +586,8 @@ def test_multilabel_accuracy_half_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, diff --git a/tests/unittests/classification/test_auroc.py b/tests/unittests/classification/test_auroc.py index 30d4acb470c..c7fdb54d6c1 100644 --- a/tests/unittests/classification/test_auroc.py +++ b/tests/unittests/classification/test_auroc.py @@ -24,6 +24,7 @@ from torchmetrics.functional.classification.auroc import binary_auroc, multiclass_auroc, multilabel_auroc from torchmetrics.functional.classification.roc import binary_roc from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import NUM_CLASSES from unittests._helpers import seed_all @@ -102,8 +103,8 @@ def test_binary_auroc_dtype_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, diff --git a/tests/unittests/classification/test_average_precision.py b/tests/unittests/classification/test_average_precision.py index da0dc2f56b6..cf37360e832 100644 --- a/tests/unittests/classification/test_average_precision.py +++ b/tests/unittests/classification/test_average_precision.py @@ -33,6 +33,7 @@ ) from torchmetrics.functional.classification.precision_recall_curve import binary_precision_recall_curve from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import NUM_CLASSES from unittests._helpers import seed_all @@ -106,8 +107,8 @@ def test_binary_average_precision_differentiability(self, inputs): def test_binary_average_precision_dtype_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, diff --git a/tests/unittests/classification/test_calibration_error.py b/tests/unittests/classification/test_calibration_error.py index b8b6bfc1646..8e2556c0533 100644 --- a/tests/unittests/classification/test_calibration_error.py +++ b/tests/unittests/classification/test_calibration_error.py @@ -29,6 +29,7 @@ multiclass_calibration_error, ) from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import NUM_CLASSES from unittests._helpers import seed_all @@ -112,8 +113,8 @@ def test_binary_calibration_error_differentiability(self, inputs): def test_binary_calibration_error_dtype_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, diff --git a/tests/unittests/classification/test_cohen_kappa.py b/tests/unittests/classification/test_cohen_kappa.py index 1f2585372bd..4c4a411aab7 100644 --- a/tests/unittests/classification/test_cohen_kappa.py +++ b/tests/unittests/classification/test_cohen_kappa.py @@ -21,6 +21,7 @@ from torchmetrics.classification.cohen_kappa import BinaryCohenKappa, CohenKappa, MulticlassCohenKappa from torchmetrics.functional.classification.cohen_kappa import binary_cohen_kappa, multiclass_cohen_kappa from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import NUM_CLASSES, THRESHOLD from unittests._helpers import seed_all @@ -103,8 +104,8 @@ def test_binary_cohen_kappa_dtypes_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, @@ -206,8 +207,8 @@ def test_multiclass_cohen_kappa_dtypes_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, diff --git a/tests/unittests/classification/test_confusion_matrix.py b/tests/unittests/classification/test_confusion_matrix.py index 4d27dfc2069..23777dbdc2a 100644 --- a/tests/unittests/classification/test_confusion_matrix.py +++ b/tests/unittests/classification/test_confusion_matrix.py @@ -30,6 +30,7 @@ multilabel_confusion_matrix, ) from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import NUM_CLASSES, THRESHOLD from unittests._helpers import seed_all @@ -114,8 +115,8 @@ def test_binary_confusion_matrix_dtype_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, @@ -367,8 +368,8 @@ def test_multilabel_confusion_matrix_dtype_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, diff --git a/tests/unittests/classification/test_exact_match.py b/tests/unittests/classification/test_exact_match.py index 5afd4c00e40..3cb8caa2061 100644 --- a/tests/unittests/classification/test_exact_match.py +++ b/tests/unittests/classification/test_exact_match.py @@ -20,6 +20,7 @@ from torchmetrics.classification.exact_match import ExactMatch, MulticlassExactMatch, MultilabelExactMatch from torchmetrics.functional.classification.exact_match import multiclass_exact_match, multilabel_exact_match from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import NUM_CLASSES, THRESHOLD from unittests._helpers import seed_all @@ -121,8 +122,8 @@ def test_multiclass_exact_match_half_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, @@ -250,8 +251,8 @@ def test_multilabel_exact_match_half_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, diff --git a/tests/unittests/classification/test_f_beta.py b/tests/unittests/classification/test_f_beta.py index 075e37cc699..3c3e429f232 100644 --- a/tests/unittests/classification/test_f_beta.py +++ b/tests/unittests/classification/test_f_beta.py @@ -40,6 +40,7 @@ multilabel_fbeta_score, ) from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import NUM_CLASSES, THRESHOLD from unittests._helpers import seed_all @@ -171,8 +172,8 @@ def test_binary_fbeta_score_half_cpu(self, inputs, module, functional, compare, """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, @@ -349,8 +350,8 @@ def test_multiclass_fbeta_score_half_cpu(self, inputs, module, functional, compa """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, @@ -608,8 +609,8 @@ def test_multilabel_fbeta_score_half_cpu(self, inputs, module, functional, compa """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, diff --git a/tests/unittests/classification/test_group_fairness.py b/tests/unittests/classification/test_group_fairness.py index 5f627676f09..d1899831d7e 100644 --- a/tests/unittests/classification/test_group_fairness.py +++ b/tests/unittests/classification/test_group_fairness.py @@ -26,6 +26,7 @@ from torchmetrics import Metric from torchmetrics.classification.group_fairness import BinaryFairness from torchmetrics.functional.classification.group_fairness import binary_fairness +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import THRESHOLD from unittests._helpers import seed_all @@ -282,8 +283,8 @@ def test_binary_fairness_half_cpu(self, inputs, dtype): """Test class implementation of metric.""" preds, target, groups = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, diff --git a/tests/unittests/classification/test_hamming_distance.py b/tests/unittests/classification/test_hamming_distance.py index a7a42db61b0..6d4f0f824cc 100644 --- a/tests/unittests/classification/test_hamming_distance.py +++ b/tests/unittests/classification/test_hamming_distance.py @@ -31,6 +31,7 @@ multilabel_hamming_distance, ) from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import NUM_CLASSES, THRESHOLD from unittests._helpers import seed_all @@ -138,8 +139,8 @@ def test_binary_hamming_distance_differentiability(self, inputs): def test_binary_hamming_distance_dtype_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, @@ -301,8 +302,8 @@ def test_multiclass_hamming_distance_differentiability(self, inputs): def test_multiclass_hamming_distance_dtype_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, @@ -485,8 +486,8 @@ def test_multilabel_hamming_distance_differentiability(self, inputs): def test_multilabel_hamming_distance_dtype_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, diff --git a/tests/unittests/classification/test_jaccard.py b/tests/unittests/classification/test_jaccard.py index 0a20a2e458a..606825f7e71 100644 --- a/tests/unittests/classification/test_jaccard.py +++ b/tests/unittests/classification/test_jaccard.py @@ -32,6 +32,7 @@ multilabel_jaccard_index, ) from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import NUM_CLASSES, THRESHOLD from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index @@ -108,8 +109,8 @@ def test_binary_jaccard_index_differentiability(self, inputs): def test_binary_jaccard_index_dtype_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, @@ -355,8 +356,8 @@ def test_multilabel_jaccard_index_differentiability(self, inputs): def test_multilabel_jaccard_index_dtype_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, diff --git a/tests/unittests/classification/test_logauc.py b/tests/unittests/classification/test_logauc.py index 26cb395f45e..6494ac72372 100644 --- a/tests/unittests/classification/test_logauc.py +++ b/tests/unittests/classification/test_logauc.py @@ -27,6 +27,7 @@ from torchmetrics.functional.classification.logauc import binary_logauc, multiclass_logauc, multilabel_logauc from torchmetrics.functional.classification.roc import binary_roc from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import NUM_CLASSES from unittests._helpers import seed_all @@ -105,8 +106,8 @@ def test_binary_logauc_dtype_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, diff --git a/tests/unittests/classification/test_matthews_corrcoef.py b/tests/unittests/classification/test_matthews_corrcoef.py index b340db8d713..fc4d762384b 100644 --- a/tests/unittests/classification/test_matthews_corrcoef.py +++ b/tests/unittests/classification/test_matthews_corrcoef.py @@ -30,6 +30,7 @@ multilabel_matthews_corrcoef, ) from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import NUM_CLASSES, THRESHOLD from unittests._helpers import seed_all @@ -105,8 +106,8 @@ def test_binary_matthews_corrcoef_differentiability(self, inputs): def test_binary_matthews_corrcoef_dtype_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, @@ -287,8 +288,8 @@ def test_multilabel_matthews_corrcoef_differentiability(self, inputs): def test_multilabel_matthews_corrcoef_dtype_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, diff --git a/tests/unittests/classification/test_negative_predictive_value.py b/tests/unittests/classification/test_negative_predictive_value.py index 464884ca82a..2fb352bc74f 100644 --- a/tests/unittests/classification/test_negative_predictive_value.py +++ b/tests/unittests/classification/test_negative_predictive_value.py @@ -31,6 +31,7 @@ multilabel_negative_predictive_value, ) from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import NUM_CLASSES, THRESHOLD from unittests._helpers import seed_all @@ -155,8 +156,8 @@ def test_binary_negative_predictive_value_differentiability(self, inputs): def test_binary_negative_predictive_value_dtype_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, @@ -332,8 +333,8 @@ def test_multiclass_negative_predictive_value_differentiability(self, inputs): def test_multiclass_negative_predictive_value_dtype_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, @@ -538,8 +539,8 @@ def test_multilabel_negative_predictive_value_differentiability(self, inputs): def test_multilabel_negative_predictive_value_dtype_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, diff --git a/tests/unittests/classification/test_precision_fixed_recall.py b/tests/unittests/classification/test_precision_fixed_recall.py index b6649ad869d..03c8ee7654f 100644 --- a/tests/unittests/classification/test_precision_fixed_recall.py +++ b/tests/unittests/classification/test_precision_fixed_recall.py @@ -32,6 +32,7 @@ multilabel_precision_at_fixed_recall, ) from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import NUM_CLASSES from unittests._helpers import seed_all @@ -125,8 +126,8 @@ def test_binary_precision_at_fixed_recall_differentiability(self, inputs): def test_binary_precision_at_fixed_recall_dtype_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, diff --git a/tests/unittests/classification/test_precision_recall.py b/tests/unittests/classification/test_precision_recall.py index 7717ffa5b0d..56d87ebf073 100644 --- a/tests/unittests/classification/test_precision_recall.py +++ b/tests/unittests/classification/test_precision_recall.py @@ -40,6 +40,7 @@ multilabel_recall, ) from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import NUM_CLASSES, THRESHOLD from unittests._helpers import seed_all @@ -172,8 +173,8 @@ def test_binary_precision_recall_differentiability(self, inputs, module, functio def test_binary_precision_recall_half_cpu(self, inputs, module, functional, compare, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, @@ -355,8 +356,8 @@ def test_multiclass_precision_recall_differentiability(self, inputs, module, fun def test_multiclass_precision_recall_half_cpu(self, inputs, module, functional, compare, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, @@ -612,8 +613,8 @@ def test_multilabel_precision_recall_differentiability(self, inputs, module, fun def test_multilabel_precision_recall_half_cpu(self, inputs, module, functional, compare, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, diff --git a/tests/unittests/classification/test_precision_recall_curve.py b/tests/unittests/classification/test_precision_recall_curve.py index 7c034c528e6..d6a79b9b5cb 100644 --- a/tests/unittests/classification/test_precision_recall_curve.py +++ b/tests/unittests/classification/test_precision_recall_curve.py @@ -32,6 +32,7 @@ multilabel_precision_recall_curve, ) from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import NUM_CLASSES from unittests._helpers import seed_all @@ -105,8 +106,8 @@ def test_binary_precision_recall_curve_differentiability(self, inputs): def test_binary_precision_recall_curve_dtype_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, diff --git a/tests/unittests/classification/test_ranking.py b/tests/unittests/classification/test_ranking.py index e85d38cde05..4727ab882a1 100644 --- a/tests/unittests/classification/test_ranking.py +++ b/tests/unittests/classification/test_ranking.py @@ -30,6 +30,7 @@ multilabel_ranking_average_precision, multilabel_ranking_loss, ) +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import NUM_CLASSES from unittests._helpers import seed_all @@ -117,8 +118,8 @@ def test_multilabel_ranking_differentiability(self, inputs, metric, functional_m def test_multilabel_ranking_dtype_cpu(self, inputs, metric, functional_metric, ref_metric, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") if dtype == torch.half and functional_metric == multilabel_ranking_average_precision: pytest.xfail( reason="multilabel_ranking_average_precision requires torch.unique which is not implemented for half" diff --git a/tests/unittests/classification/test_recall_fixed_precision.py b/tests/unittests/classification/test_recall_fixed_precision.py index 2d73d64f264..5bbf2e55e58 100644 --- a/tests/unittests/classification/test_recall_fixed_precision.py +++ b/tests/unittests/classification/test_recall_fixed_precision.py @@ -32,6 +32,7 @@ multilabel_recall_at_fixed_precision, ) from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import NUM_CLASSES from unittests._helpers import seed_all @@ -129,8 +130,8 @@ def test_binary_recall_at_fixed_precision_differentiability(self, inputs): def test_binary_recall_at_fixed_precision_dtype_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, diff --git a/tests/unittests/classification/test_roc.py b/tests/unittests/classification/test_roc.py index f6cbd173128..5ad6dee35fa 100644 --- a/tests/unittests/classification/test_roc.py +++ b/tests/unittests/classification/test_roc.py @@ -23,6 +23,7 @@ from torchmetrics.classification.roc import ROC, BinaryROC, MulticlassROC, MultilabelROC from torchmetrics.functional.classification.roc import binary_roc, multiclass_roc, multilabel_roc from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import NUM_CLASSES from unittests._helpers import seed_all @@ -98,8 +99,8 @@ def test_binary_roc_differentiability(self, inputs): def test_binary_roc_dtype_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, diff --git a/tests/unittests/classification/test_sensitivity_specificity.py b/tests/unittests/classification/test_sensitivity_specificity.py index 47884bae2a3..cc85f6e4e28 100644 --- a/tests/unittests/classification/test_sensitivity_specificity.py +++ b/tests/unittests/classification/test_sensitivity_specificity.py @@ -33,7 +33,7 @@ multilabel_sensitivity_at_specificity, ) from torchmetrics.metric import Metric -from torchmetrics.utilities.imports import _SKLEARN_GREATER_EQUAL_1_3 +from torchmetrics.utilities.imports import _SKLEARN_GREATER_EQUAL_1_3, _TORCH_GREATER_EQUAL_2_1 from unittests import NUM_CLASSES from unittests._helpers import seed_all @@ -153,8 +153,8 @@ def test_binary_sensitivity_at_specificity_differentiability(self, inputs): def test_binary_sensitivity_at_specificity_dtype_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, diff --git a/tests/unittests/classification/test_specificity.py b/tests/unittests/classification/test_specificity.py index fe5fd8977a8..437d9e07af9 100644 --- a/tests/unittests/classification/test_specificity.py +++ b/tests/unittests/classification/test_specificity.py @@ -31,6 +31,7 @@ multilabel_specificity, ) from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import NUM_CLASSES, THRESHOLD from unittests._helpers import seed_all @@ -151,8 +152,8 @@ def test_binary_specificity_differentiability(self, inputs): def test_binary_specificity_dtype_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, @@ -328,8 +329,8 @@ def test_multiclass_specificity_differentiability(self, inputs): def test_multiclass_specificity_dtype_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, @@ -532,8 +533,8 @@ def test_multilabel_specificity_differentiability(self, inputs): def test_multilabel_specificity_dtype_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, diff --git a/tests/unittests/classification/test_specificity_sensitivity.py b/tests/unittests/classification/test_specificity_sensitivity.py index 934d669678a..9e866dbabd9 100644 --- a/tests/unittests/classification/test_specificity_sensitivity.py +++ b/tests/unittests/classification/test_specificity_sensitivity.py @@ -33,6 +33,7 @@ multilabel_specificity_at_sensitivity, ) from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import NUM_CLASSES from unittests._helpers import seed_all @@ -148,8 +149,8 @@ def test_binary_specificity_at_sensitivity_differentiability(self, inputs): def test_binary_specificity_at_sensitivity_dtype_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, diff --git a/tests/unittests/classification/test_stat_scores.py b/tests/unittests/classification/test_stat_scores.py index 2a5a53bb8aa..fee079011be 100644 --- a/tests/unittests/classification/test_stat_scores.py +++ b/tests/unittests/classification/test_stat_scores.py @@ -31,6 +31,7 @@ multilabel_stat_scores, ) from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import NUM_CLASSES, THRESHOLD from unittests._helpers import seed_all @@ -135,8 +136,8 @@ def test_binary_stat_scores_differentiability(self, inputs): def test_binary_stat_scores_dtype_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, @@ -300,8 +301,8 @@ def test_multiclass_stat_scores_differentiability(self, inputs): def test_multiclass_stat_scores_dtype_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target, @@ -552,8 +553,8 @@ def test_multilabel_stat_scores_differentiability(self, inputs): def test_multilabel_stat_scores_dtype_cpu(self, inputs, dtype): """Test dtype support of the metric on CPU.""" preds, target = inputs - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") self.run_precision_test_cpu( preds=preds, target=target,