diff --git a/.github/actions/pull-caches/action.yml b/.github/actions/pull-caches/action.yml index a5cf7cafe2d..fec93a58bde 100644 --- a/.github/actions/pull-caches/action.yml +++ b/.github/actions/pull-caches/action.yml @@ -90,5 +90,5 @@ runs: - name: Restored References continue-on-error: true - run: ls -lh tests/_cache-references/ + run: py-tree tests/_cache-references/ --show_hidden shell: bash diff --git a/.github/actions/push-caches/action.yml b/.github/actions/push-caches/action.yml index da757f09b5d..8f5db36b6dd 100644 --- a/.github/actions/push-caches/action.yml +++ b/.github/actions/push-caches/action.yml @@ -99,5 +99,5 @@ runs: key: cache-references - name: Post References - run: ls -lh tests/_cache-references/ + run: py-tree tests/_cache-references/ --show_hidden shell: bash diff --git a/.gitignore b/.gitignore index 6f45b493e3c..cbe31a9b316 100644 --- a/.gitignore +++ b/.gitignore @@ -40,7 +40,7 @@ pip-delete-this-directory.txt # Unit test / coverage reports tests/_data/ data.zip -tests/_reference-cache/ +tests/_cache-references/ htmlcov/ .coverage .coverage.* diff --git a/src/torchmetrics/utilities/imports.py b/src/torchmetrics/utilities/imports.py index 085269dd4d9..6e80411f5c1 100644 --- a/src/torchmetrics/utilities/imports.py +++ b/src/torchmetrics/utilities/imports.py @@ -17,10 +17,8 @@ import sys from lightning_utilities.core.imports import RequirementCache -from packaging.version import Version, parse _PYTHON_VERSION = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" -_PYTHON_LOWER_3_8 = parse(_PYTHON_VERSION) < Version("3.8") _TORCH_LOWER_2_0 = RequirementCache("torch<2.0.0") _TORCH_GREATER_EQUAL_1_11 = RequirementCache("torch>=1.11.0") _TORCH_GREATER_EQUAL_1_12 = RequirementCache("torch>=1.12.0") @@ -29,7 +27,6 @@ _TORCH_GREATER_EQUAL_2_1 = RequirementCache("torch>=2.1.0") _TORCH_GREATER_EQUAL_2_2 = RequirementCache("torch>=2.2.0") -_JIWER_AVAILABLE = RequirementCache("jiwer") _NLTK_AVAILABLE = RequirementCache("nltk") _ROUGE_SCORE_AVAILABLE = RequirementCache("rouge_score") _BERTSCORE_AVAILABLE = RequirementCache("bert_score") @@ -49,7 +46,6 @@ _GAMMATONE_AVAILABLE = RequirementCache("gammatone") _TORCHAUDIO_AVAILABLE = RequirementCache("torchaudio") _TORCHAUDIO_GREATER_EQUAL_0_10 = RequirementCache("torchaudio>=0.10.0") -_SACREBLEU_AVAILABLE = RequirementCache("sacrebleu") _REGEX_AVAILABLE = RequirementCache("regex") _PYSTOI_AVAILABLE = RequirementCache("pystoi") _FAST_BSS_EVAL_AVAILABLE = RequirementCache("fast_bss_eval") diff --git a/tests/unittests/classification/test_group_fairness.py b/tests/unittests/classification/test_group_fairness.py index 811e2a55ab9..4d76b9301dd 100644 --- a/tests/unittests/classification/test_group_fairness.py +++ b/tests/unittests/classification/test_group_fairness.py @@ -26,7 +26,6 @@ from torchmetrics import Metric from torchmetrics.classification.group_fairness import BinaryFairness from torchmetrics.functional.classification.group_fairness import binary_fairness -from torchmetrics.utilities.imports import _PYTHON_LOWER_3_8 from unittests import THRESHOLD from unittests.classification._inputs import _group_cases @@ -222,7 +221,6 @@ def run_precision_test_gpu( @mock.patch("unittests.helpers.testers._assert_tensor", _assert_tensor) @mock.patch("unittests.helpers.testers._assert_allclose", _assert_allclose) -@pytest.mark.skipif(_PYTHON_LOWER_3_8, reason="`TestBinaryFairness` requires `python>=3.8`.") @pytest.mark.parametrize("inputs", _group_cases) class TestBinaryFairness(BinaryFairnessTester): """Test class for `BinaryFairness` metric.""" diff --git a/tests/unittests/image/test_fid.py b/tests/unittests/image/test_fid.py index 83b243200da..252f0d0ebba 100644 --- a/tests/unittests/image/test_fid.py +++ b/tests/unittests/image/test_fid.py @@ -34,7 +34,7 @@ def test_no_train_network_missing_torch_fidelity(): NoTrainInceptionV3(name="inception-v3-compat", features_list=["2048"]) -@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity") +@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity") def test_no_train(): """Assert that metric never leaves evaluation mode.""" @@ -52,7 +52,7 @@ def forward(self, x): assert not model.metric.inception.training, "FID metric was changed to training mode which should not happen" -@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity") +@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity") def test_fid_pickle(): """Assert that we can initialize the metric and pickle it.""" metric = FrechetInceptionDistance() @@ -80,7 +80,7 @@ def test_fid_raises_errors_and_warnings(): _ = FrechetInceptionDistance(feature=[1, 2]) -@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity") +@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity") @pytest.mark.parametrize("feature", [64, 192, 768, 2048]) def test_fid_same_input(feature): """If real and fake are update on the same data the fid score should be 0.""" @@ -111,7 +111,7 @@ def __len__(self) -> int: @pytest.mark.skipif(not torch.cuda.is_available(), reason="test is too slow without gpu") -@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity") +@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity") @pytest.mark.parametrize("equal_size", [False, True]) def test_compare_fid(tmpdir, equal_size, feature=768): """Check that the hole pipeline give the same result as torch-fidelity.""" diff --git a/tests/unittests/image/test_inception.py b/tests/unittests/image/test_inception.py index 552180cbbcc..627e6a4a57a 100644 --- a/tests/unittests/image/test_inception.py +++ b/tests/unittests/image/test_inception.py @@ -24,7 +24,7 @@ torch.manual_seed(42) -@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity") +@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity") def test_no_train(): """Assert that metric never leaves evaluation mode.""" @@ -44,7 +44,7 @@ def forward(self, x): ), "InceptionScore metric was changed to training mode which should not happen" -@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity") +@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity") def test_is_pickle(): """Assert that we can initialize the metric and pickle it.""" metric = InceptionScore() @@ -79,7 +79,7 @@ def test_is_raises_errors_and_warnings(): InceptionScore(feature=[1, 2]) -@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity") +@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity") def test_is_update_compute(): """Test that inception score works as expected.""" metric = InceptionScore() @@ -105,7 +105,7 @@ def __len__(self) -> int: @pytest.mark.skipif(not torch.cuda.is_available(), reason="test is too slow without gpu") -@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity") +@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity") @pytest.mark.parametrize("compute_on_cpu", [True, False]) def test_compare_is(tmpdir, compute_on_cpu): """Check that the hole pipeline give the same result as torch-fidelity.""" diff --git a/tests/unittests/image/test_kid.py b/tests/unittests/image/test_kid.py index 34d223e24af..a754768003c 100644 --- a/tests/unittests/image/test_kid.py +++ b/tests/unittests/image/test_kid.py @@ -24,7 +24,7 @@ torch.manual_seed(42) -@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity") +@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity") def test_no_train(): """Assert that metric never leaves evaluation mode.""" @@ -42,7 +42,7 @@ def forward(self, x): assert not model.metric.inception.training, "FID metric was changed to training mode which should not happen" -@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity") +@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity") def test_kid_pickle(): """Assert that we can initialize the metric and pickle it.""" metric = KernelInceptionDistance() @@ -83,7 +83,7 @@ def test_kid_raises_errors_and_warnings(): m.compute() -@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity") +@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity") def test_kid_extra_parameters(): """Test that the different input arguments raises expected errors if wrong.""" with pytest.raises(ValueError, match="Argument `subsets` expected to be integer larger than 0"): @@ -102,7 +102,7 @@ def test_kid_extra_parameters(): KernelInceptionDistance(coef=-1) -@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity") +@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity") @pytest.mark.parametrize("feature", [64, 192, 768, 2048]) def test_kid_same_input(feature): """Test that the metric works.""" @@ -132,7 +132,7 @@ def __len__(self) -> int: @pytest.mark.skipif(not torch.cuda.is_available(), reason="test is too slow without gpu") -@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity") +@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity") def test_compare_kid(tmpdir, feature=2048): """Check that the hole pipeline give the same result as torch-fidelity.""" from torch_fidelity import calculate_metrics diff --git a/tests/unittests/image/test_lpips.py b/tests/unittests/image/test_lpips.py index 0a7171ab996..026c2b91770 100644 --- a/tests/unittests/image/test_lpips.py +++ b/tests/unittests/image/test_lpips.py @@ -16,11 +16,10 @@ import pytest import torch -from lpips import LPIPS as LPIPS_reference # noqa: N811 from torch import Tensor from torchmetrics.functional.image.lpips import learned_perceptual_image_patch_similarity from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity -from torchmetrics.utilities.imports import _LPIPS_AVAILABLE, _TORCHVISION_AVAILABLE +from torchmetrics.utilities.imports import _TORCHVISION_AVAILABLE from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester @@ -43,7 +42,12 @@ def _reference_lpips( img1: Tensor, img2: Tensor, net_type: str, normalize: bool = False, reduction: str = "mean" ) -> Tensor: """Comparison function for tm implementation.""" - ref = LPIPS_reference(net=net_type) + try: + from lpips import LPIPS + except ImportError: + pytest.skip("test requires lpips package to be installed") + + ref = LPIPS(net=net_type) res = ref(img1, img2, normalize=normalize).detach().cpu().numpy() if reduction == "mean": return res.mean() @@ -51,7 +55,6 @@ def _reference_lpips( @pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="test requires that torchvision is installed") -@pytest.mark.skipif(not _LPIPS_AVAILABLE, reason="test requires that lpips is installed") class TestLPIPS(MetricTester): """Test class for `LearnedPerceptualImagePatchSimilarity` metric.""" @@ -109,7 +112,6 @@ def test_normalize_arg(normalize): @pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="test requires that torchvision is installed") -@pytest.mark.skipif(not _LPIPS_AVAILABLE, reason="test requires that lpips is installed") def test_error_on_wrong_init(): """Test class raises the expected errors.""" with pytest.raises(ValueError, match="Argument `net_type` must be one .*"): @@ -120,7 +122,6 @@ def test_error_on_wrong_init(): @pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="test requires that torchvision is installed") -@pytest.mark.skipif(not _LPIPS_AVAILABLE, reason="test requires that lpips is installed") @pytest.mark.parametrize( ("inp1", "inp2"), [ diff --git a/tests/unittests/image/test_mifid.py b/tests/unittests/image/test_mifid.py index d5bdb95cf68..ae44982b350 100644 --- a/tests/unittests/image/test_mifid.py +++ b/tests/unittests/image/test_mifid.py @@ -98,7 +98,7 @@ def calculate_mifid(m1, s1, features1, m2, s2, features2): return fid_private / (distance_private_thresholded + 1e-15) -@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity") +@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity") def test_no_train(): """Assert that metric never leaves evaluation mode.""" @@ -139,7 +139,7 @@ def test_mifid_raises_errors_and_warnings(): _ = MemorizationInformedFrechetInceptionDistance(cosine_distance_eps=1.1) -@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity") +@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity") @pytest.mark.parametrize("feature", [64, 192, 768, 2048]) def test_fid_same_input(feature): """If real and fake are update on the same data the fid score should be 0.""" @@ -157,7 +157,7 @@ def test_fid_same_input(feature): @pytest.mark.skipif(not torch.cuda.is_available(), reason="test is too slow without gpu") -@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity") +@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity") @pytest.mark.parametrize("equal_size", [False, True]) def test_compare_mifid(equal_size): """Check that our implementation of MIFID is correct by comparing it to the original implementation.""" diff --git a/tests/unittests/image/test_perceptual_path_length.py b/tests/unittests/image/test_perceptual_path_length.py index dfdd5cfde96..1eb486c6ce3 100644 --- a/tests/unittests/image/test_perceptual_path_length.py +++ b/tests/unittests/image/test_perceptual_path_length.py @@ -29,7 +29,7 @@ seed_all(42) -@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch_fidelity") +@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity") @pytest.mark.parametrize("interpolation_method", ["lerp", "slerp_any", "slerp_unit"]) def test_interpolation_methods(interpolation_method): """Test that interpolation method works as expected.""" @@ -41,7 +41,7 @@ def test_interpolation_methods(interpolation_method): assert torch.allclose(res1, res2) -@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch_fidelity") +@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity") @skip_on_running_out_of_memory() def test_sim_net(): """Check that the similarity network is the same as the one used in torch_fidelity.""" @@ -100,7 +100,7 @@ def sample(self, num_samples): return torch.randn(num_samples, self.z_size) -@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch_fidelity") +@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity") @pytest.mark.parametrize( ("argument", "match"), [ @@ -174,7 +174,7 @@ def test_raises_error_on_wrong_generator(generator, errortype, match): ppl.update(generator=generator) -@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch_fidelity") +@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity") @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") @skip_on_running_out_of_memory() def test_compare(): diff --git a/tests/unittests/multimodal/test_clip_iqa.py b/tests/unittests/multimodal/test_clip_iqa.py index c7057226759..314ff0013b8 100644 --- a/tests/unittests/multimodal/test_clip_iqa.py +++ b/tests/unittests/multimodal/test_clip_iqa.py @@ -71,7 +71,7 @@ def _reference_clip_iqa(preds, target, reduce=False): return res.sum() if reduce else res -@pytest.mark.skipif(not _PIQ_GREATER_EQUAL_0_8, reason="test requires piq>=0.8") +@pytest.mark.skipif(not _PIQ_GREATER_EQUAL_0_8, reason="metric requires piq>=0.8") @pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_10, reason="test requires transformers>=4.10") class TestCLIPIQA(MetricTester): """Test clip iqa metric.""" @@ -104,7 +104,7 @@ def test_clip_iqa_functional(self, shapes): @skip_on_connection_issues() -@pytest.mark.skipif(not _PIQ_GREATER_EQUAL_0_8, reason="test requires piq>=0.8") +@pytest.mark.skipif(not _PIQ_GREATER_EQUAL_0_8, reason="metric requires piq>=0.8") @pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_10, reason="test requires transformers>=4.10") @pytest.mark.skipif(not os.path.isfile(_SAMPLE_IMAGE), reason="test image not found") def test_for_correctness_sample_images(): @@ -121,7 +121,7 @@ def test_for_correctness_sample_images(): @skip_on_connection_issues() -@pytest.mark.skipif(not _PIQ_GREATER_EQUAL_0_8, reason="test requires piq>=0.8") +@pytest.mark.skipif(not _PIQ_GREATER_EQUAL_0_8, reason="metric requires piq>=0.8") @pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_10, reason="test requires transformers>=4.10") @pytest.mark.parametrize( "model", @@ -148,7 +148,7 @@ def test_other_models(model): @skip_on_connection_issues() -@pytest.mark.skipif(not _PIQ_GREATER_EQUAL_0_8, reason="test requires piq>=0.8") +@pytest.mark.skipif(not _PIQ_GREATER_EQUAL_0_8, reason="metric requires piq>=0.8") @pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_10, reason="test requires transformers>=4.10") @pytest.mark.parametrize( "prompts", @@ -200,7 +200,7 @@ def test_prompt(prompts): @skip_on_connection_issues() -@pytest.mark.skipif(not _PIQ_GREATER_EQUAL_0_8, reason="test requires piq>=0.8") +@pytest.mark.skipif(not _PIQ_GREATER_EQUAL_0_8, reason="metric requires piq>=0.8") @pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_10, reason="test requires transformers>=4.10") def test_plot_method(): """Test the plot method of CLIPScore separately in this file due to the skipping conditions.""" diff --git a/tests/unittests/nominal/test_cramers.py b/tests/unittests/nominal/test_cramers.py index 4cebac73e05..42b735ef510 100644 --- a/tests/unittests/nominal/test_cramers.py +++ b/tests/unittests/nominal/test_cramers.py @@ -12,13 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import itertools -import operator from functools import partial import pytest import torch -from dython.nominal import cramers_v as dython_cramers_v -from lightning_utilities.core.imports import compare_version from torchmetrics.functional.nominal.cramers import cramers_v, cramers_v_matrix from torchmetrics.nominal.cramers import CramersV @@ -63,10 +60,15 @@ def cramers_matrix_input(): def _reference_dython_cramers_v(preds, target, bias_correction, nan_strategy, nan_replace_value): + try: + from dython.nominal import cramers_v + except ImportError: + pytest.skip("This test requires `dython` package to be installed.") + preds = preds.argmax(1) if preds.ndim == 2 else preds target = target.argmax(1) if target.ndim == 2 else target - v = dython_cramers_v( + v = cramers_v( preds.numpy(), target.numpy(), bias_correction=bias_correction, @@ -87,7 +89,6 @@ def _dython_cramers_v_matrix(matrix, bias_correction, nan_strategy, nan_replace_ return cramers_v_matrix_value -@pytest.mark.skipif(compare_version("pandas", operator.lt, "1.3.2"), reason="`dython` package requires `pandas>=1.3.2`") @pytest.mark.parametrize( "preds, target", [ @@ -161,7 +162,6 @@ def test_cramers_v_differentiability(self, preds, target, bias_correction, nan_s ) -@pytest.mark.skipif(compare_version("pandas", operator.lt, "1.3.2"), reason="`dython` package requires `pandas>=1.3.2`") @pytest.mark.parametrize("bias_correction", [False, True]) @pytest.mark.parametrize(("nan_strategy", "nan_replace_value"), [("replace", 1.0), ("drop", None)]) def test_cramers_v_matrix(cramers_matrix_input, bias_correction, nan_strategy, nan_replace_value): diff --git a/tests/unittests/nominal/test_pearson.py b/tests/unittests/nominal/test_pearson.py index 44bf1c0e415..5bec1cd8121 100644 --- a/tests/unittests/nominal/test_pearson.py +++ b/tests/unittests/nominal/test_pearson.py @@ -12,13 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import itertools -import operator import pandas as pd import pytest import torch -from lightning_utilities.core.imports import compare_version -from scipy.stats.contingency import association from torchmetrics.functional.nominal.pearson import ( pearsons_contingency_coefficient, pearsons_contingency_coefficient_matrix, @@ -56,6 +53,10 @@ def pearson_matrix_input(): def _reference_pd_pearsons_t(preds, target): + try: + from scipy.stats.contingency import association + except ImportError: + pytest.skip("test requires scipy package to be installed") preds = preds.argmax(1) if preds.ndim == 2 else preds target = target.argmax(1) if target.ndim == 2 else target preds, target = preds.numpy().astype(int), target.numpy().astype(int) @@ -74,7 +75,6 @@ def _reference_pd_pearsons_t_matrix(matrix): return pearsons_t_matrix_value -@pytest.mark.skipif(compare_version("pandas", operator.lt, "1.3.2"), reason="`dython` package requires `pandas>=1.3.2`") @pytest.mark.parametrize( "preds, target", [ @@ -118,7 +118,6 @@ def test_pearsons_t_differentiability(self, preds, target): ) -@pytest.mark.skipif(compare_version("pandas", operator.lt, "1.3.2"), reason="`dython` package requires `pandas>=1.3.2`") def test_pearsons_contingency_coefficient_matrix(pearson_matrix_input): """Test matrix version of metric works as expected.""" tm_score = pearsons_contingency_coefficient_matrix(pearson_matrix_input) diff --git a/tests/unittests/nominal/test_theils_u.py b/tests/unittests/nominal/test_theils_u.py index b7ae4b29507..c06c6b9bcd2 100644 --- a/tests/unittests/nominal/test_theils_u.py +++ b/tests/unittests/nominal/test_theils_u.py @@ -12,13 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import itertools -import operator from functools import partial import pytest import torch -from dython.nominal import theils_u as dython_theils_u -from lightning_utilities.core.imports import compare_version from torchmetrics.functional.nominal.theils_u import theils_u, theils_u_matrix from torchmetrics.nominal import TheilsU @@ -63,6 +60,11 @@ def theils_u_matrix_input(): def _reference_dython_theils_u(preds, target, nan_strategy, nan_replace_value): + try: + from dython.nominal import theils_u as dython_theils_u + except ImportError: + pytest.skip("Test requires `dython` package to be installed.") + preds = preds.argmax(1) if preds.ndim == 2 else preds target = target.argmax(1) if target.ndim == 2 else target @@ -85,7 +87,6 @@ def _reference_dython_theils_u_matrix(matrix, nan_strategy, nan_replace_value): return theils_u_matrix_value -@pytest.mark.skipif(compare_version("pandas", operator.lt, "1.3.2"), reason="`dython` package requires `pandas>=1.3.2`") @pytest.mark.parametrize( "preds, target", [ @@ -153,7 +154,6 @@ def test_theils_u_differentiability(self, preds, target, nan_strategy, nan_repla ) -@pytest.mark.skipif(compare_version("pandas", operator.lt, "1.3.2"), reason="`dython` package requires `pandas>=1.3.2`") @pytest.mark.parametrize(("nan_strategy", "nan_replace_value"), [("replace", 1.0), ("drop", None)]) def test_theils_u_matrix(theils_u_matrix_input, nan_strategy, nan_replace_value): """Test matrix version of metric works as expected.""" diff --git a/tests/unittests/nominal/test_tschuprows.py b/tests/unittests/nominal/test_tschuprows.py index 48102ac6f34..91798d88d82 100644 --- a/tests/unittests/nominal/test_tschuprows.py +++ b/tests/unittests/nominal/test_tschuprows.py @@ -12,13 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import itertools -import operator import pandas as pd import pytest import torch -from lightning_utilities.core.imports import compare_version -from scipy.stats.contingency import association from torchmetrics.functional.nominal.tschuprows import tschuprows_t, tschuprows_t_matrix from torchmetrics.nominal.tschuprows import TschuprowsT @@ -53,6 +50,10 @@ def tschuprows_matrix_input(): def _reference_pd_tschuprows_t(preds, target): + try: + from scipy.stats.contingency import association + except ImportError: + pytest.skip("test requires scipy package to be installed") preds = preds.argmax(1) if preds.ndim == 2 else preds target = target.argmax(1) if target.ndim == 2 else target preds, target = preds.numpy().astype(int), target.numpy().astype(int) @@ -71,7 +72,6 @@ def _reference_pd_tschuprows_t_matrix(matrix): return tschuprows_t_matrix_value -@pytest.mark.skipif(compare_version("pandas", operator.lt, "1.3.2"), reason="`dython` package requires `pandas>=1.3.2`") @pytest.mark.parametrize( "preds, target", [ @@ -120,7 +120,6 @@ def test_tschuprows_t_differentiability(self, preds, target): ) -@pytest.mark.skipif(compare_version("pandas", operator.lt, "1.3.2"), reason="`dython` package requires `pandas>=1.3.2`") def test_tschuprows_t_matrix(tschuprows_matrix_input): """Test matrix version of metric works as expected.""" tm_score = tschuprows_t_matrix(tschuprows_matrix_input, bias_correction=False) diff --git a/tests/unittests/text/test_bertscore.py b/tests/unittests/text/test_bertscore.py index 3b2382a8488..b651ecfddb1 100644 --- a/tests/unittests/text/test_bertscore.py +++ b/tests/unittests/text/test_bertscore.py @@ -19,18 +19,13 @@ from torch import Tensor from torchmetrics.functional.text.bert import bert_score from torchmetrics.text.bert import BERTScore -from torchmetrics.utilities.imports import _BERTSCORE_AVAILABLE, _TRANSFORMERS_GREATER_EQUAL_4_4 +from torchmetrics.utilities.imports import _TRANSFORMERS_GREATER_EQUAL_4_4 from typing_extensions import Literal from unittests.helpers import skip_on_connection_issues from unittests.text._inputs import _inputs_single_reference from unittests.text.helpers import TextTester -if _BERTSCORE_AVAILABLE: - from bert_score import score as original_bert_score -else: - original_bert_score = None - _METRIC_KEY_TO_IDX = { "precision": 0, "recall": 1, @@ -45,7 +40,6 @@ @skip_on_connection_issues() @pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_4, reason="test requires transformers>4.4") -@pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score") def _reference_bert_score( preds: Sequence[str], target: Sequence[str], @@ -55,6 +49,11 @@ def _reference_bert_score( rescale_with_baseline: bool, metric_key: Literal["f1", "precision", "recall"], ) -> Tensor: + try: + from bert_score import score as original_bert_score + except ImportError: + pytest.skip("test requires bert_score package to be installed.") + score_tuple = original_bert_score( preds, target, @@ -88,7 +87,6 @@ def _reference_bert_score( [(_inputs_single_reference.preds, _inputs_single_reference.target)], ) @pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_4, reason="test requires transformers>4.4") -@pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score") class TestBERTScore(TextTester): """Tests for BERTScore.""" diff --git a/tests/unittests/text/test_cer.py b/tests/unittests/text/test_cer.py index 34c7a9735a9..ab09f3e5334 100644 --- a/tests/unittests/text/test_cer.py +++ b/tests/unittests/text/test_cer.py @@ -11,28 +11,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, List, Union +from typing import List, Union import pytest from torchmetrics.functional.text.cer import char_error_rate from torchmetrics.text.cer import CharErrorRate -from torchmetrics.utilities.imports import _JIWER_AVAILABLE from unittests.text._inputs import _inputs_error_rate_batch_size_1, _inputs_error_rate_batch_size_2 from unittests.text.helpers import TextTester -if _JIWER_AVAILABLE: - from jiwer import cer - -else: - compute_measures = Callable - def _reference_jiwer_cer(preds: Union[str, List[str]], target: Union[str, List[str]]): + try: + from jiwer import cer + except ImportError: + pytest.skip("test requires jiwer package to be installed.") + return cer(target, preds) -@pytest.mark.skipif(not _JIWER_AVAILABLE, reason="test requires jiwer") @pytest.mark.parametrize( ["preds", "targets"], [ diff --git a/tests/unittests/text/test_chrf.py b/tests/unittests/text/test_chrf.py index 4df8e5d8c22..6dd328d2c77 100644 --- a/tests/unittests/text/test_chrf.py +++ b/tests/unittests/text/test_chrf.py @@ -18,14 +18,10 @@ from torch import Tensor, tensor from torchmetrics.functional.text.chrf import chrf_score from torchmetrics.text.chrf import CHRFScore -from torchmetrics.utilities.imports import _SACREBLEU_AVAILABLE from unittests.text._inputs import _inputs_multiple_references, _inputs_single_sentence_multiple_references from unittests.text.helpers import TextTester -if _SACREBLEU_AVAILABLE: - from sacrebleu.metrics import CHRF - def _reference_sacrebleu_chrf( preds: Sequence[str], @@ -35,6 +31,11 @@ def _reference_sacrebleu_chrf( lowercase: bool, whitespace: bool, ) -> Tensor: + try: + from sacrebleu import CHRF + except ImportError: + pytest.skip("test requires sacrebleu package to be installed") + sacrebleu_chrf = CHRF( char_order=char_order, word_order=word_order, lowercase=lowercase, whitespace=whitespace, eps_smoothing=True ) @@ -59,7 +60,6 @@ def _reference_sacrebleu_chrf( ["preds", "targets"], [(_inputs_multiple_references.preds, _inputs_multiple_references.target)], ) -@pytest.mark.skipif(not _SACREBLEU_AVAILABLE, reason="test requires sacrebleu") class TestCHRFScore(TextTester): """Test class for `CHRFScore` metric.""" diff --git a/tests/unittests/text/test_infolm.py b/tests/unittests/text/test_infolm.py index 1ddd6d9bcfe..d8611695ff3 100644 --- a/tests/unittests/text/test_infolm.py +++ b/tests/unittests/text/test_infolm.py @@ -36,10 +36,10 @@ def _reference_infolm_score(preds, target, model_name, information_measure, idf, https://github.com/stancld/infolm-docker. """ - if model_name != "google/bert_uncased_L-2_H-128_A-2": + allowed_model = "google/bert_uncased_L-2_H-128_A-2" + if model_name != allowed_model: raise ValueError( - "`model_name` is expected to be 'google/bert_uncased_L-2_H-128_A-2' as this model was used for the result " - "generation." + f"`model_name` is expected to be '{allowed_model}' as this model was used for the result generation." ) precomputed_result = { "kl_divergence": torch.tensor([-3.2250, -0.1784, -0.1784, -2.2182]), diff --git a/tests/unittests/text/test_mer.py b/tests/unittests/text/test_mer.py index 9ff0823f173..e6f5222c3b1 100644 --- a/tests/unittests/text/test_mer.py +++ b/tests/unittests/text/test_mer.py @@ -11,30 +11,27 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, List, Union +from typing import List, Union import pytest from torchmetrics.functional.text.mer import match_error_rate from torchmetrics.text.mer import MatchErrorRate -from torchmetrics.utilities.imports import _JIWER_AVAILABLE from unittests.helpers import seed_all from unittests.text._inputs import _inputs_error_rate_batch_size_1, _inputs_error_rate_batch_size_2 from unittests.text.helpers import TextTester -if _JIWER_AVAILABLE: - from jiwer import compute_measures -else: - compute_measures: Callable - seed_all(42) def _reference_jiwer_mer(preds: Union[str, List[str]], target: Union[str, List[str]]): + try: + from jiwer import compute_measures + except ImportError: + pytest.skip("test requires jiwer package to be installed") return compute_measures(target, preds)["mer"] -@pytest.mark.skipif(not _JIWER_AVAILABLE, reason="test requires jiwer") @pytest.mark.parametrize( ["preds", "targets"], [ diff --git a/tests/unittests/text/test_rouge.py b/tests/unittests/text/test_rouge.py index 8d59929c488..c9eec8a055a 100644 --- a/tests/unittests/text/test_rouge.py +++ b/tests/unittests/text/test_rouge.py @@ -91,7 +91,7 @@ def _reference_rouge_score( return torch.tensor(rs_result, dtype=torch.float) -@pytest.mark.skipif(not _NLTK_AVAILABLE, reason="test requires nltk") +@pytest.mark.skipif(not _NLTK_AVAILABLE, reason="metric requires nltk") @pytest.mark.parametrize( ["pl_rouge_metric_key", "use_stemmer"], [ diff --git a/tests/unittests/text/test_sacre_bleu.py b/tests/unittests/text/test_sacre_bleu.py index 362fcacf59e..d74d032597d 100644 --- a/tests/unittests/text/test_sacre_bleu.py +++ b/tests/unittests/text/test_sacre_bleu.py @@ -19,18 +19,19 @@ from torch import Tensor, tensor from torchmetrics.functional.text.sacre_bleu import AVAILABLE_TOKENIZERS, _TokenizersLiteral, sacre_bleu_score from torchmetrics.text.sacre_bleu import SacreBLEUScore -from torchmetrics.utilities.imports import _SACREBLEU_AVAILABLE from unittests.text._inputs import _inputs_multiple_references from unittests.text.helpers import TextTester -if _SACREBLEU_AVAILABLE: - from sacrebleu.metrics import BLEU - def _reference_sacre_bleu( preds: Sequence[str], targets: Sequence[Sequence[str]], tokenize: str, lowercase: bool ) -> Tensor: + try: + from sacrebleu.metrics import BLEU + except ImportError: + pytest.skip("test requires sacrebleu package to be installed") + sacrebleu_fn = BLEU(tokenize=tokenize, lowercase=lowercase) # Sacrebleu expects different format of input targets = [[target[i] for target in targets] for i in range(len(targets[0]))] @@ -44,7 +45,6 @@ def _reference_sacre_bleu( ) @pytest.mark.parametrize(["lowercase"], [(False,), (True,)]) @pytest.mark.parametrize("tokenize", AVAILABLE_TOKENIZERS) -@pytest.mark.skipif(not _SACREBLEU_AVAILABLE, reason="test requires sacrebleu") class TestSacreBLEUScore(TextTester): """Test class for `SacreBLEUScore` metric.""" diff --git a/tests/unittests/text/test_ter.py b/tests/unittests/text/test_ter.py index f6dd90f2c36..eb63451cf36 100644 --- a/tests/unittests/text/test_ter.py +++ b/tests/unittests/text/test_ter.py @@ -18,14 +18,10 @@ from torch import Tensor, tensor from torchmetrics.functional.text.ter import translation_edit_rate from torchmetrics.text.ter import TranslationEditRate -from torchmetrics.utilities.imports import _SACREBLEU_AVAILABLE from unittests.text._inputs import _inputs_multiple_references, _inputs_single_sentence_multiple_references from unittests.text.helpers import TextTester -if _SACREBLEU_AVAILABLE: - from sacrebleu.metrics import TER as SacreTER # noqa: N811 - def _reference_sacrebleu_ter( preds: Sequence[str], @@ -35,7 +31,12 @@ def _reference_sacrebleu_ter( asian_support: bool, case_sensitive: bool, ) -> Tensor: - sacrebleu_ter = SacreTER( + try: + from sacrebleu.metrics import TER + except ImportError: + pytest.skip("test requires sacrebleu package to be installed") + + sacrebleu_ter = TER( normalized=normalized, no_punct=no_punct, asian_support=asian_support, case_sensitive=case_sensitive ) # Sacrebleu CHRF expects different format of input @@ -59,7 +60,6 @@ def _reference_sacrebleu_ter( ["preds", "targets"], [(_inputs_multiple_references.preds, _inputs_multiple_references.target)], ) -@pytest.mark.skipif(not _SACREBLEU_AVAILABLE, reason="test requires sacrebleu") class TestTER(TextTester): """Test class for `TranslationEditRate` metric.""" diff --git a/tests/unittests/text/test_wer.py b/tests/unittests/text/test_wer.py index bb781f8f0ac..6aee783d411 100644 --- a/tests/unittests/text/test_wer.py +++ b/tests/unittests/text/test_wer.py @@ -11,27 +11,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, List, Union +from typing import List, Union import pytest from torchmetrics.functional.text.wer import word_error_rate from torchmetrics.text.wer import WordErrorRate -from torchmetrics.utilities.imports import _JIWER_AVAILABLE from unittests.text._inputs import _inputs_error_rate_batch_size_1, _inputs_error_rate_batch_size_2 from unittests.text.helpers import TextTester -if _JIWER_AVAILABLE: - from jiwer import compute_measures -else: - compute_measures: Callable - def _reference_jiwer_wer(preds: Union[str, List[str]], target: Union[str, List[str]]): + try: + from jiwer import compute_measures + except ImportError: + pytest.skip("test requires jiwer package to be installed") + return compute_measures(target, preds)["wer"] -@pytest.mark.skipif(not _JIWER_AVAILABLE, reason="test requires jiwer") @pytest.mark.parametrize( ["preds", "targets"], [ diff --git a/tests/unittests/text/test_wil.py b/tests/unittests/text/test_wil.py index 9b88866071a..08ecad16284 100644 --- a/tests/unittests/text/test_wil.py +++ b/tests/unittests/text/test_wil.py @@ -14,20 +14,22 @@ from typing import List, Union import pytest -from jiwer import wil from torchmetrics.functional.text.wil import word_information_lost from torchmetrics.text.wil import WordInfoLost -from torchmetrics.utilities.imports import _JIWER_AVAILABLE from unittests.text._inputs import _inputs_error_rate_batch_size_1, _inputs_error_rate_batch_size_2 from unittests.text.helpers import TextTester def _reference_jiwer_wil(preds: Union[str, List[str]], target: Union[str, List[str]]): + try: + from jiwer import wil + except ImportError: + pytest.skip("test requires jiwer package to be installed") + return wil(target, preds) -@pytest.mark.skipif(not _JIWER_AVAILABLE, reason="test requires jiwer") @pytest.mark.parametrize( ["preds", "targets"], [ diff --git a/tests/unittests/text/test_wip.py b/tests/unittests/text/test_wip.py index c6ce8a89a8e..1900f7182b2 100644 --- a/tests/unittests/text/test_wip.py +++ b/tests/unittests/text/test_wip.py @@ -14,20 +14,22 @@ from typing import List, Union import pytest -from jiwer import wip from torchmetrics.functional.text.wip import word_information_preserved from torchmetrics.text.wip import WordInfoPreserved -from torchmetrics.utilities.imports import _JIWER_AVAILABLE from unittests.text._inputs import _inputs_error_rate_batch_size_1, _inputs_error_rate_batch_size_2 from unittests.text.helpers import TextTester def _reference_jiwer_wip(preds: Union[str, List[str]], target: Union[str, List[str]]): + try: + from jiwer import wip + except ImportError: + pytest.skip("test requires jiwer package to be installed") + return wip(target, preds) -@pytest.mark.skipif(not _JIWER_AVAILABLE, reason="test requires jiwer") @pytest.mark.parametrize( ["preds", "targets"], [