diff --git a/.azure/gpu-unittests.yml b/.azure/gpu-unittests.yml index 2013ab871b7..62084985b1c 100644 --- a/.azure/gpu-unittests.yml +++ b/.azure/gpu-unittests.yml @@ -49,7 +49,6 @@ jobs: TEST_DIRS: "unittests" # todo: consider unfreeze for master too FREEZE_REQUIREMENTS: 1 - PYTEST_REFERENCE_CACHE: "/var/tmp/cache-references" container: image: "$(docker-image)" @@ -127,7 +126,13 @@ jobs: pip install -q py-tree py-tree /var/tmp/torch py-tree /var/tmp/hf - py-tree $(PYTEST_REFERENCE_CACHE) --show_hidden + # this gives more the 60k lines and takes a few minutes to run + #py-tree $(PYTEST_REFERENCE_CACHE) --show_hidden + # make sure the cache exists even it is empty + mkdir -p /var/tmp/cached-references + # copy the cache to the tests folder to be used in the next steps + cp -r /var/tmp/cached-references tests/_cache-references + du -h --max-depth=1 tests/ displayName: "Show caches" - bash: | @@ -156,6 +161,7 @@ jobs: workingDirectory: tests # skip for PR if there is nothing to test, note that outside PR there is default 'unittests' condition: and(succeeded(), ne(variables['TEST_DIRS'], '')) + timeoutInMinutes: "60" displayName: "UnitTesting common" - bash: | @@ -167,8 +173,16 @@ jobs: workingDirectory: tests # skip for PR if there is nothing to test, note that outside PR there is default 'unittests' condition: and(succeeded(), ne(variables['TEST_DIRS'], '')) + timeoutInMinutes: "60" displayName: "UnitTesting DDP" + - bash: | + du -h --max-depth=1 tests/ + # copy potentially updated cache to the machine filesystem to be reused with next jobs + cp -r --update tests/_cache-references /var/tmp/cached-references + # set as extra step to not pollute general cache when jobs fails or crashes + displayName: "Update cached refs" + - bash: | python -m coverage report python -m coverage xml diff --git a/.github/workflows/ci-tests.yml b/.github/workflows/ci-tests.yml index 89e3a04a5bc..bd80c9868a4 100644 --- a/.github/workflows/ci-tests.yml +++ b/.github/workflows/ci-tests.yml @@ -169,7 +169,7 @@ jobs: --reruns-delay 1 \ -m "not DDP" \ -n auto \ - --dist=loadfile \ + --dist=load \ ${{ env.UNITTEST_TIMEOUT }} - name: Unittests DDP diff --git a/requirements/_doctest.txt b/requirements/_doctest.txt index 521f204df48..f8738145a5f 100644 --- a/requirements/_doctest.txt +++ b/requirements/_doctest.txt @@ -1,6 +1,6 @@ # NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment -pytest >=8.0.0, <8.1.0 -pytest-doctestplus >1.0, <1.3 -pytest-rerunfailures >10.0, <14.0 +pytest >=8.0, <9.0 +pytest-doctestplus >=1.0, <1.3 +pytest-rerunfailures >=10.0, <14.0 diff --git a/requirements/_tests.txt b/requirements/_tests.txt index d739a420297..0dccca202fe 100644 --- a/requirements/_tests.txt +++ b/requirements/_tests.txt @@ -2,7 +2,7 @@ # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment coverage ==7.4.3 -pytest ==8.0.0 +pytest ==8.1.1 pytest-cov ==4.1.0 pytest-doctestplus ==1.2.1 pytest-rerunfailures ==13.0 diff --git a/tests/unittests/helpers/__init__.py b/tests/unittests/_helpers/__init__.py similarity index 90% rename from tests/unittests/helpers/__init__.py rename to tests/unittests/_helpers/__init__.py index c4c5be42c8b..5103904838d 100644 --- a/tests/unittests/helpers/__init__.py +++ b/tests/unittests/_helpers/__init__.py @@ -16,7 +16,7 @@ import numpy import torch -from unittests.helpers.wrappers import skip_on_connection_issues, skip_on_running_out_of_memory +from unittests._helpers.wrappers import skip_on_connection_issues, skip_on_running_out_of_memory def seed_all(seed): diff --git a/tests/unittests/helpers/testers.py b/tests/unittests/_helpers/testers.py similarity index 98% rename from tests/unittests/helpers/testers.py rename to tests/unittests/_helpers/testers.py index a8885f831bb..deb4c12324e 100644 --- a/tests/unittests/helpers/testers.py +++ b/tests/unittests/_helpers/testers.py @@ -337,6 +337,18 @@ def _assert_dtype_support( _assert_tensor(metric_functional(y_hat, y, **kwargs_update)) +def _select_rand_best_device() -> str: + """Select the best device to run tests on.""" + nb_gpus = torch.cuda.device_count() + # todo: debug the eventual device checks/assets + # if nb_gpus > 1: + # from random import randrange + # return f"cuda:{randrange(nb_gpus)}" + if nb_gpus: + return "cuda" + return "cpu" + + class MetricTester: """Test class for all metrics. @@ -371,8 +383,6 @@ def run_functional_metric_test( target when running update on the metric. """ - device = "cuda" if (torch.cuda.is_available() and torch.cuda.device_count() > 0) else "cpu" - _functional_test( preds=preds, target=target, @@ -380,7 +390,7 @@ def run_functional_metric_test( reference_metric=reference_metric, metric_args=metric_args, atol=self.atol, - device=device, + device=_select_rand_best_device(), fragment_kwargs=fragment_kwargs, **kwargs_update, ) @@ -431,7 +441,7 @@ def run_class_metric_test( "reference_metric": reference_metric, "metric_args": metric_args or {}, "atol": atol or self.atol, - "device": "cuda" if torch.cuda.is_available() else "cpu", + "device": _select_rand_best_device(), "dist_sync_on_step": dist_sync_on_step, "check_dist_sync_on_step": check_dist_sync_on_step, "check_batch": check_batch, diff --git a/tests/unittests/helpers/wrappers.py b/tests/unittests/_helpers/wrappers.py similarity index 100% rename from tests/unittests/helpers/wrappers.py rename to tests/unittests/_helpers/wrappers.py diff --git a/tests/unittests/audio/test_c_si_snr.py b/tests/unittests/audio/test_c_si_snr.py index aed96ea6285..2ed148aef65 100644 --- a/tests/unittests/audio/test_c_si_snr.py +++ b/tests/unittests/audio/test_c_si_snr.py @@ -19,9 +19,9 @@ from torchmetrics.functional.audio import complex_scale_invariant_signal_noise_ratio from unittests import BATCH_SIZE, NUM_BATCHES, _Input +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester from unittests.audio import _SAMPLE_AUDIO_SPEECH, _SAMPLE_AUDIO_SPEECH_BAB_DB -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/audio/test_pesq.py b/tests/unittests/audio/test_pesq.py index 348f99c13d6..dd0e3caba9c 100644 --- a/tests/unittests/audio/test_pesq.py +++ b/tests/unittests/audio/test_pesq.py @@ -22,9 +22,9 @@ from torchmetrics.functional.audio import perceptual_evaluation_speech_quality from unittests import _Input +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester from unittests.audio import _SAMPLE_AUDIO_SPEECH, _SAMPLE_AUDIO_SPEECH_BAB_DB, _average_metric_wrapper -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/audio/test_pit.py b/tests/unittests/audio/test_pit.py index 107a775e728..85baab5e045 100644 --- a/tests/unittests/audio/test_pit.py +++ b/tests/unittests/audio/test_pit.py @@ -31,9 +31,9 @@ ) from unittests import BATCH_SIZE, NUM_BATCHES, _Input +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester from unittests.audio import _average_metric_wrapper -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/audio/test_sa_sdr.py b/tests/unittests/audio/test_sa_sdr.py index d6e7178b8ad..3de3c4900cf 100644 --- a/tests/unittests/audio/test_sa_sdr.py +++ b/tests/unittests/audio/test_sa_sdr.py @@ -24,8 +24,8 @@ ) from unittests import BATCH_SIZE, NUM_BATCHES, _Input -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/audio/test_sdr.py b/tests/unittests/audio/test_sdr.py index ce8756d8ec7..61257588606 100644 --- a/tests/unittests/audio/test_sdr.py +++ b/tests/unittests/audio/test_sdr.py @@ -24,9 +24,9 @@ from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_11 from unittests import _Input +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester from unittests.audio import _SAMPLE_AUDIO_SPEECH, _SAMPLE_AUDIO_SPEECH_BAB_DB, _SAMPLE_NUMPY_ISSUE_895 -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/audio/test_si_sdr.py b/tests/unittests/audio/test_si_sdr.py index 6f014f828eb..d8b8f78c5cc 100644 --- a/tests/unittests/audio/test_si_sdr.py +++ b/tests/unittests/audio/test_si_sdr.py @@ -21,9 +21,9 @@ from torchmetrics.functional.audio import scale_invariant_signal_distortion_ratio from unittests import BATCH_SIZE, NUM_BATCHES, _Input +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester from unittests.audio import _average_metric_wrapper -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/audio/test_si_snr.py b/tests/unittests/audio/test_si_snr.py index f6f6c7f52f7..d7065e0dc88 100644 --- a/tests/unittests/audio/test_si_snr.py +++ b/tests/unittests/audio/test_si_snr.py @@ -21,8 +21,8 @@ from torchmetrics.functional.audio import scale_invariant_signal_noise_ratio from unittests import BATCH_SIZE, NUM_BATCHES, _Input -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/audio/test_snr.py b/tests/unittests/audio/test_snr.py index d513360851a..332707028ff 100644 --- a/tests/unittests/audio/test_snr.py +++ b/tests/unittests/audio/test_snr.py @@ -21,9 +21,9 @@ from torchmetrics.functional.audio import signal_noise_ratio from unittests import _Input +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester from unittests.audio import _average_metric_wrapper -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/audio/test_srmr.py b/tests/unittests/audio/test_srmr.py index fa0ce989a61..3b1b07862ce 100644 --- a/tests/unittests/audio/test_srmr.py +++ b/tests/unittests/audio/test_srmr.py @@ -22,8 +22,8 @@ from torchmetrics.functional.audio.srmr import speech_reverberation_modulation_energy_ratio from torchmetrics.utilities.imports import _TORCHAUDIO_GREATER_EQUAL_0_10 -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/audio/test_stoi.py b/tests/unittests/audio/test_stoi.py index d05824a1380..54374098779 100644 --- a/tests/unittests/audio/test_stoi.py +++ b/tests/unittests/audio/test_stoi.py @@ -22,9 +22,9 @@ from torchmetrics.functional.audio import short_time_objective_intelligibility from unittests import _Input +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester from unittests.audio import _SAMPLE_AUDIO_SPEECH, _SAMPLE_AUDIO_SPEECH_BAB_DB, _average_metric_wrapper -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/bases/test_aggregation.py b/tests/unittests/bases/test_aggregation.py index 25ff2233076..0b593208e54 100644 --- a/tests/unittests/bases/test_aggregation.py +++ b/tests/unittests/bases/test_aggregation.py @@ -5,7 +5,7 @@ from torchmetrics.collections import MetricCollection from unittests import BATCH_SIZE, NUM_BATCHES -from unittests.helpers.testers import MetricTester +from unittests._helpers.testers import MetricTester def compare_mean(values, weights): diff --git a/tests/unittests/bases/test_collections.py b/tests/unittests/bases/test_collections.py index 16c95fc879a..0e124125509 100644 --- a/tests/unittests/bases/test_collections.py +++ b/tests/unittests/bases/test_collections.py @@ -34,8 +34,8 @@ ) from torchmetrics.utilities.checks import _allclose_recursive -from unittests.helpers import seed_all -from unittests.helpers.testers import DummyMetricDiff, DummyMetricMultiOutputDict, DummyMetricSum +from unittests._helpers import seed_all +from unittests._helpers.testers import DummyMetricDiff, DummyMetricMultiOutputDict, DummyMetricSum seed_all(42) diff --git a/tests/unittests/bases/test_ddp.py b/tests/unittests/bases/test_ddp.py index f613f6f20e9..1a44a2145ba 100644 --- a/tests/unittests/bases/test_ddp.py +++ b/tests/unittests/bases/test_ddp.py @@ -24,8 +24,8 @@ from torchmetrics.utilities.exceptions import TorchMetricsUserError from unittests import NUM_PROCESSES -from unittests.helpers import seed_all -from unittests.helpers.testers import DummyListMetric, DummyMetric, DummyMetricSum +from unittests._helpers import seed_all +from unittests._helpers.testers import DummyListMetric, DummyMetric, DummyMetricSum seed_all(42) diff --git a/tests/unittests/bases/test_hashing.py b/tests/unittests/bases/test_hashing.py index b8bb28987e7..1e4dd474f02 100644 --- a/tests/unittests/bases/test_hashing.py +++ b/tests/unittests/bases/test_hashing.py @@ -1,6 +1,6 @@ import pytest -from unittests.helpers.testers import DummyListMetric, DummyMetric +from unittests._helpers.testers import DummyListMetric, DummyMetric @pytest.mark.parametrize( diff --git a/tests/unittests/bases/test_metric.py b/tests/unittests/bases/test_metric.py index a928b1a753f..22257c0d0c9 100644 --- a/tests/unittests/bases/test_metric.py +++ b/tests/unittests/bases/test_metric.py @@ -27,8 +27,8 @@ from torchmetrics.classification import BinaryAccuracy from torchmetrics.regression import PearsonCorrCoef -from unittests.helpers import seed_all -from unittests.helpers.testers import DummyListMetric, DummyMetric, DummyMetricMultiOutput, DummyMetricSum +from unittests._helpers import seed_all +from unittests._helpers.testers import DummyListMetric, DummyMetric, DummyMetricMultiOutput, DummyMetricSum seed_all(42) diff --git a/tests/unittests/classification/_inputs.py b/tests/unittests/classification/_inputs.py index c660e625214..9a4981f041c 100644 --- a/tests/unittests/classification/_inputs.py +++ b/tests/unittests/classification/_inputs.py @@ -18,7 +18,7 @@ from torch import Tensor from unittests import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, NUM_CLASSES, _GroupInput, _Input -from unittests.helpers import seed_all +from unittests._helpers import seed_all seed_all(1) diff --git a/tests/unittests/classification/test_accuracy.py b/tests/unittests/classification/test_accuracy.py index f5e5f9857d7..db497cdb197 100644 --- a/tests/unittests/classification/test_accuracy.py +++ b/tests/unittests/classification/test_accuracy.py @@ -29,9 +29,9 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES, THRESHOLD +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index from unittests.classification._inputs import _binary_cases, _input_binary, _multiclass_cases, _multilabel_cases -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) diff --git a/tests/unittests/classification/test_auroc.py b/tests/unittests/classification/test_auroc.py index 54374bd56e1..3691c7305b7 100644 --- a/tests/unittests/classification/test_auroc.py +++ b/tests/unittests/classification/test_auroc.py @@ -26,9 +26,9 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) diff --git a/tests/unittests/classification/test_average_precision.py b/tests/unittests/classification/test_average_precision.py index e9427623162..2ff0274c93c 100644 --- a/tests/unittests/classification/test_average_precision.py +++ b/tests/unittests/classification/test_average_precision.py @@ -35,9 +35,9 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) diff --git a/tests/unittests/classification/test_calibration_error.py b/tests/unittests/classification/test_calibration_error.py index 5660a9042f0..a2000cc984e 100644 --- a/tests/unittests/classification/test_calibration_error.py +++ b/tests/unittests/classification/test_calibration_error.py @@ -32,9 +32,9 @@ from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_13 from unittests import NUM_CLASSES +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index from unittests.classification._inputs import _binary_cases, _multiclass_cases -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) diff --git a/tests/unittests/classification/test_cohen_kappa.py b/tests/unittests/classification/test_cohen_kappa.py index 6b21b9be3ff..40ebbc028bd 100644 --- a/tests/unittests/classification/test_cohen_kappa.py +++ b/tests/unittests/classification/test_cohen_kappa.py @@ -23,9 +23,9 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES, THRESHOLD +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index from unittests.classification._inputs import _binary_cases, _multiclass_cases -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) diff --git a/tests/unittests/classification/test_confusion_matrix.py b/tests/unittests/classification/test_confusion_matrix.py index 6a1c1850d4f..666e9f0fc05 100644 --- a/tests/unittests/classification/test_confusion_matrix.py +++ b/tests/unittests/classification/test_confusion_matrix.py @@ -32,9 +32,9 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES, THRESHOLD +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) diff --git a/tests/unittests/classification/test_dice.py b/tests/unittests/classification/test_dice.py index d3737b255a2..6854265d3d9 100644 --- a/tests/unittests/classification/test_dice.py +++ b/tests/unittests/classification/test_dice.py @@ -23,6 +23,8 @@ from torchmetrics.utilities.checks import _input_format_classification from torchmetrics.utilities.enums import DataType +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester from unittests.classification._inputs import _input_binary, _input_binary_logits, _input_binary_prob from unittests.classification._inputs import _input_multiclass as _input_mcls from unittests.classification._inputs import _input_multiclass_logits as _input_mcls_logits @@ -33,8 +35,6 @@ from unittests.classification._inputs import _input_multilabel_multidim as _input_mlmd from unittests.classification._inputs import _input_multilabel_multidim_prob as _input_mlmd_prob from unittests.classification._inputs import _input_multilabel_prob as _input_mlb_prob -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/classification/test_exact_match.py b/tests/unittests/classification/test_exact_match.py index 048003c1699..5afd4c00e40 100644 --- a/tests/unittests/classification/test_exact_match.py +++ b/tests/unittests/classification/test_exact_match.py @@ -22,9 +22,9 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES, THRESHOLD +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester, inject_ignore_index from unittests.classification._inputs import _multiclass_cases, _multilabel_cases -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester, inject_ignore_index seed_all(42) diff --git a/tests/unittests/classification/test_f_beta.py b/tests/unittests/classification/test_f_beta.py index a6cfc5f71b8..3a334708485 100644 --- a/tests/unittests/classification/test_f_beta.py +++ b/tests/unittests/classification/test_f_beta.py @@ -42,9 +42,9 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES, THRESHOLD +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) diff --git a/tests/unittests/classification/test_group_fairness.py b/tests/unittests/classification/test_group_fairness.py index 4d76b9301dd..9e89d041fd7 100644 --- a/tests/unittests/classification/test_group_fairness.py +++ b/tests/unittests/classification/test_group_fairness.py @@ -28,17 +28,17 @@ from torchmetrics.functional.classification.group_fairness import binary_fairness from unittests import THRESHOLD -from unittests.classification._inputs import _group_cases -from unittests.helpers import seed_all -from unittests.helpers.testers import ( +from unittests._helpers import seed_all +from unittests._helpers.testers import ( MetricTester, _assert_dtype_support, inject_ignore_index, remove_ignore_index_groups, ) -from unittests.helpers.testers import _assert_allclose as _core_assert_allclose -from unittests.helpers.testers import _assert_requires_grad as _core_assert_requires_grad -from unittests.helpers.testers import _assert_tensor as _core_assert_tensor +from unittests._helpers.testers import _assert_allclose as _core_assert_allclose +from unittests._helpers.testers import _assert_requires_grad as _core_assert_requires_grad +from unittests._helpers.testers import _assert_tensor as _core_assert_tensor +from unittests.classification._inputs import _group_cases seed_all(42) @@ -219,8 +219,8 @@ def run_precision_test_gpu( ) -@mock.patch("unittests.helpers.testers._assert_tensor", _assert_tensor) -@mock.patch("unittests.helpers.testers._assert_allclose", _assert_allclose) +@mock.patch("unittests._helpers.testers._assert_tensor", _assert_tensor) +@mock.patch("unittests._helpers.testers._assert_allclose", _assert_allclose) @pytest.mark.parametrize("inputs", _group_cases) class TestBinaryFairness(BinaryFairnessTester): """Test class for `BinaryFairness` metric.""" diff --git a/tests/unittests/classification/test_hamming_distance.py b/tests/unittests/classification/test_hamming_distance.py index 8ccbbc9e1fb..f7f3686c73b 100644 --- a/tests/unittests/classification/test_hamming_distance.py +++ b/tests/unittests/classification/test_hamming_distance.py @@ -33,9 +33,9 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES, THRESHOLD +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) diff --git a/tests/unittests/classification/test_hinge.py b/tests/unittests/classification/test_hinge.py index 6b9eaca1abd..8f285794d15 100644 --- a/tests/unittests/classification/test_hinge.py +++ b/tests/unittests/classification/test_hinge.py @@ -25,8 +25,8 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES +from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index from unittests.classification._inputs import _binary_cases, _multiclass_cases -from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index torch.manual_seed(42) diff --git a/tests/unittests/classification/test_jaccard.py b/tests/unittests/classification/test_jaccard.py index c1e6354d57a..8fa17ca1d32 100644 --- a/tests/unittests/classification/test_jaccard.py +++ b/tests/unittests/classification/test_jaccard.py @@ -33,8 +33,8 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES, THRESHOLD +from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases -from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index def _reference_sklearn_jaccard_index_binary(preds, target, ignore_index=None): diff --git a/tests/unittests/classification/test_matthews_corrcoef.py b/tests/unittests/classification/test_matthews_corrcoef.py index f8c0801b5ad..03f649bc0ac 100644 --- a/tests/unittests/classification/test_matthews_corrcoef.py +++ b/tests/unittests/classification/test_matthews_corrcoef.py @@ -32,9 +32,9 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES, THRESHOLD +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) diff --git a/tests/unittests/classification/test_precision_fixed_recall.py b/tests/unittests/classification/test_precision_fixed_recall.py index 70c44df1109..f320d2cf1e9 100644 --- a/tests/unittests/classification/test_precision_fixed_recall.py +++ b/tests/unittests/classification/test_precision_fixed_recall.py @@ -34,9 +34,9 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) diff --git a/tests/unittests/classification/test_precision_recall.py b/tests/unittests/classification/test_precision_recall.py index 7cf4ce9e474..86fbe262aea 100644 --- a/tests/unittests/classification/test_precision_recall.py +++ b/tests/unittests/classification/test_precision_recall.py @@ -42,9 +42,9 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES, THRESHOLD +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) diff --git a/tests/unittests/classification/test_precision_recall_curve.py b/tests/unittests/classification/test_precision_recall_curve.py index 825e6887584..6f78438007e 100644 --- a/tests/unittests/classification/test_precision_recall_curve.py +++ b/tests/unittests/classification/test_precision_recall_curve.py @@ -34,9 +34,9 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) diff --git a/tests/unittests/classification/test_ranking.py b/tests/unittests/classification/test_ranking.py index f322ce44442..e85d38cde05 100644 --- a/tests/unittests/classification/test_ranking.py +++ b/tests/unittests/classification/test_ranking.py @@ -32,9 +32,9 @@ ) from unittests import NUM_CLASSES +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester, inject_ignore_index from unittests.classification._inputs import _multilabel_cases -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester, inject_ignore_index seed_all(42) diff --git a/tests/unittests/classification/test_recall_fixed_precision.py b/tests/unittests/classification/test_recall_fixed_precision.py index fb23e36f759..9bdca8950bd 100644 --- a/tests/unittests/classification/test_recall_fixed_precision.py +++ b/tests/unittests/classification/test_recall_fixed_precision.py @@ -34,9 +34,9 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) diff --git a/tests/unittests/classification/test_roc.py b/tests/unittests/classification/test_roc.py index cc12bd6c133..167ad4876f0 100644 --- a/tests/unittests/classification/test_roc.py +++ b/tests/unittests/classification/test_roc.py @@ -25,9 +25,9 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) diff --git a/tests/unittests/classification/test_sensitivity_specificity.py b/tests/unittests/classification/test_sensitivity_specificity.py index 3ffbb8ac5cb..d629c86583a 100644 --- a/tests/unittests/classification/test_sensitivity_specificity.py +++ b/tests/unittests/classification/test_sensitivity_specificity.py @@ -36,9 +36,9 @@ from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_11 from unittests import NUM_CLASSES +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) diff --git a/tests/unittests/classification/test_specificity.py b/tests/unittests/classification/test_specificity.py index 3aa2dcf6cba..fe5fd8977a8 100644 --- a/tests/unittests/classification/test_specificity.py +++ b/tests/unittests/classification/test_specificity.py @@ -33,9 +33,9 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES, THRESHOLD +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester, inject_ignore_index from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester, inject_ignore_index seed_all(42) diff --git a/tests/unittests/classification/test_specificity_sensitivity.py b/tests/unittests/classification/test_specificity_sensitivity.py index 9fc918d1e1d..0bafdfe55ea 100644 --- a/tests/unittests/classification/test_specificity_sensitivity.py +++ b/tests/unittests/classification/test_specificity_sensitivity.py @@ -35,9 +35,9 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) diff --git a/tests/unittests/classification/test_stat_scores.py b/tests/unittests/classification/test_stat_scores.py index b1e4d36e1ed..86c793f8c83 100644 --- a/tests/unittests/classification/test_stat_scores.py +++ b/tests/unittests/classification/test_stat_scores.py @@ -32,9 +32,9 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES, THRESHOLD +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) diff --git a/tests/unittests/clustering/_inputs.py b/tests/unittests/clustering/_inputs.py index b13bc0c0947..b4f6ce0b97c 100644 --- a/tests/unittests/clustering/_inputs.py +++ b/tests/unittests/clustering/_inputs.py @@ -18,7 +18,7 @@ from torch import Tensor from unittests import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, NUM_CLASSES, _Input -from unittests.helpers import seed_all +from unittests._helpers import seed_all seed_all(42) diff --git a/tests/unittests/clustering/test_adjusted_mutual_info_score.py b/tests/unittests/clustering/test_adjusted_mutual_info_score.py index 304ee18bddc..474e221d6a5 100644 --- a/tests/unittests/clustering/test_adjusted_mutual_info_score.py +++ b/tests/unittests/clustering/test_adjusted_mutual_info_score.py @@ -20,9 +20,9 @@ from torchmetrics.functional.clustering.adjusted_mutual_info_score import adjusted_mutual_info_score from unittests import BATCH_SIZE, NUM_CLASSES +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester from unittests.clustering._inputs import _float_inputs_extrinsic, _single_target_extrinsic1, _single_target_extrinsic2 -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/clustering/test_adjusted_rand_score.py b/tests/unittests/clustering/test_adjusted_rand_score.py index 54e8c1b4577..b98536aad15 100644 --- a/tests/unittests/clustering/test_adjusted_rand_score.py +++ b/tests/unittests/clustering/test_adjusted_rand_score.py @@ -17,8 +17,8 @@ from torchmetrics.clustering.adjusted_rand_score import AdjustedRandScore from torchmetrics.functional.clustering.adjusted_rand_score import adjusted_rand_score +from unittests._helpers.testers import MetricTester from unittests.clustering._inputs import _float_inputs_extrinsic, _single_target_extrinsic1, _single_target_extrinsic2 -from unittests.helpers.testers import MetricTester @pytest.mark.parametrize( diff --git a/tests/unittests/clustering/test_calinski_harabasz_score.py b/tests/unittests/clustering/test_calinski_harabasz_score.py index 6071767364e..f81da592389 100644 --- a/tests/unittests/clustering/test_calinski_harabasz_score.py +++ b/tests/unittests/clustering/test_calinski_harabasz_score.py @@ -16,9 +16,9 @@ from torchmetrics.clustering.calinski_harabasz_score import CalinskiHarabaszScore from torchmetrics.functional.clustering.calinski_harabasz_score import calinski_harabasz_score +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester from unittests.clustering._inputs import _single_target_intrinsic1, _single_target_intrinsic2 -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/clustering/test_davies_bouldin_score.py b/tests/unittests/clustering/test_davies_bouldin_score.py index 8f3b4800a01..bea2018c2cc 100644 --- a/tests/unittests/clustering/test_davies_bouldin_score.py +++ b/tests/unittests/clustering/test_davies_bouldin_score.py @@ -16,9 +16,9 @@ from torchmetrics.clustering.davies_bouldin_score import DaviesBouldinScore from torchmetrics.functional.clustering.davies_bouldin_score import davies_bouldin_score +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester from unittests.clustering._inputs import _single_target_intrinsic1, _single_target_intrinsic2 -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/clustering/test_dunn_index.py b/tests/unittests/clustering/test_dunn_index.py index fc1500a0fd8..c2e6adcd2cb 100644 --- a/tests/unittests/clustering/test_dunn_index.py +++ b/tests/unittests/clustering/test_dunn_index.py @@ -19,12 +19,12 @@ from torchmetrics.clustering.dunn_index import DunnIndex from torchmetrics.functional.clustering.dunn_index import dunn_index +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester from unittests.clustering._inputs import ( _single_target_intrinsic1, _single_target_intrinsic2, ) -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/clustering/test_fowlkes_mallows_index.py b/tests/unittests/clustering/test_fowlkes_mallows_index.py index f880791454f..6e5674ae337 100644 --- a/tests/unittests/clustering/test_fowlkes_mallows_index.py +++ b/tests/unittests/clustering/test_fowlkes_mallows_index.py @@ -16,9 +16,9 @@ from torchmetrics.clustering import FowlkesMallowsIndex from torchmetrics.functional.clustering import fowlkes_mallows_index +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester from unittests.clustering._inputs import _single_target_extrinsic1, _single_target_extrinsic2 -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/clustering/test_homogeneity_completeness_v_measure.py b/tests/unittests/clustering/test_homogeneity_completeness_v_measure.py index db8224f2f5e..dd716182b4b 100644 --- a/tests/unittests/clustering/test_homogeneity_completeness_v_measure.py +++ b/tests/unittests/clustering/test_homogeneity_completeness_v_measure.py @@ -28,9 +28,9 @@ v_measure_score, ) +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester from unittests.clustering._inputs import _float_inputs_extrinsic, _single_target_extrinsic1, _single_target_extrinsic2 -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/clustering/test_mutual_info_score.py b/tests/unittests/clustering/test_mutual_info_score.py index 2a5fd2af1ad..ab9222b8082 100644 --- a/tests/unittests/clustering/test_mutual_info_score.py +++ b/tests/unittests/clustering/test_mutual_info_score.py @@ -18,9 +18,9 @@ from torchmetrics.functional.clustering.mutual_info_score import mutual_info_score from unittests import BATCH_SIZE, NUM_CLASSES +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester from unittests.clustering._inputs import _float_inputs_extrinsic, _single_target_extrinsic1, _single_target_extrinsic2 -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/clustering/test_normalized_mutual_info_score.py b/tests/unittests/clustering/test_normalized_mutual_info_score.py index e40b807958a..07109771b0f 100644 --- a/tests/unittests/clustering/test_normalized_mutual_info_score.py +++ b/tests/unittests/clustering/test_normalized_mutual_info_score.py @@ -20,9 +20,9 @@ from torchmetrics.functional.clustering import normalized_mutual_info_score from unittests import BATCH_SIZE, NUM_CLASSES +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester from unittests.clustering._inputs import _float_inputs_extrinsic, _single_target_extrinsic1, _single_target_extrinsic2 -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/clustering/test_rand_score.py b/tests/unittests/clustering/test_rand_score.py index 9a2abf1a736..824f5e90fbd 100644 --- a/tests/unittests/clustering/test_rand_score.py +++ b/tests/unittests/clustering/test_rand_score.py @@ -17,9 +17,9 @@ from torchmetrics.clustering.rand_score import RandScore from torchmetrics.functional.clustering.rand_score import rand_score +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester from unittests.clustering._inputs import _float_inputs_extrinsic, _single_target_extrinsic1, _single_target_extrinsic2 -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/clustering/test_utils.py b/tests/unittests/clustering/test_utils.py index e6ffe222b46..2b1e4dbc755 100644 --- a/tests/unittests/clustering/test_utils.py +++ b/tests/unittests/clustering/test_utils.py @@ -27,7 +27,7 @@ ) from unittests import BATCH_SIZE, NUM_BATCHES, _Input -from unittests.helpers import seed_all +from unittests._helpers import seed_all seed_all(42) diff --git a/tests/unittests/detection/test_intersection.py b/tests/unittests/detection/test_intersection.py index 03f10242f09..9d99d2d55c5 100644 --- a/tests/unittests/detection/test_intersection.py +++ b/tests/unittests/detection/test_intersection.py @@ -35,7 +35,7 @@ else: tv_iou, tv_ciou, tv_diou, tv_giou = ..., ..., ..., ... -from unittests.helpers.testers import MetricTester +from unittests._helpers.testers import MetricTester def _tv_wrapper(preds, target, base_fn, aggregate=True, iou_threshold=None): diff --git a/tests/unittests/detection/test_map.py b/tests/unittests/detection/test_map.py index 46d9db5a839..f0dcdff52f2 100644 --- a/tests/unittests/detection/test_map.py +++ b/tests/unittests/detection/test_map.py @@ -32,8 +32,8 @@ _TORCHVISION_GREATER_EQUAL_0_8, ) +from unittests._helpers.testers import MetricTester from unittests.detection import _DETECTION_BBOX, _DETECTION_SEGM, _DETECTION_VAL -from unittests.helpers.testers import MetricTester _pytest_condition = not (_PYCOCOTOOLS_AVAILABLE and _TORCHVISION_GREATER_EQUAL_0_8) diff --git a/tests/unittests/detection/test_modified_panoptic_quality.py b/tests/unittests/detection/test_modified_panoptic_quality.py index 96b02d16930..4c864d0e9af 100644 --- a/tests/unittests/detection/test_modified_panoptic_quality.py +++ b/tests/unittests/detection/test_modified_panoptic_quality.py @@ -20,8 +20,8 @@ from torchmetrics.functional.detection import modified_panoptic_quality from unittests import _Input -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/detection/test_panoptic_quality.py b/tests/unittests/detection/test_panoptic_quality.py index 9a2b801e0f4..a8263de61e0 100644 --- a/tests/unittests/detection/test_panoptic_quality.py +++ b/tests/unittests/detection/test_panoptic_quality.py @@ -20,8 +20,8 @@ from torchmetrics.functional.detection.panoptic_qualities import panoptic_quality from unittests import _Input -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/image/test_csi.py b/tests/unittests/image/test_csi.py index 9b338167889..85e370c008f 100644 --- a/tests/unittests/image/test_csi.py +++ b/tests/unittests/image/test_csi.py @@ -21,8 +21,8 @@ from torchmetrics.regression.csi import CriticalSuccessIndex from unittests import BATCH_SIZE, NUM_BATCHES, _Input -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/image/test_d_lambda.py b/tests/unittests/image/test_d_lambda.py index be1ec3ffdf1..a4eb684b5ac 100644 --- a/tests/unittests/image/test_d_lambda.py +++ b/tests/unittests/image/test_d_lambda.py @@ -23,8 +23,8 @@ from torchmetrics.image.d_lambda import SpectralDistortionIndex from unittests import BATCH_SIZE, NUM_BATCHES -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/image/test_d_s.py b/tests/unittests/image/test_d_s.py index fb2a0fdf2db..e14cb2c96a0 100644 --- a/tests/unittests/image/test_d_s.py +++ b/tests/unittests/image/test_d_s.py @@ -26,8 +26,8 @@ from torchmetrics.image.d_s import SpatialDistortionIndex from unittests import BATCH_SIZE, NUM_BATCHES -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/image/test_ergas.py b/tests/unittests/image/test_ergas.py index 110fa6bf4bc..58bd999c205 100644 --- a/tests/unittests/image/test_ergas.py +++ b/tests/unittests/image/test_ergas.py @@ -22,8 +22,8 @@ from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import BATCH_SIZE, NUM_BATCHES -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/image/test_lpips.py b/tests/unittests/image/test_lpips.py index 026c2b91770..5e148a0d984 100644 --- a/tests/unittests/image/test_lpips.py +++ b/tests/unittests/image/test_lpips.py @@ -21,8 +21,8 @@ from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity from torchmetrics.utilities.imports import _TORCHVISION_AVAILABLE -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/image/test_ms_ssim.py b/tests/unittests/image/test_ms_ssim.py index 8d71617be53..8201e877332 100644 --- a/tests/unittests/image/test_ms_ssim.py +++ b/tests/unittests/image/test_ms_ssim.py @@ -19,8 +19,8 @@ from torchmetrics.image.ssim import MultiScaleStructuralSimilarityIndexMeasure from unittests import NUM_BATCHES, _Input -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/image/test_perceptual_path_length.py b/tests/unittests/image/test_perceptual_path_length.py index 1eb486c6ce3..26f48be51ae 100644 --- a/tests/unittests/image/test_perceptual_path_length.py +++ b/tests/unittests/image/test_perceptual_path_length.py @@ -24,7 +24,7 @@ from torchmetrics.image.perceptual_path_length import PerceptualPathLength from torchmetrics.utilities.imports import _TORCH_FIDELITY_AVAILABLE -from unittests.helpers import seed_all, skip_on_running_out_of_memory +from unittests._helpers import seed_all, skip_on_running_out_of_memory seed_all(42) diff --git a/tests/unittests/image/test_psnr.py b/tests/unittests/image/test_psnr.py index 0cfe9546017..66724af1f4c 100644 --- a/tests/unittests/image/test_psnr.py +++ b/tests/unittests/image/test_psnr.py @@ -23,8 +23,8 @@ from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import BATCH_SIZE, NUM_BATCHES, _Input -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/image/test_psnrb.py b/tests/unittests/image/test_psnrb.py index 077af420601..2d59efa1f79 100644 --- a/tests/unittests/image/test_psnrb.py +++ b/tests/unittests/image/test_psnrb.py @@ -21,8 +21,8 @@ from torchmetrics.image import PeakSignalNoiseRatioWithBlockedEffect from unittests import BATCH_SIZE, NUM_BATCHES -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/image/test_qnr.py b/tests/unittests/image/test_qnr.py index e1e680beed6..4cb42cf36be 100644 --- a/tests/unittests/image/test_qnr.py +++ b/tests/unittests/image/test_qnr.py @@ -22,8 +22,8 @@ from torchmetrics.image.qnr import QualityWithNoReference from unittests import BATCH_SIZE, NUM_BATCHES -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester from unittests.image.test_d_lambda import _baseline_d_lambda from unittests.image.test_d_s import _reference_d_s diff --git a/tests/unittests/image/test_rase.py b/tests/unittests/image/test_rase.py index 8015153fd19..f9227285e87 100644 --- a/tests/unittests/image/test_rase.py +++ b/tests/unittests/image/test_rase.py @@ -23,7 +23,7 @@ from torchmetrics.image import RelativeAverageSpectralError from unittests import BATCH_SIZE -from unittests.helpers.testers import MetricTester +from unittests._helpers.testers import MetricTester class _InputWindowSized(NamedTuple): diff --git a/tests/unittests/image/test_rmse_sw.py b/tests/unittests/image/test_rmse_sw.py index a989de625a7..307d66ac9b1 100644 --- a/tests/unittests/image/test_rmse_sw.py +++ b/tests/unittests/image/test_rmse_sw.py @@ -22,7 +22,7 @@ from torchmetrics.image import RootMeanSquaredErrorUsingSlidingWindow from unittests import BATCH_SIZE, NUM_BATCHES -from unittests.helpers.testers import MetricTester +from unittests._helpers.testers import MetricTester class _InputWindowSized(NamedTuple): diff --git a/tests/unittests/image/test_sam.py b/tests/unittests/image/test_sam.py index 9cf83196866..e71b0b230e3 100644 --- a/tests/unittests/image/test_sam.py +++ b/tests/unittests/image/test_sam.py @@ -22,8 +22,8 @@ from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import BATCH_SIZE, NUM_BATCHES, _Input -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/image/test_scc.py b/tests/unittests/image/test_scc.py index fa1854adc60..ecf6c355677 100644 --- a/tests/unittests/image/test_scc.py +++ b/tests/unittests/image/test_scc.py @@ -21,8 +21,8 @@ from torchmetrics.image import SpatialCorrelationCoefficient from unittests import BATCH_SIZE, NUM_BATCHES, _Input -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/image/test_ssim.py b/tests/unittests/image/test_ssim.py index a84f2d83468..6b464f2a97b 100644 --- a/tests/unittests/image/test_ssim.py +++ b/tests/unittests/image/test_ssim.py @@ -23,8 +23,8 @@ from torchmetrics.image import StructuralSimilarityIndexMeasure from unittests import NUM_BATCHES, _Input -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/image/test_tv.py b/tests/unittests/image/test_tv.py index 1842b693046..add144897ae 100644 --- a/tests/unittests/image/test_tv.py +++ b/tests/unittests/image/test_tv.py @@ -21,8 +21,8 @@ from torchmetrics.image.tv import TotalVariation from unittests import _Input -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/image/test_uqi.py b/tests/unittests/image/test_uqi.py index 22ff2b7479b..0539eb4d0d1 100644 --- a/tests/unittests/image/test_uqi.py +++ b/tests/unittests/image/test_uqi.py @@ -22,8 +22,8 @@ from torchmetrics.image.uqi import UniversalImageQualityIndex from unittests import BATCH_SIZE, NUM_BATCHES -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/image/test_vif.py b/tests/unittests/image/test_vif.py index bde00c969b5..e1ea8eb1401 100644 --- a/tests/unittests/image/test_vif.py +++ b/tests/unittests/image/test_vif.py @@ -20,8 +20,8 @@ from torchmetrics.image.vif import VisualInformationFidelity from unittests import BATCH_SIZE, NUM_BATCHES, _Input -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/multimodal/test_clip_iqa.py b/tests/unittests/multimodal/test_clip_iqa.py index 314ff0013b8..2e88840e81b 100644 --- a/tests/unittests/multimodal/test_clip_iqa.py +++ b/tests/unittests/multimodal/test_clip_iqa.py @@ -26,8 +26,8 @@ from torchmetrics.utilities.imports import _PIQ_GREATER_EQUAL_0_8, _TRANSFORMERS_GREATER_EQUAL_4_10 from torchvision.transforms import PILToTensor -from unittests.helpers import skip_on_connection_issues -from unittests.helpers.testers import MetricTester +from unittests._helpers import skip_on_connection_issues +from unittests._helpers.testers import MetricTester from unittests.image import _SAMPLE_IMAGE diff --git a/tests/unittests/multimodal/test_clip_score.py b/tests/unittests/multimodal/test_clip_score.py index 110266e6525..e2804ecebb9 100644 --- a/tests/unittests/multimodal/test_clip_score.py +++ b/tests/unittests/multimodal/test_clip_score.py @@ -25,8 +25,8 @@ from transformers import CLIPModel as _CLIPModel from transformers import CLIPProcessor as _CLIPProcessor -from unittests.helpers import seed_all, skip_on_connection_issues -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all, skip_on_connection_issues +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/nominal/test_cramers.py b/tests/unittests/nominal/test_cramers.py index 42b735ef510..78df22d21b9 100644 --- a/tests/unittests/nominal/test_cramers.py +++ b/tests/unittests/nominal/test_cramers.py @@ -20,7 +20,7 @@ from torchmetrics.nominal.cramers import CramersV from unittests import BATCH_SIZE, NUM_BATCHES, _Input -from unittests.helpers.testers import MetricTester +from unittests._helpers.testers import MetricTester NUM_CLASSES = 4 diff --git a/tests/unittests/nominal/test_fleiss_kappa.py b/tests/unittests/nominal/test_fleiss_kappa.py index 1538b116ab2..911f2814ad5 100644 --- a/tests/unittests/nominal/test_fleiss_kappa.py +++ b/tests/unittests/nominal/test_fleiss_kappa.py @@ -21,7 +21,7 @@ from torchmetrics.nominal.fleiss_kappa import FleissKappa from unittests import BATCH_SIZE, NUM_BATCHES, NUM_CLASSES -from unittests.helpers.testers import MetricTester +from unittests._helpers.testers import MetricTester NUM_RATERS = 20 NUM_CATEGORIES = NUM_CLASSES diff --git a/tests/unittests/nominal/test_pearson.py b/tests/unittests/nominal/test_pearson.py index 5bec1cd8121..2f59461d276 100644 --- a/tests/unittests/nominal/test_pearson.py +++ b/tests/unittests/nominal/test_pearson.py @@ -23,7 +23,7 @@ from torchmetrics.nominal.pearson import PearsonsContingencyCoefficient from unittests import BATCH_SIZE, NUM_BATCHES, _Input -from unittests.helpers.testers import MetricTester +from unittests._helpers.testers import MetricTester NUM_CLASSES = 4 diff --git a/tests/unittests/nominal/test_theils_u.py b/tests/unittests/nominal/test_theils_u.py index c06c6b9bcd2..a8bc2bc9952 100644 --- a/tests/unittests/nominal/test_theils_u.py +++ b/tests/unittests/nominal/test_theils_u.py @@ -20,7 +20,7 @@ from torchmetrics.nominal import TheilsU from unittests import BATCH_SIZE, NUM_BATCHES, _Input -from unittests.helpers.testers import MetricTester +from unittests._helpers.testers import MetricTester NUM_CLASSES = 4 diff --git a/tests/unittests/nominal/test_tschuprows.py b/tests/unittests/nominal/test_tschuprows.py index 91798d88d82..a6a8c1d2b39 100644 --- a/tests/unittests/nominal/test_tschuprows.py +++ b/tests/unittests/nominal/test_tschuprows.py @@ -20,7 +20,7 @@ from torchmetrics.nominal.tschuprows import TschuprowsT from unittests import BATCH_SIZE, NUM_BATCHES, _Input -from unittests.helpers.testers import MetricTester +from unittests._helpers.testers import MetricTester NUM_CLASSES = 4 diff --git a/tests/unittests/pairwise/test_pairwise_distance.py b/tests/unittests/pairwise/test_pairwise_distance.py index 3777ee63276..6538423c592 100644 --- a/tests/unittests/pairwise/test_pairwise_distance.py +++ b/tests/unittests/pairwise/test_pairwise_distance.py @@ -33,8 +33,8 @@ ) from unittests import BATCH_SIZE, NUM_BATCHES -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/regression/test_concordance.py b/tests/unittests/regression/test_concordance.py index 69668772021..06493af0e0d 100644 --- a/tests/unittests/regression/test_concordance.py +++ b/tests/unittests/regression/test_concordance.py @@ -22,8 +22,8 @@ from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, _Input -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/regression/test_cosine_similarity.py b/tests/unittests/regression/test_cosine_similarity.py index 184676526c7..de689cabb45 100644 --- a/tests/unittests/regression/test_cosine_similarity.py +++ b/tests/unittests/regression/test_cosine_similarity.py @@ -21,8 +21,8 @@ from torchmetrics.regression.cosine_similarity import CosineSimilarity from unittests import BATCH_SIZE, NUM_BATCHES, _Input -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/regression/test_explained_variance.py b/tests/unittests/regression/test_explained_variance.py index 629e1ea7932..4838efa5df1 100644 --- a/tests/unittests/regression/test_explained_variance.py +++ b/tests/unittests/regression/test_explained_variance.py @@ -20,8 +20,8 @@ from torchmetrics.regression import ExplainedVariance from unittests import BATCH_SIZE, NUM_BATCHES, _Input -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/regression/test_kendall.py b/tests/unittests/regression/test_kendall.py index c7e5747a0ba..017179069e0 100644 --- a/tests/unittests/regression/test_kendall.py +++ b/tests/unittests/regression/test_kendall.py @@ -24,8 +24,8 @@ from torchmetrics.utilities.imports import _SCIPY_GREATER_EQUAL_1_8, _TORCH_LOWER_2_0 from unittests import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, _Input -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/regression/test_kl_divergence.py b/tests/unittests/regression/test_kl_divergence.py index 6af23cd2f44..53a3817fa82 100644 --- a/tests/unittests/regression/test_kl_divergence.py +++ b/tests/unittests/regression/test_kl_divergence.py @@ -24,8 +24,8 @@ from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/regression/test_log_cosh_error.py b/tests/unittests/regression/test_log_cosh_error.py index 74b6214719b..9931ec91c22 100644 --- a/tests/unittests/regression/test_log_cosh_error.py +++ b/tests/unittests/regression/test_log_cosh_error.py @@ -20,8 +20,8 @@ from torchmetrics.regression.log_cosh import LogCoshError from unittests import BATCH_SIZE, NUM_BATCHES, _Input -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/regression/test_mean_error.py b/tests/unittests/regression/test_mean_error.py index c25882c3f37..631d9e6afe4 100644 --- a/tests/unittests/regression/test_mean_error.py +++ b/tests/unittests/regression/test_mean_error.py @@ -42,8 +42,8 @@ from torchmetrics.regression.symmetric_mape import SymmetricMeanAbsolutePercentageError from unittests import BATCH_SIZE, NUM_BATCHES, _Input -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/regression/test_minkowski_distance.py b/tests/unittests/regression/test_minkowski_distance.py index e00ccdaf7c4..9c0a27b6e54 100644 --- a/tests/unittests/regression/test_minkowski_distance.py +++ b/tests/unittests/regression/test_minkowski_distance.py @@ -8,8 +8,8 @@ from torchmetrics.utilities.exceptions import TorchMetricsUserError from unittests import BATCH_SIZE, NUM_BATCHES, _Input -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/regression/test_pearson.py b/tests/unittests/regression/test_pearson.py index d7ed5b27bfd..0d23507aeed 100644 --- a/tests/unittests/regression/test_pearson.py +++ b/tests/unittests/regression/test_pearson.py @@ -20,8 +20,8 @@ from torchmetrics.regression.pearson import PearsonCorrCoef, _final_aggregation from unittests import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, _Input -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/regression/test_r2.py b/tests/unittests/regression/test_r2.py index adcdc8a2dc8..32ce2554ed7 100644 --- a/tests/unittests/regression/test_r2.py +++ b/tests/unittests/regression/test_r2.py @@ -20,8 +20,8 @@ from torchmetrics.regression import R2Score from unittests import BATCH_SIZE, NUM_BATCHES, _Input -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/regression/test_rse.py b/tests/unittests/regression/test_rse.py index 0f127ed796c..4ec677aa9ff 100644 --- a/tests/unittests/regression/test_rse.py +++ b/tests/unittests/regression/test_rse.py @@ -21,8 +21,8 @@ from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import BATCH_SIZE, NUM_BATCHES, _Input -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/regression/test_spearman.py b/tests/unittests/regression/test_spearman.py index 4f3bd2ab790..b8d096e3d0e 100644 --- a/tests/unittests/regression/test_spearman.py +++ b/tests/unittests/regression/test_spearman.py @@ -21,8 +21,8 @@ from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, _Input -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/regression/test_tweedie_deviance.py b/tests/unittests/regression/test_tweedie_deviance.py index 00324dc041f..ec45b8ceb49 100644 --- a/tests/unittests/regression/test_tweedie_deviance.py +++ b/tests/unittests/regression/test_tweedie_deviance.py @@ -21,8 +21,8 @@ from torchmetrics.regression.tweedie_deviance import TweedieDevianceScore from unittests import BATCH_SIZE, NUM_BATCHES, _Input -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/retrieval/helpers.py b/tests/unittests/retrieval/helpers.py index 748c8f993b0..03d5429ba47 100644 --- a/tests/unittests/retrieval/helpers.py +++ b/tests/unittests/retrieval/helpers.py @@ -22,8 +22,8 @@ from torch import Tensor, tensor from typing_extensions import Literal -from unittests.helpers import seed_all -from unittests.helpers.testers import Metric, MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import Metric, MetricTester from unittests.retrieval._inputs import _input_retrieval_scores as _irs from unittests.retrieval._inputs import _input_retrieval_scores_all_target as _irs_all from unittests.retrieval._inputs import _input_retrieval_scores_empty as _irs_empty diff --git a/tests/unittests/retrieval/test_auroc.py b/tests/unittests/retrieval/test_auroc.py index 9a4385d1e06..26dbf62a9aa 100644 --- a/tests/unittests/retrieval/test_auroc.py +++ b/tests/unittests/retrieval/test_auroc.py @@ -21,7 +21,7 @@ from torchmetrics.retrieval.auroc import RetrievalAUROC from typing_extensions import Literal -from unittests.helpers import seed_all +from unittests._helpers import seed_all from unittests.retrieval.helpers import ( RetrievalMetricTester, _concat_tests, diff --git a/tests/unittests/retrieval/test_fallout.py b/tests/unittests/retrieval/test_fallout.py index 0ca5cb1db4b..9b0d8a3ebe0 100644 --- a/tests/unittests/retrieval/test_fallout.py +++ b/tests/unittests/retrieval/test_fallout.py @@ -20,7 +20,7 @@ from torchmetrics.retrieval.fall_out import RetrievalFallOut from typing_extensions import Literal -from unittests.helpers import seed_all +from unittests._helpers import seed_all from unittests.retrieval.helpers import ( RetrievalMetricTester, _concat_tests, diff --git a/tests/unittests/retrieval/test_hit_rate.py b/tests/unittests/retrieval/test_hit_rate.py index 1e3805aa32a..377c304ae5e 100644 --- a/tests/unittests/retrieval/test_hit_rate.py +++ b/tests/unittests/retrieval/test_hit_rate.py @@ -20,7 +20,7 @@ from torchmetrics.retrieval.hit_rate import RetrievalHitRate from typing_extensions import Literal -from unittests.helpers import seed_all +from unittests._helpers import seed_all from unittests.retrieval.helpers import ( RetrievalMetricTester, _concat_tests, diff --git a/tests/unittests/retrieval/test_map.py b/tests/unittests/retrieval/test_map.py index 8aec9a491cc..f3ac6b9989d 100644 --- a/tests/unittests/retrieval/test_map.py +++ b/tests/unittests/retrieval/test_map.py @@ -21,7 +21,7 @@ from torchmetrics.retrieval.average_precision import RetrievalMAP from typing_extensions import Literal -from unittests.helpers import seed_all +from unittests._helpers import seed_all from unittests.retrieval.helpers import ( RetrievalMetricTester, _concat_tests, diff --git a/tests/unittests/retrieval/test_mrr.py b/tests/unittests/retrieval/test_mrr.py index e5a1af2e73d..22cc946e8a8 100644 --- a/tests/unittests/retrieval/test_mrr.py +++ b/tests/unittests/retrieval/test_mrr.py @@ -21,7 +21,7 @@ from torchmetrics.retrieval.reciprocal_rank import RetrievalMRR from typing_extensions import Literal -from unittests.helpers import seed_all +from unittests._helpers import seed_all from unittests.retrieval.helpers import ( RetrievalMetricTester, _concat_tests, diff --git a/tests/unittests/retrieval/test_ndcg.py b/tests/unittests/retrieval/test_ndcg.py index 9eb29f79889..48e0c679195 100644 --- a/tests/unittests/retrieval/test_ndcg.py +++ b/tests/unittests/retrieval/test_ndcg.py @@ -22,7 +22,7 @@ from torchmetrics.retrieval.ndcg import RetrievalNormalizedDCG from typing_extensions import Literal -from unittests.helpers import seed_all +from unittests._helpers import seed_all from unittests.retrieval.helpers import ( RetrievalMetricTester, _concat_tests, diff --git a/tests/unittests/retrieval/test_precision.py b/tests/unittests/retrieval/test_precision.py index 61f096aa8fb..a6e18756cab 100644 --- a/tests/unittests/retrieval/test_precision.py +++ b/tests/unittests/retrieval/test_precision.py @@ -20,7 +20,7 @@ from torchmetrics.retrieval.precision import RetrievalPrecision from typing_extensions import Literal -from unittests.helpers import seed_all +from unittests._helpers import seed_all from unittests.retrieval.helpers import ( RetrievalMetricTester, _concat_tests, diff --git a/tests/unittests/retrieval/test_precision_recall_curve.py b/tests/unittests/retrieval/test_precision_recall_curve.py index 7ed7f12c5ee..691334d135e 100644 --- a/tests/unittests/retrieval/test_precision_recall_curve.py +++ b/tests/unittests/retrieval/test_precision_recall_curve.py @@ -23,8 +23,8 @@ from torchmetrics.retrieval.base import _retrieval_aggregate from typing_extensions import Literal -from unittests.helpers import seed_all -from unittests.helpers.testers import Metric, MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import Metric, MetricTester from unittests.retrieval.helpers import _custom_aggregate_fn, _default_metric_class_input_arguments, get_group_indexes from unittests.retrieval.test_precision import _precision_at_k from unittests.retrieval.test_recall import _recall_at_k diff --git a/tests/unittests/retrieval/test_r_precision.py b/tests/unittests/retrieval/test_r_precision.py index 6343de74ea3..50a0384a58d 100644 --- a/tests/unittests/retrieval/test_r_precision.py +++ b/tests/unittests/retrieval/test_r_precision.py @@ -20,7 +20,7 @@ from torchmetrics.retrieval.r_precision import RetrievalRPrecision from typing_extensions import Literal -from unittests.helpers import seed_all +from unittests._helpers import seed_all from unittests.retrieval.helpers import ( RetrievalMetricTester, _concat_tests, diff --git a/tests/unittests/retrieval/test_recall.py b/tests/unittests/retrieval/test_recall.py index c5de7e6600d..24ff4b6a756 100644 --- a/tests/unittests/retrieval/test_recall.py +++ b/tests/unittests/retrieval/test_recall.py @@ -20,7 +20,7 @@ from torchmetrics.retrieval.recall import RetrievalRecall from typing_extensions import Literal -from unittests.helpers import seed_all +from unittests._helpers import seed_all from unittests.retrieval.helpers import ( RetrievalMetricTester, _concat_tests, diff --git a/tests/unittests/text/helpers.py b/tests/unittests/text/_helpers.py similarity index 98% rename from tests/unittests/text/helpers.py rename to tests/unittests/text/_helpers.py index e86607518ea..1f237d3e96f 100644 --- a/tests/unittests/text/helpers.py +++ b/tests/unittests/text/_helpers.py @@ -23,8 +23,14 @@ from torchmetrics import Metric from unittests import NUM_PROCESSES, _reference_cachier -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester, _assert_allclose, _assert_requires_grad, _assert_tensor +from unittests._helpers import seed_all +from unittests._helpers.testers import ( + MetricTester, + _assert_allclose, + _assert_requires_grad, + _assert_tensor, + _select_rand_best_device, +) TEXT_METRIC_INPUT = Union[Sequence[str], Sequence[Sequence[str]], Sequence[Sequence[Sequence[str]]]] NUM_BATCHES = 2 @@ -288,7 +294,6 @@ def run_functional_metric_test( """ seed_all(42) - device = "cuda" if torch.cuda.device_count() > 0 else "cpu" _functional_test( preds=preds, @@ -297,7 +302,7 @@ def run_functional_metric_test( reference_metric=reference_metric, metric_args=metric_args, atol=self.atol, - device=device, + device=_select_rand_best_device(), fragment_kwargs=fragment_kwargs, key=key, **kwargs_update, @@ -352,7 +357,7 @@ def run_class_metric_test( "reference_metric": reference_metric, "metric_args": metric_args or {}, "atol": self.atol, - "device": "cuda" if torch.cuda.is_available() else "cpu", + "device": _select_rand_best_device(), "dist_sync_on_step": dist_sync_on_step, "check_dist_sync_on_step": check_dist_sync_on_step, "check_batch": check_batch, diff --git a/tests/unittests/text/_inputs.py b/tests/unittests/text/_inputs.py index 9d976864623..595b07cd87c 100644 --- a/tests/unittests/text/_inputs.py +++ b/tests/unittests/text/_inputs.py @@ -17,7 +17,7 @@ from torch import Tensor from unittests import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, NUM_CLASSES, _Input -from unittests.helpers import seed_all +from unittests._helpers import seed_all seed_all(1) diff --git a/tests/unittests/text/test_bertscore.py b/tests/unittests/text/test_bertscore.py index b651ecfddb1..0812808426f 100644 --- a/tests/unittests/text/test_bertscore.py +++ b/tests/unittests/text/test_bertscore.py @@ -22,9 +22,9 @@ from torchmetrics.utilities.imports import _TRANSFORMERS_GREATER_EQUAL_4_4 from typing_extensions import Literal -from unittests.helpers import skip_on_connection_issues +from unittests._helpers import skip_on_connection_issues +from unittests.text._helpers import TextTester from unittests.text._inputs import _inputs_single_reference -from unittests.text.helpers import TextTester _METRIC_KEY_TO_IDX = { "precision": 0, diff --git a/tests/unittests/text/test_bleu.py b/tests/unittests/text/test_bleu.py index 3c271cf10e6..03ce0faba02 100644 --- a/tests/unittests/text/test_bleu.py +++ b/tests/unittests/text/test_bleu.py @@ -20,8 +20,8 @@ from torchmetrics.functional.text.bleu import bleu_score from torchmetrics.text.bleu import BLEUScore +from unittests.text._helpers import TextTester from unittests.text._inputs import _inputs_multiple_references -from unittests.text.helpers import TextTester # https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.SmoothingFunction smooth_func = SmoothingFunction().method2 diff --git a/tests/unittests/text/test_cer.py b/tests/unittests/text/test_cer.py index ab09f3e5334..6ef3f7390be 100644 --- a/tests/unittests/text/test_cer.py +++ b/tests/unittests/text/test_cer.py @@ -17,8 +17,8 @@ from torchmetrics.functional.text.cer import char_error_rate from torchmetrics.text.cer import CharErrorRate +from unittests.text._helpers import TextTester from unittests.text._inputs import _inputs_error_rate_batch_size_1, _inputs_error_rate_batch_size_2 -from unittests.text.helpers import TextTester def _reference_jiwer_cer(preds: Union[str, List[str]], target: Union[str, List[str]]): diff --git a/tests/unittests/text/test_chrf.py b/tests/unittests/text/test_chrf.py index 6dd328d2c77..233c9451381 100644 --- a/tests/unittests/text/test_chrf.py +++ b/tests/unittests/text/test_chrf.py @@ -19,8 +19,8 @@ from torchmetrics.functional.text.chrf import chrf_score from torchmetrics.text.chrf import CHRFScore +from unittests.text._helpers import TextTester from unittests.text._inputs import _inputs_multiple_references, _inputs_single_sentence_multiple_references -from unittests.text.helpers import TextTester def _reference_sacrebleu_chrf( diff --git a/tests/unittests/text/test_edit.py b/tests/unittests/text/test_edit.py index a7d4029cef6..457bcfa18ad 100644 --- a/tests/unittests/text/test_edit.py +++ b/tests/unittests/text/test_edit.py @@ -18,8 +18,8 @@ from torchmetrics.functional.text.edit import edit_distance from torchmetrics.text.edit import EditDistance +from unittests.text._helpers import TextTester from unittests.text._inputs import _inputs_single_reference -from unittests.text.helpers import TextTester @pytest.mark.parametrize( diff --git a/tests/unittests/text/test_eed.py b/tests/unittests/text/test_eed.py index 964df16d3d1..a9c30d384de 100644 --- a/tests/unittests/text/test_eed.py +++ b/tests/unittests/text/test_eed.py @@ -19,8 +19,8 @@ from torchmetrics.functional.text.eed import extended_edit_distance from torchmetrics.text.eed import ExtendedEditDistance +from unittests.text._helpers import TextTester from unittests.text._inputs import _inputs_single_reference, _inputs_single_sentence_multiple_references -from unittests.text.helpers import TextTester def _reference_rwth_manual(preds, targets) -> Tensor: diff --git a/tests/unittests/text/test_infolm.py b/tests/unittests/text/test_infolm.py index d8611695ff3..1ee45cde02e 100644 --- a/tests/unittests/text/test_infolm.py +++ b/tests/unittests/text/test_infolm.py @@ -19,9 +19,9 @@ from torchmetrics.text.infolm import InfoLM from torchmetrics.utilities.imports import _TRANSFORMERS_GREATER_EQUAL_4_4 -from unittests.helpers import skip_on_connection_issues +from unittests._helpers import skip_on_connection_issues +from unittests.text._helpers import TextTester from unittests.text._inputs import HYPOTHESIS_A, HYPOTHESIS_C, _inputs_single_reference -from unittests.text.helpers import TextTester # Small bert model with 2 layers, 2 attention heads and hidden dim of 128 MODEL_NAME = "google/bert_uncased_L-2_H-128_A-2" diff --git a/tests/unittests/text/test_mer.py b/tests/unittests/text/test_mer.py index e6f5222c3b1..69e595465a7 100644 --- a/tests/unittests/text/test_mer.py +++ b/tests/unittests/text/test_mer.py @@ -17,9 +17,9 @@ from torchmetrics.functional.text.mer import match_error_rate from torchmetrics.text.mer import MatchErrorRate -from unittests.helpers import seed_all +from unittests._helpers import seed_all +from unittests.text._helpers import TextTester from unittests.text._inputs import _inputs_error_rate_batch_size_1, _inputs_error_rate_batch_size_2 -from unittests.text.helpers import TextTester seed_all(42) diff --git a/tests/unittests/text/test_perplexity.py b/tests/unittests/text/test_perplexity.py index 3673da47647..42930ec21ff 100644 --- a/tests/unittests/text/test_perplexity.py +++ b/tests/unittests/text/test_perplexity.py @@ -20,7 +20,7 @@ from torchmetrics.text.perplexity import Perplexity from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_2 -from unittests.helpers.testers import MetricTester +from unittests._helpers.testers import MetricTester from unittests.text._inputs import ( MASK_INDEX, _logits_inputs_fp32, diff --git a/tests/unittests/text/test_rouge.py b/tests/unittests/text/test_rouge.py index c9eec8a055a..a40885587e8 100644 --- a/tests/unittests/text/test_rouge.py +++ b/tests/unittests/text/test_rouge.py @@ -24,9 +24,9 @@ from torchmetrics.utilities.imports import _NLTK_AVAILABLE, _ROUGE_SCORE_AVAILABLE from typing_extensions import Literal -from unittests.helpers import skip_on_connection_issues +from unittests._helpers import skip_on_connection_issues +from unittests.text._helpers import TextTester from unittests.text._inputs import _Input, _inputs_multiple_references, _inputs_single_sentence_single_reference -from unittests.text.helpers import TextTester if _ROUGE_SCORE_AVAILABLE: from rouge_score.rouge_scorer import RougeScorer diff --git a/tests/unittests/text/test_sacre_bleu.py b/tests/unittests/text/test_sacre_bleu.py index d74d032597d..1be9387989c 100644 --- a/tests/unittests/text/test_sacre_bleu.py +++ b/tests/unittests/text/test_sacre_bleu.py @@ -20,8 +20,8 @@ from torchmetrics.functional.text.sacre_bleu import AVAILABLE_TOKENIZERS, _TokenizersLiteral, sacre_bleu_score from torchmetrics.text.sacre_bleu import SacreBLEUScore +from unittests.text._helpers import TextTester from unittests.text._inputs import _inputs_multiple_references -from unittests.text.helpers import TextTester def _reference_sacre_bleu( diff --git a/tests/unittests/text/test_squad.py b/tests/unittests/text/test_squad.py index 5bb4cc0c7fa..8a3d26af8ab 100644 --- a/tests/unittests/text/test_squad.py +++ b/tests/unittests/text/test_squad.py @@ -20,7 +20,7 @@ from torchmetrics.functional.text import squad from torchmetrics.text.squad import SQuAD -from unittests.helpers.testers import _assert_allclose, _assert_tensor +from unittests._helpers.testers import _assert_allclose, _assert_tensor from unittests.text._inputs import _inputs_squad_batch_match, _inputs_squad_exact_match, _inputs_squad_exact_mismatch diff --git a/tests/unittests/text/test_ter.py b/tests/unittests/text/test_ter.py index eb63451cf36..861a7a77723 100644 --- a/tests/unittests/text/test_ter.py +++ b/tests/unittests/text/test_ter.py @@ -19,8 +19,8 @@ from torchmetrics.functional.text.ter import translation_edit_rate from torchmetrics.text.ter import TranslationEditRate +from unittests.text._helpers import TextTester from unittests.text._inputs import _inputs_multiple_references, _inputs_single_sentence_multiple_references -from unittests.text.helpers import TextTester def _reference_sacrebleu_ter( diff --git a/tests/unittests/text/test_wer.py b/tests/unittests/text/test_wer.py index 6aee783d411..16b03849f84 100644 --- a/tests/unittests/text/test_wer.py +++ b/tests/unittests/text/test_wer.py @@ -17,8 +17,8 @@ from torchmetrics.functional.text.wer import word_error_rate from torchmetrics.text.wer import WordErrorRate +from unittests.text._helpers import TextTester from unittests.text._inputs import _inputs_error_rate_batch_size_1, _inputs_error_rate_batch_size_2 -from unittests.text.helpers import TextTester def _reference_jiwer_wer(preds: Union[str, List[str]], target: Union[str, List[str]]): diff --git a/tests/unittests/text/test_wil.py b/tests/unittests/text/test_wil.py index 08ecad16284..37278b829f1 100644 --- a/tests/unittests/text/test_wil.py +++ b/tests/unittests/text/test_wil.py @@ -17,8 +17,8 @@ from torchmetrics.functional.text.wil import word_information_lost from torchmetrics.text.wil import WordInfoLost +from unittests.text._helpers import TextTester from unittests.text._inputs import _inputs_error_rate_batch_size_1, _inputs_error_rate_batch_size_2 -from unittests.text.helpers import TextTester def _reference_jiwer_wil(preds: Union[str, List[str]], target: Union[str, List[str]]): diff --git a/tests/unittests/text/test_wip.py b/tests/unittests/text/test_wip.py index 1900f7182b2..a6523babd67 100644 --- a/tests/unittests/text/test_wip.py +++ b/tests/unittests/text/test_wip.py @@ -17,8 +17,8 @@ from torchmetrics.functional.text.wip import word_information_preserved from torchmetrics.text.wip import WordInfoPreserved +from unittests.text._helpers import TextTester from unittests.text._inputs import _inputs_error_rate_batch_size_1, _inputs_error_rate_batch_size_2 -from unittests.text.helpers import TextTester def _reference_jiwer_wip(preds: Union[str, List[str]], target: Union[str, List[str]]): diff --git a/tests/unittests/utilities/test_auc.py b/tests/unittests/utilities/test_auc.py index 37d9c1105ee..887f4c2d14b 100644 --- a/tests/unittests/utilities/test_auc.py +++ b/tests/unittests/utilities/test_auc.py @@ -20,8 +20,8 @@ from torch import Tensor, tensor from torchmetrics.utilities.compute import auc from unittests import NUM_BATCHES -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/wrappers/test_bootstrapping.py b/tests/unittests/wrappers/test_bootstrapping.py index b02b5034c75..42d890cf728 100644 --- a/tests/unittests/wrappers/test_bootstrapping.py +++ b/tests/unittests/wrappers/test_bootstrapping.py @@ -25,7 +25,7 @@ from torchmetrics.regression import MeanSquaredError from torchmetrics.wrappers.bootstrapping import BootStrapper, _bootstrap_sampler -from unittests.helpers import seed_all +from unittests._helpers import seed_all seed_all(42) diff --git a/tests/unittests/wrappers/test_minmax.py b/tests/unittests/wrappers/test_minmax.py index fe406537baf..480fd1f8f7c 100644 --- a/tests/unittests/wrappers/test_minmax.py +++ b/tests/unittests/wrappers/test_minmax.py @@ -23,8 +23,8 @@ from torchmetrics.wrappers import MinMaxMetric from unittests import BATCH_SIZE, NUM_BATCHES, NUM_CLASSES -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/wrappers/test_multioutput.py b/tests/unittests/wrappers/test_multioutput.py index 7b807e7741e..3d95b20d69a 100644 --- a/tests/unittests/wrappers/test_multioutput.py +++ b/tests/unittests/wrappers/test_multioutput.py @@ -25,8 +25,8 @@ from torchmetrics.wrappers.multioutput import MultioutputWrapper from unittests import BATCH_SIZE, NUM_BATCHES, NUM_CLASSES, _Input -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/wrappers/test_multitask.py b/tests/unittests/wrappers/test_multitask.py index 9d8f61323ef..63af6f31b35 100644 --- a/tests/unittests/wrappers/test_multitask.py +++ b/tests/unittests/wrappers/test_multitask.py @@ -22,7 +22,7 @@ from torchmetrics.wrappers import MultitaskWrapper from unittests import BATCH_SIZE, NUM_BATCHES -from unittests.helpers import seed_all +from unittests._helpers import seed_all seed_all(42) diff --git a/tests/unittests/wrappers/test_tracker.py b/tests/unittests/wrappers/test_tracker.py index a6b62db2453..93cdbc452ca 100644 --- a/tests/unittests/wrappers/test_tracker.py +++ b/tests/unittests/wrappers/test_tracker.py @@ -24,7 +24,7 @@ from torchmetrics.regression import MeanAbsoluteError, MeanSquaredError from torchmetrics.wrappers import MetricTracker, MultioutputWrapper -from unittests.helpers import seed_all +from unittests._helpers import seed_all seed_all(42)