From bc761931a890c1dd1a5a26ce6043ae49564f614e Mon Sep 17 00:00:00 2001 From: Jirka Date: Wed, 13 Mar 2024 16:49:44 +0100 Subject: [PATCH] _select_rand_best_device() --- .../unittests/{helpers => _helpers}/__init__.py | 2 +- .../unittests/{helpers => _helpers}/testers.py | 17 +++++++++++++---- .../unittests/{helpers => _helpers}/wrappers.py | 0 tests/unittests/audio/test_c_si_snr.py | 4 ++-- tests/unittests/audio/test_pesq.py | 4 ++-- tests/unittests/audio/test_pit.py | 4 ++-- tests/unittests/audio/test_sa_sdr.py | 4 ++-- tests/unittests/audio/test_sdr.py | 4 ++-- tests/unittests/audio/test_si_sdr.py | 4 ++-- tests/unittests/audio/test_si_snr.py | 4 ++-- tests/unittests/audio/test_snr.py | 4 ++-- tests/unittests/audio/test_srmr.py | 4 ++-- tests/unittests/audio/test_stoi.py | 4 ++-- tests/unittests/bases/test_aggregation.py | 2 +- tests/unittests/bases/test_collections.py | 4 ++-- tests/unittests/bases/test_ddp.py | 4 ++-- tests/unittests/bases/test_hashing.py | 2 +- tests/unittests/bases/test_metric.py | 4 ++-- tests/unittests/classification/_inputs.py | 2 +- tests/unittests/classification/test_accuracy.py | 4 ++-- tests/unittests/classification/test_auroc.py | 4 ++-- .../classification/test_average_precision.py | 4 ++-- .../classification/test_calibration_error.py | 4 ++-- .../classification/test_cohen_kappa.py | 4 ++-- .../classification/test_confusion_matrix.py | 4 ++-- tests/unittests/classification/test_dice.py | 4 ++-- .../classification/test_exact_match.py | 4 ++-- tests/unittests/classification/test_f_beta.py | 4 ++-- .../classification/test_group_fairness.py | 10 +++++----- .../classification/test_hamming_distance.py | 4 ++-- tests/unittests/classification/test_hinge.py | 2 +- tests/unittests/classification/test_jaccard.py | 2 +- .../classification/test_matthews_corrcoef.py | 4 ++-- .../test_precision_fixed_recall.py | 4 ++-- .../classification/test_precision_recall.py | 4 ++-- .../test_precision_recall_curve.py | 4 ++-- tests/unittests/classification/test_ranking.py | 4 ++-- .../test_recall_fixed_precision.py | 4 ++-- tests/unittests/classification/test_roc.py | 4 ++-- .../test_sensitivity_specificity.py | 4 ++-- .../classification/test_specificity.py | 4 ++-- .../test_specificity_sensitivity.py | 4 ++-- .../classification/test_stat_scores.py | 4 ++-- tests/unittests/clustering/_inputs.py | 2 +- .../test_adjusted_mutual_info_score.py | 4 ++-- .../clustering/test_adjusted_rand_score.py | 2 +- .../clustering/test_calinski_harabasz_score.py | 4 ++-- .../clustering/test_davies_bouldin_score.py | 4 ++-- tests/unittests/clustering/test_dunn_index.py | 4 ++-- .../clustering/test_fowlkes_mallows_index.py | 4 ++-- .../test_homogeneity_completeness_v_measure.py | 4 ++-- .../clustering/test_mutual_info_score.py | 4 ++-- .../test_normalized_mutual_info_score.py | 4 ++-- tests/unittests/clustering/test_rand_score.py | 4 ++-- tests/unittests/clustering/test_utils.py | 2 +- tests/unittests/detection/test_intersection.py | 2 +- tests/unittests/detection/test_map.py | 2 +- .../detection/test_modified_panoptic_quality.py | 4 ++-- .../detection/test_panoptic_quality.py | 4 ++-- tests/unittests/image/test_csi.py | 4 ++-- tests/unittests/image/test_d_lambda.py | 4 ++-- tests/unittests/image/test_d_s.py | 4 ++-- tests/unittests/image/test_ergas.py | 4 ++-- tests/unittests/image/test_lpips.py | 4 ++-- tests/unittests/image/test_ms_ssim.py | 4 ++-- .../image/test_perceptual_path_length.py | 2 +- tests/unittests/image/test_psnr.py | 4 ++-- tests/unittests/image/test_psnrb.py | 4 ++-- tests/unittests/image/test_qnr.py | 4 ++-- tests/unittests/image/test_rase.py | 2 +- tests/unittests/image/test_rmse_sw.py | 2 +- tests/unittests/image/test_sam.py | 4 ++-- tests/unittests/image/test_scc.py | 4 ++-- tests/unittests/image/test_ssim.py | 4 ++-- tests/unittests/image/test_tv.py | 4 ++-- tests/unittests/image/test_uqi.py | 4 ++-- tests/unittests/image/test_vif.py | 4 ++-- tests/unittests/multimodal/test_clip_iqa.py | 4 ++-- tests/unittests/multimodal/test_clip_score.py | 4 ++-- tests/unittests/nominal/test_cramers.py | 2 +- tests/unittests/nominal/test_fleiss_kappa.py | 2 +- tests/unittests/nominal/test_pearson.py | 2 +- tests/unittests/nominal/test_theils_u.py | 2 +- tests/unittests/nominal/test_tschuprows.py | 2 +- .../pairwise/test_pairwise_distance.py | 4 ++-- tests/unittests/regression/test_concordance.py | 4 ++-- .../regression/test_cosine_similarity.py | 4 ++-- .../regression/test_explained_variance.py | 4 ++-- tests/unittests/regression/test_kendall.py | 4 ++-- .../unittests/regression/test_kl_divergence.py | 4 ++-- .../unittests/regression/test_log_cosh_error.py | 4 ++-- tests/unittests/regression/test_mean_error.py | 4 ++-- .../regression/test_minkowski_distance.py | 4 ++-- tests/unittests/regression/test_pearson.py | 4 ++-- tests/unittests/regression/test_r2.py | 4 ++-- tests/unittests/regression/test_rse.py | 4 ++-- tests/unittests/regression/test_spearman.py | 4 ++-- .../regression/test_tweedie_deviance.py | 4 ++-- tests/unittests/retrieval/helpers.py | 4 ++-- tests/unittests/retrieval/test_auroc.py | 2 +- tests/unittests/retrieval/test_fallout.py | 2 +- tests/unittests/retrieval/test_hit_rate.py | 2 +- tests/unittests/retrieval/test_map.py | 2 +- tests/unittests/retrieval/test_mrr.py | 2 +- tests/unittests/retrieval/test_ndcg.py | 2 +- tests/unittests/retrieval/test_precision.py | 2 +- .../retrieval/test_precision_recall_curve.py | 4 ++-- tests/unittests/retrieval/test_r_precision.py | 2 +- tests/unittests/retrieval/test_recall.py | 2 +- .../unittests/text/{helpers.py => _helpers.py} | 10 +++++----- tests/unittests/text/_inputs.py | 2 +- tests/unittests/text/test_bertscore.py | 2 +- tests/unittests/text/test_infolm.py | 2 +- tests/unittests/text/test_mer.py | 2 +- tests/unittests/text/test_perplexity.py | 2 +- tests/unittests/text/test_rouge.py | 2 +- tests/unittests/text/test_squad.py | 2 +- tests/unittests/utilities/test_auc.py | 4 ++-- tests/unittests/wrappers/test_bootstrapping.py | 2 +- tests/unittests/wrappers/test_minmax.py | 4 ++-- tests/unittests/wrappers/test_multioutput.py | 4 ++-- tests/unittests/wrappers/test_multitask.py | 2 +- tests/unittests/wrappers/test_tracker.py | 2 +- 123 files changed, 223 insertions(+), 214 deletions(-) rename tests/unittests/{helpers => _helpers}/__init__.py (90%) rename tests/unittests/{helpers => _helpers}/testers.py (98%) rename tests/unittests/{helpers => _helpers}/wrappers.py (100%) rename tests/unittests/text/{helpers.py => _helpers.py} (98%) diff --git a/tests/unittests/helpers/__init__.py b/tests/unittests/_helpers/__init__.py similarity index 90% rename from tests/unittests/helpers/__init__.py rename to tests/unittests/_helpers/__init__.py index c4c5be42c8b..5103904838d 100644 --- a/tests/unittests/helpers/__init__.py +++ b/tests/unittests/_helpers/__init__.py @@ -16,7 +16,7 @@ import numpy import torch -from unittests.helpers.wrappers import skip_on_connection_issues, skip_on_running_out_of_memory +from unittests._helpers.wrappers import skip_on_connection_issues, skip_on_running_out_of_memory def seed_all(seed): diff --git a/tests/unittests/helpers/testers.py b/tests/unittests/_helpers/testers.py similarity index 98% rename from tests/unittests/helpers/testers.py rename to tests/unittests/_helpers/testers.py index a8885f831bb..f2109933500 100644 --- a/tests/unittests/helpers/testers.py +++ b/tests/unittests/_helpers/testers.py @@ -15,6 +15,7 @@ import sys from copy import deepcopy from functools import partial +from random import randrange from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import numpy as np @@ -337,6 +338,16 @@ def _assert_dtype_support( _assert_tensor(metric_functional(y_hat, y, **kwargs_update)) +def _select_rand_best_device() -> str: + """Select the best device to run tests on.""" + nb_gpus = torch.cuda.device_count() + if nb_gpus > 1: + return f"cuda:{randrange(nb_gpus)}" + if nb_gpus: + return "cuda" + return "cpu" + + class MetricTester: """Test class for all metrics. @@ -371,8 +382,6 @@ def run_functional_metric_test( target when running update on the metric. """ - device = "cuda" if (torch.cuda.is_available() and torch.cuda.device_count() > 0) else "cpu" - _functional_test( preds=preds, target=target, @@ -380,7 +389,7 @@ def run_functional_metric_test( reference_metric=reference_metric, metric_args=metric_args, atol=self.atol, - device=device, + device=_select_rand_best_device(), fragment_kwargs=fragment_kwargs, **kwargs_update, ) @@ -431,7 +440,7 @@ def run_class_metric_test( "reference_metric": reference_metric, "metric_args": metric_args or {}, "atol": atol or self.atol, - "device": "cuda" if torch.cuda.is_available() else "cpu", + "device": _select_rand_best_device(), "dist_sync_on_step": dist_sync_on_step, "check_dist_sync_on_step": check_dist_sync_on_step, "check_batch": check_batch, diff --git a/tests/unittests/helpers/wrappers.py b/tests/unittests/_helpers/wrappers.py similarity index 100% rename from tests/unittests/helpers/wrappers.py rename to tests/unittests/_helpers/wrappers.py diff --git a/tests/unittests/audio/test_c_si_snr.py b/tests/unittests/audio/test_c_si_snr.py index aed96ea6285..eb160410ca4 100644 --- a/tests/unittests/audio/test_c_si_snr.py +++ b/tests/unittests/audio/test_c_si_snr.py @@ -20,8 +20,8 @@ from unittests import BATCH_SIZE, NUM_BATCHES, _Input from unittests.audio import _SAMPLE_AUDIO_SPEECH, _SAMPLE_AUDIO_SPEECH_BAB_DB -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/audio/test_pesq.py b/tests/unittests/audio/test_pesq.py index 348f99c13d6..d5124b79fbe 100644 --- a/tests/unittests/audio/test_pesq.py +++ b/tests/unittests/audio/test_pesq.py @@ -23,8 +23,8 @@ from unittests import _Input from unittests.audio import _SAMPLE_AUDIO_SPEECH, _SAMPLE_AUDIO_SPEECH_BAB_DB, _average_metric_wrapper -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/audio/test_pit.py b/tests/unittests/audio/test_pit.py index 107a775e728..920d6082065 100644 --- a/tests/unittests/audio/test_pit.py +++ b/tests/unittests/audio/test_pit.py @@ -32,8 +32,8 @@ from unittests import BATCH_SIZE, NUM_BATCHES, _Input from unittests.audio import _average_metric_wrapper -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/audio/test_sa_sdr.py b/tests/unittests/audio/test_sa_sdr.py index d6e7178b8ad..3de3c4900cf 100644 --- a/tests/unittests/audio/test_sa_sdr.py +++ b/tests/unittests/audio/test_sa_sdr.py @@ -24,8 +24,8 @@ ) from unittests import BATCH_SIZE, NUM_BATCHES, _Input -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/audio/test_sdr.py b/tests/unittests/audio/test_sdr.py index ce8756d8ec7..537249da1eb 100644 --- a/tests/unittests/audio/test_sdr.py +++ b/tests/unittests/audio/test_sdr.py @@ -25,8 +25,8 @@ from unittests import _Input from unittests.audio import _SAMPLE_AUDIO_SPEECH, _SAMPLE_AUDIO_SPEECH_BAB_DB, _SAMPLE_NUMPY_ISSUE_895 -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/audio/test_si_sdr.py b/tests/unittests/audio/test_si_sdr.py index 6f014f828eb..ee9c6ff6d4b 100644 --- a/tests/unittests/audio/test_si_sdr.py +++ b/tests/unittests/audio/test_si_sdr.py @@ -22,8 +22,8 @@ from unittests import BATCH_SIZE, NUM_BATCHES, _Input from unittests.audio import _average_metric_wrapper -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/audio/test_si_snr.py b/tests/unittests/audio/test_si_snr.py index f6f6c7f52f7..d7065e0dc88 100644 --- a/tests/unittests/audio/test_si_snr.py +++ b/tests/unittests/audio/test_si_snr.py @@ -21,8 +21,8 @@ from torchmetrics.functional.audio import scale_invariant_signal_noise_ratio from unittests import BATCH_SIZE, NUM_BATCHES, _Input -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/audio/test_snr.py b/tests/unittests/audio/test_snr.py index d513360851a..917010d5dd4 100644 --- a/tests/unittests/audio/test_snr.py +++ b/tests/unittests/audio/test_snr.py @@ -22,8 +22,8 @@ from unittests import _Input from unittests.audio import _average_metric_wrapper -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/audio/test_srmr.py b/tests/unittests/audio/test_srmr.py index fa0ce989a61..3b1b07862ce 100644 --- a/tests/unittests/audio/test_srmr.py +++ b/tests/unittests/audio/test_srmr.py @@ -22,8 +22,8 @@ from torchmetrics.functional.audio.srmr import speech_reverberation_modulation_energy_ratio from torchmetrics.utilities.imports import _TORCHAUDIO_GREATER_EQUAL_0_10 -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/audio/test_stoi.py b/tests/unittests/audio/test_stoi.py index d05824a1380..af2c828d838 100644 --- a/tests/unittests/audio/test_stoi.py +++ b/tests/unittests/audio/test_stoi.py @@ -23,8 +23,8 @@ from unittests import _Input from unittests.audio import _SAMPLE_AUDIO_SPEECH, _SAMPLE_AUDIO_SPEECH_BAB_DB, _average_metric_wrapper -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/bases/test_aggregation.py b/tests/unittests/bases/test_aggregation.py index 25ff2233076..0b593208e54 100644 --- a/tests/unittests/bases/test_aggregation.py +++ b/tests/unittests/bases/test_aggregation.py @@ -5,7 +5,7 @@ from torchmetrics.collections import MetricCollection from unittests import BATCH_SIZE, NUM_BATCHES -from unittests.helpers.testers import MetricTester +from unittests._helpers.testers import MetricTester def compare_mean(values, weights): diff --git a/tests/unittests/bases/test_collections.py b/tests/unittests/bases/test_collections.py index 16c95fc879a..0e124125509 100644 --- a/tests/unittests/bases/test_collections.py +++ b/tests/unittests/bases/test_collections.py @@ -34,8 +34,8 @@ ) from torchmetrics.utilities.checks import _allclose_recursive -from unittests.helpers import seed_all -from unittests.helpers.testers import DummyMetricDiff, DummyMetricMultiOutputDict, DummyMetricSum +from unittests._helpers import seed_all +from unittests._helpers.testers import DummyMetricDiff, DummyMetricMultiOutputDict, DummyMetricSum seed_all(42) diff --git a/tests/unittests/bases/test_ddp.py b/tests/unittests/bases/test_ddp.py index f613f6f20e9..1a44a2145ba 100644 --- a/tests/unittests/bases/test_ddp.py +++ b/tests/unittests/bases/test_ddp.py @@ -24,8 +24,8 @@ from torchmetrics.utilities.exceptions import TorchMetricsUserError from unittests import NUM_PROCESSES -from unittests.helpers import seed_all -from unittests.helpers.testers import DummyListMetric, DummyMetric, DummyMetricSum +from unittests._helpers import seed_all +from unittests._helpers.testers import DummyListMetric, DummyMetric, DummyMetricSum seed_all(42) diff --git a/tests/unittests/bases/test_hashing.py b/tests/unittests/bases/test_hashing.py index b8bb28987e7..1e4dd474f02 100644 --- a/tests/unittests/bases/test_hashing.py +++ b/tests/unittests/bases/test_hashing.py @@ -1,6 +1,6 @@ import pytest -from unittests.helpers.testers import DummyListMetric, DummyMetric +from unittests._helpers.testers import DummyListMetric, DummyMetric @pytest.mark.parametrize( diff --git a/tests/unittests/bases/test_metric.py b/tests/unittests/bases/test_metric.py index a928b1a753f..22257c0d0c9 100644 --- a/tests/unittests/bases/test_metric.py +++ b/tests/unittests/bases/test_metric.py @@ -27,8 +27,8 @@ from torchmetrics.classification import BinaryAccuracy from torchmetrics.regression import PearsonCorrCoef -from unittests.helpers import seed_all -from unittests.helpers.testers import DummyListMetric, DummyMetric, DummyMetricMultiOutput, DummyMetricSum +from unittests._helpers import seed_all +from unittests._helpers.testers import DummyListMetric, DummyMetric, DummyMetricMultiOutput, DummyMetricSum seed_all(42) diff --git a/tests/unittests/classification/_inputs.py b/tests/unittests/classification/_inputs.py index c660e625214..9a4981f041c 100644 --- a/tests/unittests/classification/_inputs.py +++ b/tests/unittests/classification/_inputs.py @@ -18,7 +18,7 @@ from torch import Tensor from unittests import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, NUM_CLASSES, _GroupInput, _Input -from unittests.helpers import seed_all +from unittests._helpers import seed_all seed_all(1) diff --git a/tests/unittests/classification/test_accuracy.py b/tests/unittests/classification/test_accuracy.py index f5e5f9857d7..cf751ae9a57 100644 --- a/tests/unittests/classification/test_accuracy.py +++ b/tests/unittests/classification/test_accuracy.py @@ -30,8 +30,8 @@ from unittests import NUM_CLASSES, THRESHOLD from unittests.classification._inputs import _binary_cases, _input_binary, _multiclass_cases, _multilabel_cases -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) diff --git a/tests/unittests/classification/test_auroc.py b/tests/unittests/classification/test_auroc.py index 54374bd56e1..a2a60d3ad62 100644 --- a/tests/unittests/classification/test_auroc.py +++ b/tests/unittests/classification/test_auroc.py @@ -27,8 +27,8 @@ from unittests import NUM_CLASSES from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) diff --git a/tests/unittests/classification/test_average_precision.py b/tests/unittests/classification/test_average_precision.py index e9427623162..eedcd2794df 100644 --- a/tests/unittests/classification/test_average_precision.py +++ b/tests/unittests/classification/test_average_precision.py @@ -36,8 +36,8 @@ from unittests import NUM_CLASSES from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) diff --git a/tests/unittests/classification/test_calibration_error.py b/tests/unittests/classification/test_calibration_error.py index 5660a9042f0..0d176ca7fbc 100644 --- a/tests/unittests/classification/test_calibration_error.py +++ b/tests/unittests/classification/test_calibration_error.py @@ -33,8 +33,8 @@ from unittests import NUM_CLASSES from unittests.classification._inputs import _binary_cases, _multiclass_cases -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) diff --git a/tests/unittests/classification/test_cohen_kappa.py b/tests/unittests/classification/test_cohen_kappa.py index 6b21b9be3ff..724e74ebcd9 100644 --- a/tests/unittests/classification/test_cohen_kappa.py +++ b/tests/unittests/classification/test_cohen_kappa.py @@ -24,8 +24,8 @@ from unittests import NUM_CLASSES, THRESHOLD from unittests.classification._inputs import _binary_cases, _multiclass_cases -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) diff --git a/tests/unittests/classification/test_confusion_matrix.py b/tests/unittests/classification/test_confusion_matrix.py index 6a1c1850d4f..bcf02d679d4 100644 --- a/tests/unittests/classification/test_confusion_matrix.py +++ b/tests/unittests/classification/test_confusion_matrix.py @@ -33,8 +33,8 @@ from unittests import NUM_CLASSES, THRESHOLD from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) diff --git a/tests/unittests/classification/test_dice.py b/tests/unittests/classification/test_dice.py index d3737b255a2..211c210bb86 100644 --- a/tests/unittests/classification/test_dice.py +++ b/tests/unittests/classification/test_dice.py @@ -33,8 +33,8 @@ from unittests.classification._inputs import _input_multilabel_multidim as _input_mlmd from unittests.classification._inputs import _input_multilabel_multidim_prob as _input_mlmd_prob from unittests.classification._inputs import _input_multilabel_prob as _input_mlb_prob -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/classification/test_exact_match.py b/tests/unittests/classification/test_exact_match.py index 048003c1699..44e2dbe22e1 100644 --- a/tests/unittests/classification/test_exact_match.py +++ b/tests/unittests/classification/test_exact_match.py @@ -23,8 +23,8 @@ from unittests import NUM_CLASSES, THRESHOLD from unittests.classification._inputs import _multiclass_cases, _multilabel_cases -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester, inject_ignore_index +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester, inject_ignore_index seed_all(42) diff --git a/tests/unittests/classification/test_f_beta.py b/tests/unittests/classification/test_f_beta.py index a6cfc5f71b8..0ab54ac341f 100644 --- a/tests/unittests/classification/test_f_beta.py +++ b/tests/unittests/classification/test_f_beta.py @@ -43,8 +43,8 @@ from unittests import NUM_CLASSES, THRESHOLD from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) diff --git a/tests/unittests/classification/test_group_fairness.py b/tests/unittests/classification/test_group_fairness.py index 4d76b9301dd..686839166e5 100644 --- a/tests/unittests/classification/test_group_fairness.py +++ b/tests/unittests/classification/test_group_fairness.py @@ -29,16 +29,16 @@ from unittests import THRESHOLD from unittests.classification._inputs import _group_cases -from unittests.helpers import seed_all -from unittests.helpers.testers import ( +from unittests._helpers import seed_all +from unittests._helpers.testers import ( MetricTester, _assert_dtype_support, inject_ignore_index, remove_ignore_index_groups, ) -from unittests.helpers.testers import _assert_allclose as _core_assert_allclose -from unittests.helpers.testers import _assert_requires_grad as _core_assert_requires_grad -from unittests.helpers.testers import _assert_tensor as _core_assert_tensor +from unittests._helpers.testers import _assert_allclose as _core_assert_allclose +from unittests._helpers.testers import _assert_requires_grad as _core_assert_requires_grad +from unittests._helpers.testers import _assert_tensor as _core_assert_tensor seed_all(42) diff --git a/tests/unittests/classification/test_hamming_distance.py b/tests/unittests/classification/test_hamming_distance.py index 8ccbbc9e1fb..6b20a45c847 100644 --- a/tests/unittests/classification/test_hamming_distance.py +++ b/tests/unittests/classification/test_hamming_distance.py @@ -34,8 +34,8 @@ from unittests import NUM_CLASSES, THRESHOLD from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) diff --git a/tests/unittests/classification/test_hinge.py b/tests/unittests/classification/test_hinge.py index 6b9eaca1abd..96841f3231c 100644 --- a/tests/unittests/classification/test_hinge.py +++ b/tests/unittests/classification/test_hinge.py @@ -26,7 +26,7 @@ from unittests import NUM_CLASSES from unittests.classification._inputs import _binary_cases, _multiclass_cases -from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index +from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index torch.manual_seed(42) diff --git a/tests/unittests/classification/test_jaccard.py b/tests/unittests/classification/test_jaccard.py index c1e6354d57a..a5ed8945456 100644 --- a/tests/unittests/classification/test_jaccard.py +++ b/tests/unittests/classification/test_jaccard.py @@ -34,7 +34,7 @@ from unittests import NUM_CLASSES, THRESHOLD from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases -from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index +from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index def _reference_sklearn_jaccard_index_binary(preds, target, ignore_index=None): diff --git a/tests/unittests/classification/test_matthews_corrcoef.py b/tests/unittests/classification/test_matthews_corrcoef.py index f8c0801b5ad..425dcabd23b 100644 --- a/tests/unittests/classification/test_matthews_corrcoef.py +++ b/tests/unittests/classification/test_matthews_corrcoef.py @@ -33,8 +33,8 @@ from unittests import NUM_CLASSES, THRESHOLD from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) diff --git a/tests/unittests/classification/test_precision_fixed_recall.py b/tests/unittests/classification/test_precision_fixed_recall.py index 70c44df1109..ddd96500da1 100644 --- a/tests/unittests/classification/test_precision_fixed_recall.py +++ b/tests/unittests/classification/test_precision_fixed_recall.py @@ -35,8 +35,8 @@ from unittests import NUM_CLASSES from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) diff --git a/tests/unittests/classification/test_precision_recall.py b/tests/unittests/classification/test_precision_recall.py index 7cf4ce9e474..98d615618e7 100644 --- a/tests/unittests/classification/test_precision_recall.py +++ b/tests/unittests/classification/test_precision_recall.py @@ -43,8 +43,8 @@ from unittests import NUM_CLASSES, THRESHOLD from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) diff --git a/tests/unittests/classification/test_precision_recall_curve.py b/tests/unittests/classification/test_precision_recall_curve.py index 825e6887584..a990ada5144 100644 --- a/tests/unittests/classification/test_precision_recall_curve.py +++ b/tests/unittests/classification/test_precision_recall_curve.py @@ -35,8 +35,8 @@ from unittests import NUM_CLASSES from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) diff --git a/tests/unittests/classification/test_ranking.py b/tests/unittests/classification/test_ranking.py index f322ce44442..9bbcb4f3928 100644 --- a/tests/unittests/classification/test_ranking.py +++ b/tests/unittests/classification/test_ranking.py @@ -33,8 +33,8 @@ from unittests import NUM_CLASSES from unittests.classification._inputs import _multilabel_cases -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester, inject_ignore_index +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester, inject_ignore_index seed_all(42) diff --git a/tests/unittests/classification/test_recall_fixed_precision.py b/tests/unittests/classification/test_recall_fixed_precision.py index fb23e36f759..7f28f10594d 100644 --- a/tests/unittests/classification/test_recall_fixed_precision.py +++ b/tests/unittests/classification/test_recall_fixed_precision.py @@ -35,8 +35,8 @@ from unittests import NUM_CLASSES from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) diff --git a/tests/unittests/classification/test_roc.py b/tests/unittests/classification/test_roc.py index cc12bd6c133..4b7c55379e8 100644 --- a/tests/unittests/classification/test_roc.py +++ b/tests/unittests/classification/test_roc.py @@ -26,8 +26,8 @@ from unittests import NUM_CLASSES from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) diff --git a/tests/unittests/classification/test_sensitivity_specificity.py b/tests/unittests/classification/test_sensitivity_specificity.py index 3ffbb8ac5cb..13dd2ea2763 100644 --- a/tests/unittests/classification/test_sensitivity_specificity.py +++ b/tests/unittests/classification/test_sensitivity_specificity.py @@ -37,8 +37,8 @@ from unittests import NUM_CLASSES from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) diff --git a/tests/unittests/classification/test_specificity.py b/tests/unittests/classification/test_specificity.py index 3aa2dcf6cba..7855f95f3cd 100644 --- a/tests/unittests/classification/test_specificity.py +++ b/tests/unittests/classification/test_specificity.py @@ -34,8 +34,8 @@ from unittests import NUM_CLASSES, THRESHOLD from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester, inject_ignore_index +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester, inject_ignore_index seed_all(42) diff --git a/tests/unittests/classification/test_specificity_sensitivity.py b/tests/unittests/classification/test_specificity_sensitivity.py index 9fc918d1e1d..f5ade54d041 100644 --- a/tests/unittests/classification/test_specificity_sensitivity.py +++ b/tests/unittests/classification/test_specificity_sensitivity.py @@ -36,8 +36,8 @@ from unittests import NUM_CLASSES from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) diff --git a/tests/unittests/classification/test_stat_scores.py b/tests/unittests/classification/test_stat_scores.py index b1e4d36e1ed..e820c31a0bf 100644 --- a/tests/unittests/classification/test_stat_scores.py +++ b/tests/unittests/classification/test_stat_scores.py @@ -33,8 +33,8 @@ from unittests import NUM_CLASSES, THRESHOLD from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) diff --git a/tests/unittests/clustering/_inputs.py b/tests/unittests/clustering/_inputs.py index b13bc0c0947..b4f6ce0b97c 100644 --- a/tests/unittests/clustering/_inputs.py +++ b/tests/unittests/clustering/_inputs.py @@ -18,7 +18,7 @@ from torch import Tensor from unittests import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, NUM_CLASSES, _Input -from unittests.helpers import seed_all +from unittests._helpers import seed_all seed_all(42) diff --git a/tests/unittests/clustering/test_adjusted_mutual_info_score.py b/tests/unittests/clustering/test_adjusted_mutual_info_score.py index 304ee18bddc..40a617ad53e 100644 --- a/tests/unittests/clustering/test_adjusted_mutual_info_score.py +++ b/tests/unittests/clustering/test_adjusted_mutual_info_score.py @@ -21,8 +21,8 @@ from unittests import BATCH_SIZE, NUM_CLASSES from unittests.clustering._inputs import _float_inputs_extrinsic, _single_target_extrinsic1, _single_target_extrinsic2 -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/clustering/test_adjusted_rand_score.py b/tests/unittests/clustering/test_adjusted_rand_score.py index 54e8c1b4577..5232314db2e 100644 --- a/tests/unittests/clustering/test_adjusted_rand_score.py +++ b/tests/unittests/clustering/test_adjusted_rand_score.py @@ -18,7 +18,7 @@ from torchmetrics.functional.clustering.adjusted_rand_score import adjusted_rand_score from unittests.clustering._inputs import _float_inputs_extrinsic, _single_target_extrinsic1, _single_target_extrinsic2 -from unittests.helpers.testers import MetricTester +from unittests._helpers.testers import MetricTester @pytest.mark.parametrize( diff --git a/tests/unittests/clustering/test_calinski_harabasz_score.py b/tests/unittests/clustering/test_calinski_harabasz_score.py index 6071767364e..9bd6639d178 100644 --- a/tests/unittests/clustering/test_calinski_harabasz_score.py +++ b/tests/unittests/clustering/test_calinski_harabasz_score.py @@ -17,8 +17,8 @@ from torchmetrics.functional.clustering.calinski_harabasz_score import calinski_harabasz_score from unittests.clustering._inputs import _single_target_intrinsic1, _single_target_intrinsic2 -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/clustering/test_davies_bouldin_score.py b/tests/unittests/clustering/test_davies_bouldin_score.py index 8f3b4800a01..2614789c1d4 100644 --- a/tests/unittests/clustering/test_davies_bouldin_score.py +++ b/tests/unittests/clustering/test_davies_bouldin_score.py @@ -17,8 +17,8 @@ from torchmetrics.functional.clustering.davies_bouldin_score import davies_bouldin_score from unittests.clustering._inputs import _single_target_intrinsic1, _single_target_intrinsic2 -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/clustering/test_dunn_index.py b/tests/unittests/clustering/test_dunn_index.py index fc1500a0fd8..84a80b0f913 100644 --- a/tests/unittests/clustering/test_dunn_index.py +++ b/tests/unittests/clustering/test_dunn_index.py @@ -23,8 +23,8 @@ _single_target_intrinsic1, _single_target_intrinsic2, ) -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/clustering/test_fowlkes_mallows_index.py b/tests/unittests/clustering/test_fowlkes_mallows_index.py index f880791454f..f4949c74448 100644 --- a/tests/unittests/clustering/test_fowlkes_mallows_index.py +++ b/tests/unittests/clustering/test_fowlkes_mallows_index.py @@ -17,8 +17,8 @@ from torchmetrics.functional.clustering import fowlkes_mallows_index from unittests.clustering._inputs import _single_target_extrinsic1, _single_target_extrinsic2 -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/clustering/test_homogeneity_completeness_v_measure.py b/tests/unittests/clustering/test_homogeneity_completeness_v_measure.py index db8224f2f5e..3c3e1c4dace 100644 --- a/tests/unittests/clustering/test_homogeneity_completeness_v_measure.py +++ b/tests/unittests/clustering/test_homogeneity_completeness_v_measure.py @@ -29,8 +29,8 @@ ) from unittests.clustering._inputs import _float_inputs_extrinsic, _single_target_extrinsic1, _single_target_extrinsic2 -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/clustering/test_mutual_info_score.py b/tests/unittests/clustering/test_mutual_info_score.py index 2a5fd2af1ad..09b82ee37f3 100644 --- a/tests/unittests/clustering/test_mutual_info_score.py +++ b/tests/unittests/clustering/test_mutual_info_score.py @@ -19,8 +19,8 @@ from unittests import BATCH_SIZE, NUM_CLASSES from unittests.clustering._inputs import _float_inputs_extrinsic, _single_target_extrinsic1, _single_target_extrinsic2 -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/clustering/test_normalized_mutual_info_score.py b/tests/unittests/clustering/test_normalized_mutual_info_score.py index e40b807958a..47696f4b13c 100644 --- a/tests/unittests/clustering/test_normalized_mutual_info_score.py +++ b/tests/unittests/clustering/test_normalized_mutual_info_score.py @@ -21,8 +21,8 @@ from unittests import BATCH_SIZE, NUM_CLASSES from unittests.clustering._inputs import _float_inputs_extrinsic, _single_target_extrinsic1, _single_target_extrinsic2 -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/clustering/test_rand_score.py b/tests/unittests/clustering/test_rand_score.py index 9a2abf1a736..3c1189e5702 100644 --- a/tests/unittests/clustering/test_rand_score.py +++ b/tests/unittests/clustering/test_rand_score.py @@ -18,8 +18,8 @@ from torchmetrics.functional.clustering.rand_score import rand_score from unittests.clustering._inputs import _float_inputs_extrinsic, _single_target_extrinsic1, _single_target_extrinsic2 -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/clustering/test_utils.py b/tests/unittests/clustering/test_utils.py index e6ffe222b46..2b1e4dbc755 100644 --- a/tests/unittests/clustering/test_utils.py +++ b/tests/unittests/clustering/test_utils.py @@ -27,7 +27,7 @@ ) from unittests import BATCH_SIZE, NUM_BATCHES, _Input -from unittests.helpers import seed_all +from unittests._helpers import seed_all seed_all(42) diff --git a/tests/unittests/detection/test_intersection.py b/tests/unittests/detection/test_intersection.py index 03f10242f09..9d99d2d55c5 100644 --- a/tests/unittests/detection/test_intersection.py +++ b/tests/unittests/detection/test_intersection.py @@ -35,7 +35,7 @@ else: tv_iou, tv_ciou, tv_diou, tv_giou = ..., ..., ..., ... -from unittests.helpers.testers import MetricTester +from unittests._helpers.testers import MetricTester def _tv_wrapper(preds, target, base_fn, aggregate=True, iou_threshold=None): diff --git a/tests/unittests/detection/test_map.py b/tests/unittests/detection/test_map.py index 46d9db5a839..9cb3fe22dc1 100644 --- a/tests/unittests/detection/test_map.py +++ b/tests/unittests/detection/test_map.py @@ -33,7 +33,7 @@ ) from unittests.detection import _DETECTION_BBOX, _DETECTION_SEGM, _DETECTION_VAL -from unittests.helpers.testers import MetricTester +from unittests._helpers.testers import MetricTester _pytest_condition = not (_PYCOCOTOOLS_AVAILABLE and _TORCHVISION_GREATER_EQUAL_0_8) diff --git a/tests/unittests/detection/test_modified_panoptic_quality.py b/tests/unittests/detection/test_modified_panoptic_quality.py index 96b02d16930..4c864d0e9af 100644 --- a/tests/unittests/detection/test_modified_panoptic_quality.py +++ b/tests/unittests/detection/test_modified_panoptic_quality.py @@ -20,8 +20,8 @@ from torchmetrics.functional.detection import modified_panoptic_quality from unittests import _Input -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/detection/test_panoptic_quality.py b/tests/unittests/detection/test_panoptic_quality.py index 9a2b801e0f4..a8263de61e0 100644 --- a/tests/unittests/detection/test_panoptic_quality.py +++ b/tests/unittests/detection/test_panoptic_quality.py @@ -20,8 +20,8 @@ from torchmetrics.functional.detection.panoptic_qualities import panoptic_quality from unittests import _Input -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/image/test_csi.py b/tests/unittests/image/test_csi.py index 9b338167889..85e370c008f 100644 --- a/tests/unittests/image/test_csi.py +++ b/tests/unittests/image/test_csi.py @@ -21,8 +21,8 @@ from torchmetrics.regression.csi import CriticalSuccessIndex from unittests import BATCH_SIZE, NUM_BATCHES, _Input -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/image/test_d_lambda.py b/tests/unittests/image/test_d_lambda.py index be1ec3ffdf1..a4eb684b5ac 100644 --- a/tests/unittests/image/test_d_lambda.py +++ b/tests/unittests/image/test_d_lambda.py @@ -23,8 +23,8 @@ from torchmetrics.image.d_lambda import SpectralDistortionIndex from unittests import BATCH_SIZE, NUM_BATCHES -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/image/test_d_s.py b/tests/unittests/image/test_d_s.py index fb2a0fdf2db..e14cb2c96a0 100644 --- a/tests/unittests/image/test_d_s.py +++ b/tests/unittests/image/test_d_s.py @@ -26,8 +26,8 @@ from torchmetrics.image.d_s import SpatialDistortionIndex from unittests import BATCH_SIZE, NUM_BATCHES -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/image/test_ergas.py b/tests/unittests/image/test_ergas.py index 110fa6bf4bc..58bd999c205 100644 --- a/tests/unittests/image/test_ergas.py +++ b/tests/unittests/image/test_ergas.py @@ -22,8 +22,8 @@ from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import BATCH_SIZE, NUM_BATCHES -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/image/test_lpips.py b/tests/unittests/image/test_lpips.py index 026c2b91770..5e148a0d984 100644 --- a/tests/unittests/image/test_lpips.py +++ b/tests/unittests/image/test_lpips.py @@ -21,8 +21,8 @@ from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity from torchmetrics.utilities.imports import _TORCHVISION_AVAILABLE -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/image/test_ms_ssim.py b/tests/unittests/image/test_ms_ssim.py index 8d71617be53..8201e877332 100644 --- a/tests/unittests/image/test_ms_ssim.py +++ b/tests/unittests/image/test_ms_ssim.py @@ -19,8 +19,8 @@ from torchmetrics.image.ssim import MultiScaleStructuralSimilarityIndexMeasure from unittests import NUM_BATCHES, _Input -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/image/test_perceptual_path_length.py b/tests/unittests/image/test_perceptual_path_length.py index 1eb486c6ce3..26f48be51ae 100644 --- a/tests/unittests/image/test_perceptual_path_length.py +++ b/tests/unittests/image/test_perceptual_path_length.py @@ -24,7 +24,7 @@ from torchmetrics.image.perceptual_path_length import PerceptualPathLength from torchmetrics.utilities.imports import _TORCH_FIDELITY_AVAILABLE -from unittests.helpers import seed_all, skip_on_running_out_of_memory +from unittests._helpers import seed_all, skip_on_running_out_of_memory seed_all(42) diff --git a/tests/unittests/image/test_psnr.py b/tests/unittests/image/test_psnr.py index 0cfe9546017..66724af1f4c 100644 --- a/tests/unittests/image/test_psnr.py +++ b/tests/unittests/image/test_psnr.py @@ -23,8 +23,8 @@ from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import BATCH_SIZE, NUM_BATCHES, _Input -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/image/test_psnrb.py b/tests/unittests/image/test_psnrb.py index 077af420601..2d59efa1f79 100644 --- a/tests/unittests/image/test_psnrb.py +++ b/tests/unittests/image/test_psnrb.py @@ -21,8 +21,8 @@ from torchmetrics.image import PeakSignalNoiseRatioWithBlockedEffect from unittests import BATCH_SIZE, NUM_BATCHES -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/image/test_qnr.py b/tests/unittests/image/test_qnr.py index e1e680beed6..4cb42cf36be 100644 --- a/tests/unittests/image/test_qnr.py +++ b/tests/unittests/image/test_qnr.py @@ -22,8 +22,8 @@ from torchmetrics.image.qnr import QualityWithNoReference from unittests import BATCH_SIZE, NUM_BATCHES -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester from unittests.image.test_d_lambda import _baseline_d_lambda from unittests.image.test_d_s import _reference_d_s diff --git a/tests/unittests/image/test_rase.py b/tests/unittests/image/test_rase.py index 8015153fd19..f9227285e87 100644 --- a/tests/unittests/image/test_rase.py +++ b/tests/unittests/image/test_rase.py @@ -23,7 +23,7 @@ from torchmetrics.image import RelativeAverageSpectralError from unittests import BATCH_SIZE -from unittests.helpers.testers import MetricTester +from unittests._helpers.testers import MetricTester class _InputWindowSized(NamedTuple): diff --git a/tests/unittests/image/test_rmse_sw.py b/tests/unittests/image/test_rmse_sw.py index a989de625a7..307d66ac9b1 100644 --- a/tests/unittests/image/test_rmse_sw.py +++ b/tests/unittests/image/test_rmse_sw.py @@ -22,7 +22,7 @@ from torchmetrics.image import RootMeanSquaredErrorUsingSlidingWindow from unittests import BATCH_SIZE, NUM_BATCHES -from unittests.helpers.testers import MetricTester +from unittests._helpers.testers import MetricTester class _InputWindowSized(NamedTuple): diff --git a/tests/unittests/image/test_sam.py b/tests/unittests/image/test_sam.py index 9cf83196866..e71b0b230e3 100644 --- a/tests/unittests/image/test_sam.py +++ b/tests/unittests/image/test_sam.py @@ -22,8 +22,8 @@ from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import BATCH_SIZE, NUM_BATCHES, _Input -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/image/test_scc.py b/tests/unittests/image/test_scc.py index fa1854adc60..ecf6c355677 100644 --- a/tests/unittests/image/test_scc.py +++ b/tests/unittests/image/test_scc.py @@ -21,8 +21,8 @@ from torchmetrics.image import SpatialCorrelationCoefficient from unittests import BATCH_SIZE, NUM_BATCHES, _Input -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/image/test_ssim.py b/tests/unittests/image/test_ssim.py index a84f2d83468..6b464f2a97b 100644 --- a/tests/unittests/image/test_ssim.py +++ b/tests/unittests/image/test_ssim.py @@ -23,8 +23,8 @@ from torchmetrics.image import StructuralSimilarityIndexMeasure from unittests import NUM_BATCHES, _Input -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/image/test_tv.py b/tests/unittests/image/test_tv.py index 1842b693046..add144897ae 100644 --- a/tests/unittests/image/test_tv.py +++ b/tests/unittests/image/test_tv.py @@ -21,8 +21,8 @@ from torchmetrics.image.tv import TotalVariation from unittests import _Input -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/image/test_uqi.py b/tests/unittests/image/test_uqi.py index 22ff2b7479b..0539eb4d0d1 100644 --- a/tests/unittests/image/test_uqi.py +++ b/tests/unittests/image/test_uqi.py @@ -22,8 +22,8 @@ from torchmetrics.image.uqi import UniversalImageQualityIndex from unittests import BATCH_SIZE, NUM_BATCHES -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/image/test_vif.py b/tests/unittests/image/test_vif.py index bde00c969b5..e1ea8eb1401 100644 --- a/tests/unittests/image/test_vif.py +++ b/tests/unittests/image/test_vif.py @@ -20,8 +20,8 @@ from torchmetrics.image.vif import VisualInformationFidelity from unittests import BATCH_SIZE, NUM_BATCHES, _Input -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/multimodal/test_clip_iqa.py b/tests/unittests/multimodal/test_clip_iqa.py index 314ff0013b8..2e88840e81b 100644 --- a/tests/unittests/multimodal/test_clip_iqa.py +++ b/tests/unittests/multimodal/test_clip_iqa.py @@ -26,8 +26,8 @@ from torchmetrics.utilities.imports import _PIQ_GREATER_EQUAL_0_8, _TRANSFORMERS_GREATER_EQUAL_4_10 from torchvision.transforms import PILToTensor -from unittests.helpers import skip_on_connection_issues -from unittests.helpers.testers import MetricTester +from unittests._helpers import skip_on_connection_issues +from unittests._helpers.testers import MetricTester from unittests.image import _SAMPLE_IMAGE diff --git a/tests/unittests/multimodal/test_clip_score.py b/tests/unittests/multimodal/test_clip_score.py index 110266e6525..e2804ecebb9 100644 --- a/tests/unittests/multimodal/test_clip_score.py +++ b/tests/unittests/multimodal/test_clip_score.py @@ -25,8 +25,8 @@ from transformers import CLIPModel as _CLIPModel from transformers import CLIPProcessor as _CLIPProcessor -from unittests.helpers import seed_all, skip_on_connection_issues -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all, skip_on_connection_issues +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/nominal/test_cramers.py b/tests/unittests/nominal/test_cramers.py index 42b735ef510..78df22d21b9 100644 --- a/tests/unittests/nominal/test_cramers.py +++ b/tests/unittests/nominal/test_cramers.py @@ -20,7 +20,7 @@ from torchmetrics.nominal.cramers import CramersV from unittests import BATCH_SIZE, NUM_BATCHES, _Input -from unittests.helpers.testers import MetricTester +from unittests._helpers.testers import MetricTester NUM_CLASSES = 4 diff --git a/tests/unittests/nominal/test_fleiss_kappa.py b/tests/unittests/nominal/test_fleiss_kappa.py index 1538b116ab2..911f2814ad5 100644 --- a/tests/unittests/nominal/test_fleiss_kappa.py +++ b/tests/unittests/nominal/test_fleiss_kappa.py @@ -21,7 +21,7 @@ from torchmetrics.nominal.fleiss_kappa import FleissKappa from unittests import BATCH_SIZE, NUM_BATCHES, NUM_CLASSES -from unittests.helpers.testers import MetricTester +from unittests._helpers.testers import MetricTester NUM_RATERS = 20 NUM_CATEGORIES = NUM_CLASSES diff --git a/tests/unittests/nominal/test_pearson.py b/tests/unittests/nominal/test_pearson.py index 5bec1cd8121..2f59461d276 100644 --- a/tests/unittests/nominal/test_pearson.py +++ b/tests/unittests/nominal/test_pearson.py @@ -23,7 +23,7 @@ from torchmetrics.nominal.pearson import PearsonsContingencyCoefficient from unittests import BATCH_SIZE, NUM_BATCHES, _Input -from unittests.helpers.testers import MetricTester +from unittests._helpers.testers import MetricTester NUM_CLASSES = 4 diff --git a/tests/unittests/nominal/test_theils_u.py b/tests/unittests/nominal/test_theils_u.py index c06c6b9bcd2..a8bc2bc9952 100644 --- a/tests/unittests/nominal/test_theils_u.py +++ b/tests/unittests/nominal/test_theils_u.py @@ -20,7 +20,7 @@ from torchmetrics.nominal import TheilsU from unittests import BATCH_SIZE, NUM_BATCHES, _Input -from unittests.helpers.testers import MetricTester +from unittests._helpers.testers import MetricTester NUM_CLASSES = 4 diff --git a/tests/unittests/nominal/test_tschuprows.py b/tests/unittests/nominal/test_tschuprows.py index 91798d88d82..a6a8c1d2b39 100644 --- a/tests/unittests/nominal/test_tschuprows.py +++ b/tests/unittests/nominal/test_tschuprows.py @@ -20,7 +20,7 @@ from torchmetrics.nominal.tschuprows import TschuprowsT from unittests import BATCH_SIZE, NUM_BATCHES, _Input -from unittests.helpers.testers import MetricTester +from unittests._helpers.testers import MetricTester NUM_CLASSES = 4 diff --git a/tests/unittests/pairwise/test_pairwise_distance.py b/tests/unittests/pairwise/test_pairwise_distance.py index 3777ee63276..6538423c592 100644 --- a/tests/unittests/pairwise/test_pairwise_distance.py +++ b/tests/unittests/pairwise/test_pairwise_distance.py @@ -33,8 +33,8 @@ ) from unittests import BATCH_SIZE, NUM_BATCHES -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/regression/test_concordance.py b/tests/unittests/regression/test_concordance.py index 69668772021..06493af0e0d 100644 --- a/tests/unittests/regression/test_concordance.py +++ b/tests/unittests/regression/test_concordance.py @@ -22,8 +22,8 @@ from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, _Input -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/regression/test_cosine_similarity.py b/tests/unittests/regression/test_cosine_similarity.py index 184676526c7..de689cabb45 100644 --- a/tests/unittests/regression/test_cosine_similarity.py +++ b/tests/unittests/regression/test_cosine_similarity.py @@ -21,8 +21,8 @@ from torchmetrics.regression.cosine_similarity import CosineSimilarity from unittests import BATCH_SIZE, NUM_BATCHES, _Input -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/regression/test_explained_variance.py b/tests/unittests/regression/test_explained_variance.py index 629e1ea7932..4838efa5df1 100644 --- a/tests/unittests/regression/test_explained_variance.py +++ b/tests/unittests/regression/test_explained_variance.py @@ -20,8 +20,8 @@ from torchmetrics.regression import ExplainedVariance from unittests import BATCH_SIZE, NUM_BATCHES, _Input -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/regression/test_kendall.py b/tests/unittests/regression/test_kendall.py index c7e5747a0ba..017179069e0 100644 --- a/tests/unittests/regression/test_kendall.py +++ b/tests/unittests/regression/test_kendall.py @@ -24,8 +24,8 @@ from torchmetrics.utilities.imports import _SCIPY_GREATER_EQUAL_1_8, _TORCH_LOWER_2_0 from unittests import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, _Input -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/regression/test_kl_divergence.py b/tests/unittests/regression/test_kl_divergence.py index 6af23cd2f44..53a3817fa82 100644 --- a/tests/unittests/regression/test_kl_divergence.py +++ b/tests/unittests/regression/test_kl_divergence.py @@ -24,8 +24,8 @@ from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/regression/test_log_cosh_error.py b/tests/unittests/regression/test_log_cosh_error.py index 74b6214719b..9931ec91c22 100644 --- a/tests/unittests/regression/test_log_cosh_error.py +++ b/tests/unittests/regression/test_log_cosh_error.py @@ -20,8 +20,8 @@ from torchmetrics.regression.log_cosh import LogCoshError from unittests import BATCH_SIZE, NUM_BATCHES, _Input -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/regression/test_mean_error.py b/tests/unittests/regression/test_mean_error.py index c25882c3f37..631d9e6afe4 100644 --- a/tests/unittests/regression/test_mean_error.py +++ b/tests/unittests/regression/test_mean_error.py @@ -42,8 +42,8 @@ from torchmetrics.regression.symmetric_mape import SymmetricMeanAbsolutePercentageError from unittests import BATCH_SIZE, NUM_BATCHES, _Input -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/regression/test_minkowski_distance.py b/tests/unittests/regression/test_minkowski_distance.py index e00ccdaf7c4..9c0a27b6e54 100644 --- a/tests/unittests/regression/test_minkowski_distance.py +++ b/tests/unittests/regression/test_minkowski_distance.py @@ -8,8 +8,8 @@ from torchmetrics.utilities.exceptions import TorchMetricsUserError from unittests import BATCH_SIZE, NUM_BATCHES, _Input -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/regression/test_pearson.py b/tests/unittests/regression/test_pearson.py index d7ed5b27bfd..0d23507aeed 100644 --- a/tests/unittests/regression/test_pearson.py +++ b/tests/unittests/regression/test_pearson.py @@ -20,8 +20,8 @@ from torchmetrics.regression.pearson import PearsonCorrCoef, _final_aggregation from unittests import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, _Input -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/regression/test_r2.py b/tests/unittests/regression/test_r2.py index adcdc8a2dc8..32ce2554ed7 100644 --- a/tests/unittests/regression/test_r2.py +++ b/tests/unittests/regression/test_r2.py @@ -20,8 +20,8 @@ from torchmetrics.regression import R2Score from unittests import BATCH_SIZE, NUM_BATCHES, _Input -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/regression/test_rse.py b/tests/unittests/regression/test_rse.py index 0f127ed796c..4ec677aa9ff 100644 --- a/tests/unittests/regression/test_rse.py +++ b/tests/unittests/regression/test_rse.py @@ -21,8 +21,8 @@ from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import BATCH_SIZE, NUM_BATCHES, _Input -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/regression/test_spearman.py b/tests/unittests/regression/test_spearman.py index 4f3bd2ab790..b8d096e3d0e 100644 --- a/tests/unittests/regression/test_spearman.py +++ b/tests/unittests/regression/test_spearman.py @@ -21,8 +21,8 @@ from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, _Input -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/regression/test_tweedie_deviance.py b/tests/unittests/regression/test_tweedie_deviance.py index 00324dc041f..ec45b8ceb49 100644 --- a/tests/unittests/regression/test_tweedie_deviance.py +++ b/tests/unittests/regression/test_tweedie_deviance.py @@ -21,8 +21,8 @@ from torchmetrics.regression.tweedie_deviance import TweedieDevianceScore from unittests import BATCH_SIZE, NUM_BATCHES, _Input -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/retrieval/helpers.py b/tests/unittests/retrieval/helpers.py index 748c8f993b0..03d5429ba47 100644 --- a/tests/unittests/retrieval/helpers.py +++ b/tests/unittests/retrieval/helpers.py @@ -22,8 +22,8 @@ from torch import Tensor, tensor from typing_extensions import Literal -from unittests.helpers import seed_all -from unittests.helpers.testers import Metric, MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import Metric, MetricTester from unittests.retrieval._inputs import _input_retrieval_scores as _irs from unittests.retrieval._inputs import _input_retrieval_scores_all_target as _irs_all from unittests.retrieval._inputs import _input_retrieval_scores_empty as _irs_empty diff --git a/tests/unittests/retrieval/test_auroc.py b/tests/unittests/retrieval/test_auroc.py index 9a4385d1e06..26dbf62a9aa 100644 --- a/tests/unittests/retrieval/test_auroc.py +++ b/tests/unittests/retrieval/test_auroc.py @@ -21,7 +21,7 @@ from torchmetrics.retrieval.auroc import RetrievalAUROC from typing_extensions import Literal -from unittests.helpers import seed_all +from unittests._helpers import seed_all from unittests.retrieval.helpers import ( RetrievalMetricTester, _concat_tests, diff --git a/tests/unittests/retrieval/test_fallout.py b/tests/unittests/retrieval/test_fallout.py index 0ca5cb1db4b..9b0d8a3ebe0 100644 --- a/tests/unittests/retrieval/test_fallout.py +++ b/tests/unittests/retrieval/test_fallout.py @@ -20,7 +20,7 @@ from torchmetrics.retrieval.fall_out import RetrievalFallOut from typing_extensions import Literal -from unittests.helpers import seed_all +from unittests._helpers import seed_all from unittests.retrieval.helpers import ( RetrievalMetricTester, _concat_tests, diff --git a/tests/unittests/retrieval/test_hit_rate.py b/tests/unittests/retrieval/test_hit_rate.py index 1e3805aa32a..377c304ae5e 100644 --- a/tests/unittests/retrieval/test_hit_rate.py +++ b/tests/unittests/retrieval/test_hit_rate.py @@ -20,7 +20,7 @@ from torchmetrics.retrieval.hit_rate import RetrievalHitRate from typing_extensions import Literal -from unittests.helpers import seed_all +from unittests._helpers import seed_all from unittests.retrieval.helpers import ( RetrievalMetricTester, _concat_tests, diff --git a/tests/unittests/retrieval/test_map.py b/tests/unittests/retrieval/test_map.py index 8aec9a491cc..f3ac6b9989d 100644 --- a/tests/unittests/retrieval/test_map.py +++ b/tests/unittests/retrieval/test_map.py @@ -21,7 +21,7 @@ from torchmetrics.retrieval.average_precision import RetrievalMAP from typing_extensions import Literal -from unittests.helpers import seed_all +from unittests._helpers import seed_all from unittests.retrieval.helpers import ( RetrievalMetricTester, _concat_tests, diff --git a/tests/unittests/retrieval/test_mrr.py b/tests/unittests/retrieval/test_mrr.py index e5a1af2e73d..22cc946e8a8 100644 --- a/tests/unittests/retrieval/test_mrr.py +++ b/tests/unittests/retrieval/test_mrr.py @@ -21,7 +21,7 @@ from torchmetrics.retrieval.reciprocal_rank import RetrievalMRR from typing_extensions import Literal -from unittests.helpers import seed_all +from unittests._helpers import seed_all from unittests.retrieval.helpers import ( RetrievalMetricTester, _concat_tests, diff --git a/tests/unittests/retrieval/test_ndcg.py b/tests/unittests/retrieval/test_ndcg.py index 9eb29f79889..48e0c679195 100644 --- a/tests/unittests/retrieval/test_ndcg.py +++ b/tests/unittests/retrieval/test_ndcg.py @@ -22,7 +22,7 @@ from torchmetrics.retrieval.ndcg import RetrievalNormalizedDCG from typing_extensions import Literal -from unittests.helpers import seed_all +from unittests._helpers import seed_all from unittests.retrieval.helpers import ( RetrievalMetricTester, _concat_tests, diff --git a/tests/unittests/retrieval/test_precision.py b/tests/unittests/retrieval/test_precision.py index 61f096aa8fb..a6e18756cab 100644 --- a/tests/unittests/retrieval/test_precision.py +++ b/tests/unittests/retrieval/test_precision.py @@ -20,7 +20,7 @@ from torchmetrics.retrieval.precision import RetrievalPrecision from typing_extensions import Literal -from unittests.helpers import seed_all +from unittests._helpers import seed_all from unittests.retrieval.helpers import ( RetrievalMetricTester, _concat_tests, diff --git a/tests/unittests/retrieval/test_precision_recall_curve.py b/tests/unittests/retrieval/test_precision_recall_curve.py index 7ed7f12c5ee..691334d135e 100644 --- a/tests/unittests/retrieval/test_precision_recall_curve.py +++ b/tests/unittests/retrieval/test_precision_recall_curve.py @@ -23,8 +23,8 @@ from torchmetrics.retrieval.base import _retrieval_aggregate from typing_extensions import Literal -from unittests.helpers import seed_all -from unittests.helpers.testers import Metric, MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import Metric, MetricTester from unittests.retrieval.helpers import _custom_aggregate_fn, _default_metric_class_input_arguments, get_group_indexes from unittests.retrieval.test_precision import _precision_at_k from unittests.retrieval.test_recall import _recall_at_k diff --git a/tests/unittests/retrieval/test_r_precision.py b/tests/unittests/retrieval/test_r_precision.py index 6343de74ea3..50a0384a58d 100644 --- a/tests/unittests/retrieval/test_r_precision.py +++ b/tests/unittests/retrieval/test_r_precision.py @@ -20,7 +20,7 @@ from torchmetrics.retrieval.r_precision import RetrievalRPrecision from typing_extensions import Literal -from unittests.helpers import seed_all +from unittests._helpers import seed_all from unittests.retrieval.helpers import ( RetrievalMetricTester, _concat_tests, diff --git a/tests/unittests/retrieval/test_recall.py b/tests/unittests/retrieval/test_recall.py index c5de7e6600d..24ff4b6a756 100644 --- a/tests/unittests/retrieval/test_recall.py +++ b/tests/unittests/retrieval/test_recall.py @@ -20,7 +20,7 @@ from torchmetrics.retrieval.recall import RetrievalRecall from typing_extensions import Literal -from unittests.helpers import seed_all +from unittests._helpers import seed_all from unittests.retrieval.helpers import ( RetrievalMetricTester, _concat_tests, diff --git a/tests/unittests/text/helpers.py b/tests/unittests/text/_helpers.py similarity index 98% rename from tests/unittests/text/helpers.py rename to tests/unittests/text/_helpers.py index e86607518ea..678244343b9 100644 --- a/tests/unittests/text/helpers.py +++ b/tests/unittests/text/_helpers.py @@ -23,8 +23,9 @@ from torchmetrics import Metric from unittests import NUM_PROCESSES, _reference_cachier -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester, _assert_allclose, _assert_requires_grad, _assert_tensor +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester, _assert_allclose, _assert_requires_grad, _assert_tensor, \ + _select_rand_best_device TEXT_METRIC_INPUT = Union[Sequence[str], Sequence[Sequence[str]], Sequence[Sequence[Sequence[str]]]] NUM_BATCHES = 2 @@ -288,7 +289,6 @@ def run_functional_metric_test( """ seed_all(42) - device = "cuda" if torch.cuda.device_count() > 0 else "cpu" _functional_test( preds=preds, @@ -297,7 +297,7 @@ def run_functional_metric_test( reference_metric=reference_metric, metric_args=metric_args, atol=self.atol, - device=device, + device=_select_rand_best_device(), fragment_kwargs=fragment_kwargs, key=key, **kwargs_update, @@ -352,7 +352,7 @@ def run_class_metric_test( "reference_metric": reference_metric, "metric_args": metric_args or {}, "atol": self.atol, - "device": "cuda" if torch.cuda.is_available() else "cpu", + "device": _select_rand_best_device(), "dist_sync_on_step": dist_sync_on_step, "check_dist_sync_on_step": check_dist_sync_on_step, "check_batch": check_batch, diff --git a/tests/unittests/text/_inputs.py b/tests/unittests/text/_inputs.py index 9d976864623..595b07cd87c 100644 --- a/tests/unittests/text/_inputs.py +++ b/tests/unittests/text/_inputs.py @@ -17,7 +17,7 @@ from torch import Tensor from unittests import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, NUM_CLASSES, _Input -from unittests.helpers import seed_all +from unittests._helpers import seed_all seed_all(1) diff --git a/tests/unittests/text/test_bertscore.py b/tests/unittests/text/test_bertscore.py index b651ecfddb1..93c00a3de52 100644 --- a/tests/unittests/text/test_bertscore.py +++ b/tests/unittests/text/test_bertscore.py @@ -22,7 +22,7 @@ from torchmetrics.utilities.imports import _TRANSFORMERS_GREATER_EQUAL_4_4 from typing_extensions import Literal -from unittests.helpers import skip_on_connection_issues +from unittests._helpers import skip_on_connection_issues from unittests.text._inputs import _inputs_single_reference from unittests.text.helpers import TextTester diff --git a/tests/unittests/text/test_infolm.py b/tests/unittests/text/test_infolm.py index d8611695ff3..d4e556bdc72 100644 --- a/tests/unittests/text/test_infolm.py +++ b/tests/unittests/text/test_infolm.py @@ -19,7 +19,7 @@ from torchmetrics.text.infolm import InfoLM from torchmetrics.utilities.imports import _TRANSFORMERS_GREATER_EQUAL_4_4 -from unittests.helpers import skip_on_connection_issues +from unittests._helpers import skip_on_connection_issues from unittests.text._inputs import HYPOTHESIS_A, HYPOTHESIS_C, _inputs_single_reference from unittests.text.helpers import TextTester diff --git a/tests/unittests/text/test_mer.py b/tests/unittests/text/test_mer.py index e6f5222c3b1..f8bcf0aec9f 100644 --- a/tests/unittests/text/test_mer.py +++ b/tests/unittests/text/test_mer.py @@ -17,7 +17,7 @@ from torchmetrics.functional.text.mer import match_error_rate from torchmetrics.text.mer import MatchErrorRate -from unittests.helpers import seed_all +from unittests._helpers import seed_all from unittests.text._inputs import _inputs_error_rate_batch_size_1, _inputs_error_rate_batch_size_2 from unittests.text.helpers import TextTester diff --git a/tests/unittests/text/test_perplexity.py b/tests/unittests/text/test_perplexity.py index 3673da47647..42930ec21ff 100644 --- a/tests/unittests/text/test_perplexity.py +++ b/tests/unittests/text/test_perplexity.py @@ -20,7 +20,7 @@ from torchmetrics.text.perplexity import Perplexity from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_2 -from unittests.helpers.testers import MetricTester +from unittests._helpers.testers import MetricTester from unittests.text._inputs import ( MASK_INDEX, _logits_inputs_fp32, diff --git a/tests/unittests/text/test_rouge.py b/tests/unittests/text/test_rouge.py index c9eec8a055a..26fdb631d40 100644 --- a/tests/unittests/text/test_rouge.py +++ b/tests/unittests/text/test_rouge.py @@ -24,7 +24,7 @@ from torchmetrics.utilities.imports import _NLTK_AVAILABLE, _ROUGE_SCORE_AVAILABLE from typing_extensions import Literal -from unittests.helpers import skip_on_connection_issues +from unittests._helpers import skip_on_connection_issues from unittests.text._inputs import _Input, _inputs_multiple_references, _inputs_single_sentence_single_reference from unittests.text.helpers import TextTester diff --git a/tests/unittests/text/test_squad.py b/tests/unittests/text/test_squad.py index 5bb4cc0c7fa..8a3d26af8ab 100644 --- a/tests/unittests/text/test_squad.py +++ b/tests/unittests/text/test_squad.py @@ -20,7 +20,7 @@ from torchmetrics.functional.text import squad from torchmetrics.text.squad import SQuAD -from unittests.helpers.testers import _assert_allclose, _assert_tensor +from unittests._helpers.testers import _assert_allclose, _assert_tensor from unittests.text._inputs import _inputs_squad_batch_match, _inputs_squad_exact_match, _inputs_squad_exact_mismatch diff --git a/tests/unittests/utilities/test_auc.py b/tests/unittests/utilities/test_auc.py index 37d9c1105ee..887f4c2d14b 100644 --- a/tests/unittests/utilities/test_auc.py +++ b/tests/unittests/utilities/test_auc.py @@ -20,8 +20,8 @@ from torch import Tensor, tensor from torchmetrics.utilities.compute import auc from unittests import NUM_BATCHES -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/wrappers/test_bootstrapping.py b/tests/unittests/wrappers/test_bootstrapping.py index b02b5034c75..42d890cf728 100644 --- a/tests/unittests/wrappers/test_bootstrapping.py +++ b/tests/unittests/wrappers/test_bootstrapping.py @@ -25,7 +25,7 @@ from torchmetrics.regression import MeanSquaredError from torchmetrics.wrappers.bootstrapping import BootStrapper, _bootstrap_sampler -from unittests.helpers import seed_all +from unittests._helpers import seed_all seed_all(42) diff --git a/tests/unittests/wrappers/test_minmax.py b/tests/unittests/wrappers/test_minmax.py index fe406537baf..480fd1f8f7c 100644 --- a/tests/unittests/wrappers/test_minmax.py +++ b/tests/unittests/wrappers/test_minmax.py @@ -23,8 +23,8 @@ from torchmetrics.wrappers import MinMaxMetric from unittests import BATCH_SIZE, NUM_BATCHES, NUM_CLASSES -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/wrappers/test_multioutput.py b/tests/unittests/wrappers/test_multioutput.py index 7b807e7741e..3d95b20d69a 100644 --- a/tests/unittests/wrappers/test_multioutput.py +++ b/tests/unittests/wrappers/test_multioutput.py @@ -25,8 +25,8 @@ from torchmetrics.wrappers.multioutput import MultioutputWrapper from unittests import BATCH_SIZE, NUM_BATCHES, NUM_CLASSES, _Input -from unittests.helpers import seed_all -from unittests.helpers.testers import MetricTester +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/wrappers/test_multitask.py b/tests/unittests/wrappers/test_multitask.py index 9d8f61323ef..63af6f31b35 100644 --- a/tests/unittests/wrappers/test_multitask.py +++ b/tests/unittests/wrappers/test_multitask.py @@ -22,7 +22,7 @@ from torchmetrics.wrappers import MultitaskWrapper from unittests import BATCH_SIZE, NUM_BATCHES -from unittests.helpers import seed_all +from unittests._helpers import seed_all seed_all(42) diff --git a/tests/unittests/wrappers/test_tracker.py b/tests/unittests/wrappers/test_tracker.py index a6b62db2453..93cdbc452ca 100644 --- a/tests/unittests/wrappers/test_tracker.py +++ b/tests/unittests/wrappers/test_tracker.py @@ -24,7 +24,7 @@ from torchmetrics.regression import MeanAbsoluteError, MeanSquaredError from torchmetrics.wrappers import MetricTracker, MultioutputWrapper -from unittests.helpers import seed_all +from unittests._helpers import seed_all seed_all(42)