Skip to content

Commit

Permalink
tests: prefer cache for missing config (#2414)
Browse files Browse the repository at this point in the history
move import inside ref metric

(cherry picked from commit a8c11b4)
  • Loading branch information
Borda committed Mar 18, 2024
1 parent 416779e commit 0f26306
Show file tree
Hide file tree
Showing 27 changed files with 107 additions and 120 deletions.
2 changes: 1 addition & 1 deletion .github/actions/pull-caches/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -90,5 +90,5 @@ runs:

- name: Restored References
continue-on-error: true
run: ls -lh tests/_cache-references/
run: py-tree tests/_cache-references/ --show_hidden
shell: bash
2 changes: 1 addition & 1 deletion .github/actions/push-caches/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -99,5 +99,5 @@ runs:
key: cache-references

- name: Post References
run: ls -lh tests/_cache-references/
run: py-tree tests/_cache-references/ --show_hidden
shell: bash
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ pip-delete-this-directory.txt
# Unit test / coverage reports
tests/_data/
data.zip
tests/_reference-cache/
tests/_cache-references/
htmlcov/
.coverage
.coverage.*
Expand Down
4 changes: 0 additions & 4 deletions src/torchmetrics/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,8 @@
import sys

from lightning_utilities.core.imports import RequirementCache
from packaging.version import Version, parse

_PYTHON_VERSION = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
_PYTHON_LOWER_3_8 = parse(_PYTHON_VERSION) < Version("3.8")
_TORCH_LOWER_2_0 = RequirementCache("torch<2.0.0")
_TORCH_GREATER_EQUAL_1_11 = RequirementCache("torch>=1.11.0")
_TORCH_GREATER_EQUAL_1_12 = RequirementCache("torch>=1.12.0")
Expand All @@ -29,7 +27,6 @@
_TORCH_GREATER_EQUAL_2_1 = RequirementCache("torch>=2.1.0")
_TORCH_GREATER_EQUAL_2_2 = RequirementCache("torch>=2.2.0")

_JIWER_AVAILABLE = RequirementCache("jiwer")
_NLTK_AVAILABLE = RequirementCache("nltk")
_ROUGE_SCORE_AVAILABLE = RequirementCache("rouge_score")
_BERTSCORE_AVAILABLE = RequirementCache("bert_score")
Expand All @@ -49,7 +46,6 @@
_GAMMATONE_AVAILABLE = RequirementCache("gammatone")
_TORCHAUDIO_AVAILABLE = RequirementCache("torchaudio")
_TORCHAUDIO_GREATER_EQUAL_0_10 = RequirementCache("torchaudio>=0.10.0")
_SACREBLEU_AVAILABLE = RequirementCache("sacrebleu")
_REGEX_AVAILABLE = RequirementCache("regex")
_PYSTOI_AVAILABLE = RequirementCache("pystoi")
_FAST_BSS_EVAL_AVAILABLE = RequirementCache("fast_bss_eval")
Expand Down
2 changes: 0 additions & 2 deletions tests/unittests/classification/test_group_fairness.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from torchmetrics import Metric
from torchmetrics.classification.group_fairness import BinaryFairness
from torchmetrics.functional.classification.group_fairness import binary_fairness
from torchmetrics.utilities.imports import _PYTHON_LOWER_3_8

from unittests import THRESHOLD
from unittests.classification._inputs import _group_cases
Expand Down Expand Up @@ -222,7 +221,6 @@ def run_precision_test_gpu(

@mock.patch("unittests.helpers.testers._assert_tensor", _assert_tensor)
@mock.patch("unittests.helpers.testers._assert_allclose", _assert_allclose)
@pytest.mark.skipif(_PYTHON_LOWER_3_8, reason="`TestBinaryFairness` requires `python>=3.8`.")
@pytest.mark.parametrize("inputs", _group_cases)
class TestBinaryFairness(BinaryFairnessTester):
"""Test class for `BinaryFairness` metric."""
Expand Down
8 changes: 4 additions & 4 deletions tests/unittests/image/test_fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_no_train_network_missing_torch_fidelity():
NoTrainInceptionV3(name="inception-v3-compat", features_list=["2048"])


@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity")
@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity")
def test_no_train():
"""Assert that metric never leaves evaluation mode."""

Expand All @@ -52,7 +52,7 @@ def forward(self, x):
assert not model.metric.inception.training, "FID metric was changed to training mode which should not happen"


@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity")
@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity")
def test_fid_pickle():
"""Assert that we can initialize the metric and pickle it."""
metric = FrechetInceptionDistance()
Expand Down Expand Up @@ -80,7 +80,7 @@ def test_fid_raises_errors_and_warnings():
_ = FrechetInceptionDistance(feature=[1, 2])


@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity")
@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity")
@pytest.mark.parametrize("feature", [64, 192, 768, 2048])
def test_fid_same_input(feature):
"""If real and fake are update on the same data the fid score should be 0."""
Expand Down Expand Up @@ -111,7 +111,7 @@ def __len__(self) -> int:


@pytest.mark.skipif(not torch.cuda.is_available(), reason="test is too slow without gpu")
@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity")
@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity")
@pytest.mark.parametrize("equal_size", [False, True])
def test_compare_fid(tmpdir, equal_size, feature=768):
"""Check that the hole pipeline give the same result as torch-fidelity."""
Expand Down
8 changes: 4 additions & 4 deletions tests/unittests/image/test_inception.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
torch.manual_seed(42)


@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity")
@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity")
def test_no_train():
"""Assert that metric never leaves evaluation mode."""

Expand All @@ -44,7 +44,7 @@ def forward(self, x):
), "InceptionScore metric was changed to training mode which should not happen"


@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity")
@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity")
def test_is_pickle():
"""Assert that we can initialize the metric and pickle it."""
metric = InceptionScore()
Expand Down Expand Up @@ -79,7 +79,7 @@ def test_is_raises_errors_and_warnings():
InceptionScore(feature=[1, 2])


@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity")
@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity")
def test_is_update_compute():
"""Test that inception score works as expected."""
metric = InceptionScore()
Expand All @@ -105,7 +105,7 @@ def __len__(self) -> int:


@pytest.mark.skipif(not torch.cuda.is_available(), reason="test is too slow without gpu")
@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity")
@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity")
@pytest.mark.parametrize("compute_on_cpu", [True, False])
def test_compare_is(tmpdir, compute_on_cpu):
"""Check that the hole pipeline give the same result as torch-fidelity."""
Expand Down
10 changes: 5 additions & 5 deletions tests/unittests/image/test_kid.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
torch.manual_seed(42)


@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity")
@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity")
def test_no_train():
"""Assert that metric never leaves evaluation mode."""

Expand All @@ -42,7 +42,7 @@ def forward(self, x):
assert not model.metric.inception.training, "FID metric was changed to training mode which should not happen"


@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity")
@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity")
def test_kid_pickle():
"""Assert that we can initialize the metric and pickle it."""
metric = KernelInceptionDistance()
Expand Down Expand Up @@ -83,7 +83,7 @@ def test_kid_raises_errors_and_warnings():
m.compute()


@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity")
@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity")
def test_kid_extra_parameters():
"""Test that the different input arguments raises expected errors if wrong."""
with pytest.raises(ValueError, match="Argument `subsets` expected to be integer larger than 0"):
Expand All @@ -102,7 +102,7 @@ def test_kid_extra_parameters():
KernelInceptionDistance(coef=-1)


@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity")
@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity")
@pytest.mark.parametrize("feature", [64, 192, 768, 2048])
def test_kid_same_input(feature):
"""Test that the metric works."""
Expand Down Expand Up @@ -132,7 +132,7 @@ def __len__(self) -> int:


@pytest.mark.skipif(not torch.cuda.is_available(), reason="test is too slow without gpu")
@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity")
@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity")
def test_compare_kid(tmpdir, feature=2048):
"""Check that the hole pipeline give the same result as torch-fidelity."""
from torch_fidelity import calculate_metrics
Expand Down
13 changes: 7 additions & 6 deletions tests/unittests/image/test_lpips.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,10 @@

import pytest
import torch
from lpips import LPIPS as LPIPS_reference # noqa: N811
from torch import Tensor
from torchmetrics.functional.image.lpips import learned_perceptual_image_patch_similarity
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
from torchmetrics.utilities.imports import _LPIPS_AVAILABLE, _TORCHVISION_AVAILABLE
from torchmetrics.utilities.imports import _TORCHVISION_AVAILABLE

from unittests.helpers import seed_all
from unittests.helpers.testers import MetricTester
Expand All @@ -43,15 +42,19 @@ def _reference_lpips(
img1: Tensor, img2: Tensor, net_type: str, normalize: bool = False, reduction: str = "mean"
) -> Tensor:
"""Comparison function for tm implementation."""
ref = LPIPS_reference(net=net_type)
try:
from lpips import LPIPS
except ImportError:
pytest.skip("test requires lpips package to be installed")

ref = LPIPS(net=net_type)
res = ref(img1, img2, normalize=normalize).detach().cpu().numpy()
if reduction == "mean":
return res.mean()
return res.sum()


@pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="test requires that torchvision is installed")
@pytest.mark.skipif(not _LPIPS_AVAILABLE, reason="test requires that lpips is installed")
class TestLPIPS(MetricTester):
"""Test class for `LearnedPerceptualImagePatchSimilarity` metric."""

