From 01a424cdb3cd31de9a19d3b39974e40ca5207304 Mon Sep 17 00:00:00 2001 From: Jirka Date: Mon, 12 Feb 2024 21:52:26 +0100 Subject: [PATCH 1/2] tests: group wrappers & make them optional --- tests/unittests/__init__.py | 2 - tests/unittests/conftest.py | 20 -------- tests/unittests/helpers/__init__.py | 5 ++ tests/unittests/helpers/wrappers.py | 51 +++++++++++++++++++ .../image/test_perceptual_path_length.py | 3 +- tests/unittests/multimodal/test_clip_iqa.py | 2 +- tests/unittests/multimodal/test_clip_score.py | 3 +- tests/unittests/text/helpers.py | 25 +-------- tests/unittests/text/test_bertscore.py | 3 +- tests/unittests/text/test_infolm.py | 3 +- tests/unittests/text/test_rouge.py | 3 +- 11 files changed, 66 insertions(+), 54 deletions(-) create mode 100644 tests/unittests/helpers/wrappers.py diff --git a/tests/unittests/__init__.py b/tests/unittests/__init__.py index 64a08805e34..c3764c083f1 100644 --- a/tests/unittests/__init__.py +++ b/tests/unittests/__init__.py @@ -13,7 +13,6 @@ NUM_PROCESSES, THRESHOLD, setup_ddp, - skip_on_running_out_of_memory, ) # adding compatibility for numpy >= 1.24 @@ -50,5 +49,4 @@ class _GroupInput(NamedTuple): "NUM_PROCESSES", "THRESHOLD", "setup_ddp", - "skip_on_running_out_of_memory", ] diff --git a/tests/unittests/conftest.py b/tests/unittests/conftest.py index aa866372539..d66ca24225c 100644 --- a/tests/unittests/conftest.py +++ b/tests/unittests/conftest.py @@ -14,8 +14,6 @@ import contextlib import os import sys -from functools import wraps -from typing import Any, Callable, Optional import pytest import torch @@ -84,21 +82,3 @@ def pytest_sessionfinish(): return pytest.pool.close() pytest.pool.join() - - -def skip_on_running_out_of_memory(reason: str = "Skipping test as it ran out of memory."): - """Handle tests that sometimes runs out of memory, by simply skipping them.""" - - def test_decorator(function: Callable, *args: Any, **kwargs: Any) -> Optional[Callable]: - @wraps(function) - def run_test(*args: Any, **kwargs: Any) -> Optional[Any]: - try: - return function(*args, **kwargs) - except RuntimeError as ex: - if "DefaultCPUAllocator: not enough memory:" not in str(ex): - raise ex - pytest.skip(reason) - - return run_test - - return test_decorator diff --git a/tests/unittests/helpers/__init__.py b/tests/unittests/helpers/__init__.py index b1903f18b5e..c4c5be42c8b 100644 --- a/tests/unittests/helpers/__init__.py +++ b/tests/unittests/helpers/__init__.py @@ -16,6 +16,8 @@ import numpy import torch +from unittests.helpers.wrappers import skip_on_connection_issues, skip_on_running_out_of_memory + def seed_all(seed): """Set the seed of all computational frameworks.""" @@ -23,3 +25,6 @@ def seed_all(seed): numpy.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) + + +__all__ = ["seed_all", "skip_on_connection_issues", "skip_on_running_out_of_memory"] diff --git a/tests/unittests/helpers/wrappers.py b/tests/unittests/helpers/wrappers.py new file mode 100644 index 00000000000..4029f001b2d --- /dev/null +++ b/tests/unittests/helpers/wrappers.py @@ -0,0 +1,51 @@ +import os +from functools import wraps +from typing import Any, Callable, Optional + +import pytest + +ALLOW_SKIP_IF_OUT_OF_MEMORY = os.getenv("ALLOW_SKIP_IF_OUT_OF_MEMORY", "0") == "1" +ALLOW_SKIP_IF_BAD_CONNECTION = os.getenv("ALLOW_SKIP_IF_BAD_CONNECTION", "0") == "1" + + +def skip_on_running_out_of_memory(reason: str = "Skipping test as it ran out of memory."): + """Handle tests that sometimes runs out of memory, by simply skipping them.""" + + def test_decorator(function: Callable, *args: Any, **kwargs: Any) -> Optional[Callable]: + @wraps(function) + def run_test(*args: Any, **kwargs: Any) -> Optional[Any]: + try: + return function(*args, **kwargs) + except RuntimeError as ex: + if "DefaultCPUAllocator: not enough memory:" not in str(ex): + raise ex + if ALLOW_SKIP_IF_OUT_OF_MEMORY: + pytest.skip(reason) + + return run_test + + return test_decorator + + +def skip_on_connection_issues(reason: str = "Unable to load checkpoints from HuggingFace `transformers`."): + """Handle download related tests if they fail due to connection issues. + + The tests run normally if no connection issue arises, and they're marked as skipped otherwise. + + """ + _error_msg_starts = ["We couldn't connect to", "Connection error", "Can't load", "`nltk` resource `punkt` is"] + + def test_decorator(function: Callable, *args: Any, **kwargs: Any) -> Optional[Callable]: + @wraps(function) + def run_test(*args: Any, **kwargs: Any) -> Optional[Any]: + try: + return function(*args, **kwargs) + except (OSError, ValueError) as ex: + if all(msg_start not in str(ex) for msg_start in _error_msg_starts): + raise ex + if ALLOW_SKIP_IF_BAD_CONNECTION: + pytest.skip(reason) + + return run_test + + return test_decorator diff --git a/tests/unittests/image/test_perceptual_path_length.py b/tests/unittests/image/test_perceptual_path_length.py index dd08bac4b22..dfdd5cfde96 100644 --- a/tests/unittests/image/test_perceptual_path_length.py +++ b/tests/unittests/image/test_perceptual_path_length.py @@ -24,8 +24,7 @@ from torchmetrics.image.perceptual_path_length import PerceptualPathLength from torchmetrics.utilities.imports import _TORCH_FIDELITY_AVAILABLE -from unittests import skip_on_running_out_of_memory -from unittests.helpers import seed_all +from unittests.helpers import seed_all, skip_on_running_out_of_memory seed_all(42) diff --git a/tests/unittests/multimodal/test_clip_iqa.py b/tests/unittests/multimodal/test_clip_iqa.py index 403aef0cd96..05421dc55b1 100644 --- a/tests/unittests/multimodal/test_clip_iqa.py +++ b/tests/unittests/multimodal/test_clip_iqa.py @@ -26,9 +26,9 @@ from torchmetrics.utilities.imports import _PIQ_GREATER_EQUAL_0_8, _TRANSFORMERS_GREATER_EQUAL_4_10 from torchvision.transforms import PILToTensor +from unittests.helpers import skip_on_connection_issues from unittests.helpers.testers import MetricTester from unittests.image import _SAMPLE_IMAGE -from unittests.text.helpers import skip_on_connection_issues @pytest.mark.parametrize( diff --git a/tests/unittests/multimodal/test_clip_score.py b/tests/unittests/multimodal/test_clip_score.py index d3212fef4a1..e506dc89d74 100644 --- a/tests/unittests/multimodal/test_clip_score.py +++ b/tests/unittests/multimodal/test_clip_score.py @@ -25,9 +25,8 @@ from transformers import CLIPModel as _CLIPModel from transformers import CLIPProcessor as _CLIPProcessor -from unittests.helpers import seed_all +from unittests.helpers import seed_all, skip_on_connection_issues from unittests.helpers.testers import MetricTester -from unittests.text.helpers import skip_on_connection_issues seed_all(42) diff --git a/tests/unittests/text/helpers.py b/tests/unittests/text/helpers.py index cebad163413..41643fb70c0 100644 --- a/tests/unittests/text/helpers.py +++ b/tests/unittests/text/helpers.py @@ -13,7 +13,7 @@ # limitations under the License. import pickle import sys -from functools import partial, wraps +from functools import partial from typing import Any, Callable, Dict, Optional, Sequence, Union import numpy as np @@ -477,26 +477,3 @@ def run_differentiability_test( if metric.is_differentiable: # check for numerical correctness assert torch.autograd.gradcheck(partial(metric_functional, **metric_args), (preds[0], targets[0])) - - -def skip_on_connection_issues(reason: str = "Unable to load checkpoints from HuggingFace `transformers`."): - """Handle download related tests if they fail due to connection issues. - - The tests run normally if no connection issue arises, and they're marked as skipped otherwise. - - """ - _error_msg_starts = ["We couldn't connect to", "Connection error", "Can't load", "`nltk` resource `punkt` is"] - - def test_decorator(function: Callable, *args: Any, **kwargs: Any) -> Optional[Callable]: - @wraps(function) - def run_test(*args: Any, **kwargs: Any) -> Optional[Any]: - try: - return function(*args, **kwargs) - except (OSError, ValueError) as ex: - if all(msg_start not in str(ex) for msg_start in _error_msg_starts): - raise ex - pytest.skip(reason) - - return run_test - - return test_decorator diff --git a/tests/unittests/text/test_bertscore.py b/tests/unittests/text/test_bertscore.py index d4819458b1a..632a902f993 100644 --- a/tests/unittests/text/test_bertscore.py +++ b/tests/unittests/text/test_bertscore.py @@ -22,7 +22,8 @@ from torchmetrics.utilities.imports import _BERTSCORE_AVAILABLE, _TRANSFORMERS_GREATER_EQUAL_4_4 from typing_extensions import Literal -from unittests.text.helpers import TextTester, skip_on_connection_issues +from unittests.helpers import skip_on_connection_issues +from unittests.text.helpers import TextTester from unittests.text.inputs import _inputs_single_reference if _BERTSCORE_AVAILABLE: diff --git a/tests/unittests/text/test_infolm.py b/tests/unittests/text/test_infolm.py index f3a54b23d07..9455fe0c867 100644 --- a/tests/unittests/text/test_infolm.py +++ b/tests/unittests/text/test_infolm.py @@ -19,7 +19,8 @@ from torchmetrics.text.infolm import InfoLM from torchmetrics.utilities.imports import _TRANSFORMERS_GREATER_EQUAL_4_4 -from unittests.text.helpers import TextTester, skip_on_connection_issues +from unittests.helpers import skip_on_connection_issues +from unittests.text.helpers import TextTester from unittests.text.inputs import HYPOTHESIS_A, HYPOTHESIS_C, _inputs_single_reference # Small bert model with 2 layers, 2 attention heads and hidden dim of 128 diff --git a/tests/unittests/text/test_rouge.py b/tests/unittests/text/test_rouge.py index 280525bd047..113b22d3e1a 100644 --- a/tests/unittests/text/test_rouge.py +++ b/tests/unittests/text/test_rouge.py @@ -24,7 +24,8 @@ from torchmetrics.utilities.imports import _NLTK_AVAILABLE, _ROUGE_SCORE_AVAILABLE from typing_extensions import Literal -from unittests.text.helpers import TextTester, skip_on_connection_issues +from unittests.helpers import skip_on_connection_issues +from unittests.text.helpers import TextTester from unittests.text.inputs import _Input, _inputs_multiple_references, _inputs_single_sentence_single_reference if _ROUGE_SCORE_AVAILABLE: From 880e7ef5154ff8a4a8b3525a823841bf0fa4fac4 Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 13 Feb 2024 15:26:57 +0100 Subject: [PATCH 2/2] ci --- .azure/gpu-unittests.yml | 5 +++++ .github/workflows/ci-tests.yml | 6 ++++++ 2 files changed, 11 insertions(+) diff --git a/.azure/gpu-unittests.yml b/.azure/gpu-unittests.yml index 432dc38560a..974f710ac1b 100644 --- a/.azure/gpu-unittests.yml +++ b/.azure/gpu-unittests.yml @@ -65,6 +65,11 @@ jobs: echo "##vso[task.setvariable variable=CUDA_VERSION_MM]$CUDA_version_mm" echo "##vso[task.setvariable variable=TORCH_URL]https://download.pytorch.org/whl/cu${CUDA_version_mm}/torch_stable.html" displayName: "set Env. vars" + - bash: | + echo "##vso[task.setvariable variable=ALLOW_SKIP_IF_OUT_OF_MEMORY]1" + echo "##vso[task.setvariable variable=ALLOW_SKIP_IF_BAD_CONNECTION]1" + condition: eq(variables['Build.Reason'], 'PullRequest') + displayName: "set Env. vars for PRs" - bash: | printf "PR: $PR_NUMBER \n" diff --git a/.github/workflows/ci-tests.yml b/.github/workflows/ci-tests.yml index 63c4d4416dc..d9f6d6a3fea 100644 --- a/.github/workflows/ci-tests.yml +++ b/.github/workflows/ci-tests.yml @@ -123,6 +123,12 @@ jobs: --find-links $PYTORCH_URL -f $PYPI_CACHE pip list + - name: set special vars for PR + if: ${{ github.event_name == 'pull_request' }} + run: | + echo 'ALLOW_SKIP_IF_OUT_OF_MEMORY=1' >> $GITHUB_ENV + echo 'ALLOW_SKIP_IF_BAD_CONNECTION=1' >> $GITHUB_ENV + - name: Sanity check id: info run: |