Skip to content

Commit

Permalink
skip for old sklearn versions
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki committed May 2, 2024
1 parent b82bc0e commit e26e748
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
1 change: 1 addition & 0 deletions src/torchmetrics/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,5 +60,6 @@
_MECAB_KO_DIC_AVAILABLE = RequirementCache("mecab_ko_dic")
_IPADIC_AVAILABLE = RequirementCache("ipadic")
_SENTENCEPIECE_AVAILABLE = RequirementCache("sentencepiece")
_SKLEARN_GREATER_EQUAL_1_3 = RequirementCache("scikit-learn>=1.3.0")

_LATEX_AVAILABLE: bool = shutil.which("latex") is not None
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
multilabel_sensitivity_at_specificity,
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_11
from torchmetrics.utilities.imports import _SKLEARN_GREATER_EQUAL_1_3, _TORCH_GREATER_EQUAL_1_11

from unittests import NUM_CLASSES
from unittests._helpers import seed_all
Expand Down Expand Up @@ -83,6 +83,7 @@ def _reference_sklearn_sensitivity_at_specificity_binary(preds, target, min_spec
return _sensitivity_at_specificity_x_multilabel(preds, target, min_specificity)


@pytest.mark.skipif(not _SKLEARN_GREATER_EQUAL_1_3, reason="metric does not support scikit-learn versions below 1.3")
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_11, reason="metric does not support torch versions below 1.11")
@pytest.mark.parametrize("inputs", (_binary_cases[1], _binary_cases[2], _binary_cases[4], _binary_cases[5]))
class TestBinarySensitivityAtSpecificity(MetricTester):
Expand Down Expand Up @@ -209,6 +210,7 @@ def _reference_sklearn_sensitivity_at_specificity_multiclass(preds, target, min_
return sensitivity, thresholds


@pytest.mark.skipif(not _SKLEARN_GREATER_EQUAL_1_3, reason="metric does not support scikit-learn versions below 1.3")
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_11, reason="metric does not support torch versions below 1.11")
@pytest.mark.parametrize(
"inputs", (_multiclass_cases[1], _multiclass_cases[2], _multiclass_cases[4], _multiclass_cases[5])
Expand Down Expand Up @@ -340,6 +342,7 @@ def _reference_sklearn_sensitivity_at_specificity_multilabel(preds, target, min_
return sensitivity, thresholds


@pytest.mark.skipif(not _SKLEARN_GREATER_EQUAL_1_3, reason="metric does not support scikit-learn versions below 1.3")
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_11, reason="metric does not support torch versions below 1.11")
@pytest.mark.parametrize(
"inputs", (_multilabel_cases[1], _multilabel_cases[2], _multilabel_cases[4], _multilabel_cases[5])
Expand Down

0 comments on commit e26e748

Please sign in to comment.