Expand Down Expand Up @@ -109,7 +112,6 @@ def test_normalize_arg(normalize):


@pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="test requires that torchvision is installed")
@pytest.mark.skipif(not _LPIPS_AVAILABLE, reason="test requires that lpips is installed")
def test_error_on_wrong_init():
"""Test class raises the expected errors."""
with pytest.raises(ValueError, match="Argument `net_type` must be one .*"):
Expand All @@ -120,7 +122,6 @@ def test_error_on_wrong_init():


@pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="test requires that torchvision is installed")
@pytest.mark.skipif(not _LPIPS_AVAILABLE, reason="test requires that lpips is installed")
@pytest.mark.parametrize(
("inp1", "inp2"),
[
Expand Down
6 changes: 3 additions & 3 deletions tests/unittests/image/test_mifid.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def calculate_mifid(m1, s1, features1, m2, s2, features2):
return fid_private / (distance_private_thresholded + 1e-15)


@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity")
@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity")
def test_no_train():
"""Assert that metric never leaves evaluation mode."""

Expand Down Expand Up @@ -139,7 +139,7 @@ def test_mifid_raises_errors_and_warnings():
_ = MemorizationInformedFrechetInceptionDistance(cosine_distance_eps=1.1)


@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity")
@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity")
@pytest.mark.parametrize("feature", [64, 192, 768, 2048])
def test_fid_same_input(feature):
"""If real and fake are update on the same data the fid score should be 0."""
Expand All @@ -157,7 +157,7 @@ def test_fid_same_input(feature):


@pytest.mark.skipif(not torch.cuda.is_available(), reason="test is too slow without gpu")
@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch-fidelity")
@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity")
@pytest.mark.parametrize("equal_size", [False, True])
def test_compare_mifid(equal_size):
"""Check that our implementation of MIFID is correct by comparing it to the original implementation."""
Expand Down
8 changes: 4 additions & 4 deletions tests/unittests/image/test_perceptual_path_length.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
seed_all(42)


@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch_fidelity")
@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity")
@pytest.mark.parametrize("interpolation_method", ["lerp", "slerp_any", "slerp_unit"])
def test_interpolation_methods(interpolation_method):
"""Test that interpolation method works as expected."""
Expand All @@ -41,7 +41,7 @@ def test_interpolation_methods(interpolation_method):
assert torch.allclose(res1, res2)


@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch_fidelity")
@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity")
@skip_on_running_out_of_memory()
def test_sim_net():
"""Check that the similarity network is the same as the one used in torch_fidelity."""
Expand Down Expand Up @@ -100,7 +100,7 @@ def sample(self, num_samples):
return torch.randn(num_samples, self.z_size)


@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch_fidelity")
@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity")
@pytest.mark.parametrize(
("argument", "match"),
[
Expand Down Expand Up @@ -174,7 +174,7 @@ def test_raises_error_on_wrong_generator(generator, errortype, match):
ppl.update(generator=generator)


@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="test requires torch_fidelity")
@pytest.mark.skipif(not _TORCH_FIDELITY_AVAILABLE, reason="metric requires torch-fidelity")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
@skip_on_running_out_of_memory()
def test_compare():
Expand Down
10 changes: 5 additions & 5 deletions tests/unittests/multimodal/test_clip_iqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def _reference_clip_iqa(preds, target, reduce=False):
return res.sum() if reduce else res


@pytest.mark.skipif(not _PIQ_GREATER_EQUAL_0_8, reason="test requires piq>=0.8")
@pytest.mark.skipif(not _PIQ_GREATER_EQUAL_0_8, reason="metric requires piq>=0.8")
@pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_10, reason="test requires transformers>=4.10")
class TestCLIPIQA(MetricTester):
"""Test clip iqa metric."""
Expand Down Expand Up @@ -104,7 +104,7 @@ def test_clip_iqa_functional(self, shapes):


@skip_on_connection_issues()
@pytest.mark.skipif(not _PIQ_GREATER_EQUAL_0_8, reason="test requires piq>=0.8")
@pytest.mark.skipif(not _PIQ_GREATER_EQUAL_0_8, reason="metric requires piq>=0.8")
@pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_10, reason="test requires transformers>=4.10")
@pytest.mark.skipif(not os.path.isfile(_SAMPLE_IMAGE), reason="test image not found")
def test_for_correctness_sample_images():
Expand All @@ -121,7 +121,7 @@ def test_for_correctness_sample_images():


@skip_on_connection_issues()
@pytest.mark.skipif(not _PIQ_GREATER_EQUAL_0_8, reason="test requires piq>=0.8")
@pytest.mark.skipif(not _PIQ_GREATER_EQUAL_0_8, reason="metric requires piq>=0.8")
@pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_10, reason="test requires transformers>=4.10")
@pytest.mark.parametrize(
"model",
Expand All @@ -148,7 +148,7 @@ def test_other_models(model):


@skip_on_connection_issues()
@pytest.mark.skipif(not _PIQ_GREATER_EQUAL_0_8, reason="test requires piq>=0.8")
@pytest.mark.skipif(not _PIQ_GREATER_EQUAL_0_8, reason="metric requires piq>=0.8")
@pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_10, reason="test requires transformers>=4.10")
@pytest.mark.parametrize(
"prompts",
Expand Down Expand Up @@ -200,7 +200,7 @@ def test_prompt(prompts):


@skip_on_connection_issues()
@pytest.mark.skipif(not _PIQ_GREATER_EQUAL_0_8, reason="test requires piq>=0.8")
@pytest.mark.skipif(not _PIQ_GREATER_EQUAL_0_8, reason="metric requires piq>=0.8")
@pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_10, reason="test requires transformers>=4.10")
def test_plot_method():
"""Test the plot method of CLIPScore separately in this file due to the skipping conditions."""
Expand Down
12 changes: 6 additions & 6 deletions tests/unittests/nominal/test_cramers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import operator
from functools import partial

import pytest
import torch
from dython.nominal import cramers_v as dython_cramers_v
from lightning_utilities.core.imports import compare_version
from torchmetrics.functional.nominal.cramers import cramers_v, cramers_v_matrix
from torchmetrics.nominal.cramers import CramersV

Expand Down Expand Up @@ -63,10 +60,15 @@ def cramers_matrix_input():


def _reference_dython_cramers_v(preds, target, bias_correction, nan_strategy, nan_replace_value):
try:
from dython.nominal import cramers_v
except ImportError:
pytest.skip("This test requires `dython` package to be installed.")

preds = preds.argmax(1) if preds.ndim == 2 else preds
target = target.argmax(1) if target.ndim == 2 else target

v = dython_cramers_v(
v = cramers_v(
preds.numpy(),
target.numpy(),
bias_correction=bias_correction,
Expand All @@ -87,7 +89,6 @@ def _dython_cramers_v_matrix(matrix, bias_correction, nan_strategy, nan_replace_
return cramers_v_matrix_value


@pytest.mark.skipif(compare_version("pandas", operator.lt, "1.3.2"), reason="`dython` package requires `pandas>=1.3.2`")
@pytest.mark.parametrize(
"preds, target",
[
Expand Down Expand Up @@ -161,7 +162,6 @@ def test_cramers_v_differentiability(self, preds, target, bias_correction, nan_s
)


@pytest.mark.skipif(compare_version("pandas", operator.lt, "1.3.2"), reason="`dython` package requires `pandas>=1.3.2`")
@pytest.mark.parametrize("bias_correction", [False, True])
@pytest.mark.parametrize(("nan_strategy", "nan_replace_value"), [("replace", 1.0), ("drop", None)])
def test_cramers_v_matrix(cramers_matrix_input, bias_correction, nan_strategy, nan_replace_value):
Expand Down
Loading

0 comments on commit 0f26306

Please sign in to comment.