From 9f881774d8eaeebef5c523c69638e4f73ea25027 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 13 Mar 2024 15:51:53 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/unittests/audio/test_c_si_snr.py | 2 +- tests/unittests/audio/test_pesq.py | 2 +- tests/unittests/audio/test_pit.py | 2 +- tests/unittests/audio/test_sdr.py | 2 +- tests/unittests/audio/test_si_sdr.py | 2 +- tests/unittests/audio/test_snr.py | 2 +- tests/unittests/audio/test_stoi.py | 2 +- tests/unittests/classification/test_accuracy.py | 2 +- tests/unittests/classification/test_auroc.py | 2 +- tests/unittests/classification/test_average_precision.py | 2 +- tests/unittests/classification/test_calibration_error.py | 2 +- tests/unittests/classification/test_cohen_kappa.py | 2 +- tests/unittests/classification/test_confusion_matrix.py | 2 +- tests/unittests/classification/test_dice.py | 4 ++-- tests/unittests/classification/test_exact_match.py | 2 +- tests/unittests/classification/test_f_beta.py | 2 +- tests/unittests/classification/test_group_fairness.py | 2 +- tests/unittests/classification/test_hamming_distance.py | 2 +- tests/unittests/classification/test_hinge.py | 2 +- tests/unittests/classification/test_jaccard.py | 2 +- tests/unittests/classification/test_matthews_corrcoef.py | 2 +- .../classification/test_precision_fixed_recall.py | 2 +- tests/unittests/classification/test_precision_recall.py | 2 +- .../classification/test_precision_recall_curve.py | 2 +- tests/unittests/classification/test_ranking.py | 2 +- .../classification/test_recall_fixed_precision.py | 2 +- tests/unittests/classification/test_roc.py | 2 +- .../classification/test_sensitivity_specificity.py | 2 +- tests/unittests/classification/test_specificity.py | 2 +- .../classification/test_specificity_sensitivity.py | 2 +- tests/unittests/classification/test_stat_scores.py | 2 +- .../clustering/test_adjusted_mutual_info_score.py | 2 +- tests/unittests/clustering/test_adjusted_rand_score.py | 2 +- .../unittests/clustering/test_calinski_harabasz_score.py | 2 +- tests/unittests/clustering/test_davies_bouldin_score.py | 2 +- tests/unittests/clustering/test_dunn_index.py | 4 ++-- tests/unittests/clustering/test_fowlkes_mallows_index.py | 2 +- .../test_homogeneity_completeness_v_measure.py | 2 +- tests/unittests/clustering/test_mutual_info_score.py | 2 +- .../clustering/test_normalized_mutual_info_score.py | 2 +- tests/unittests/clustering/test_rand_score.py | 2 +- tests/unittests/detection/test_map.py | 2 +- tests/unittests/text/_helpers.py | 9 +++++++-- 43 files changed, 51 insertions(+), 46 deletions(-) diff --git a/tests/unittests/audio/test_c_si_snr.py b/tests/unittests/audio/test_c_si_snr.py index eb160410ca4..2ed148aef65 100644 --- a/tests/unittests/audio/test_c_si_snr.py +++ b/tests/unittests/audio/test_c_si_snr.py @@ -19,9 +19,9 @@ from torchmetrics.functional.audio import complex_scale_invariant_signal_noise_ratio from unittests import BATCH_SIZE, NUM_BATCHES, _Input -from unittests.audio import _SAMPLE_AUDIO_SPEECH, _SAMPLE_AUDIO_SPEECH_BAB_DB from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester +from unittests.audio import _SAMPLE_AUDIO_SPEECH, _SAMPLE_AUDIO_SPEECH_BAB_DB seed_all(42) diff --git a/tests/unittests/audio/test_pesq.py b/tests/unittests/audio/test_pesq.py index d5124b79fbe..dd0e3caba9c 100644 --- a/tests/unittests/audio/test_pesq.py +++ b/tests/unittests/audio/test_pesq.py @@ -22,9 +22,9 @@ from torchmetrics.functional.audio import perceptual_evaluation_speech_quality from unittests import _Input -from unittests.audio import _SAMPLE_AUDIO_SPEECH, _SAMPLE_AUDIO_SPEECH_BAB_DB, _average_metric_wrapper from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester +from unittests.audio import _SAMPLE_AUDIO_SPEECH, _SAMPLE_AUDIO_SPEECH_BAB_DB, _average_metric_wrapper seed_all(42) diff --git a/tests/unittests/audio/test_pit.py b/tests/unittests/audio/test_pit.py index 920d6082065..85baab5e045 100644 --- a/tests/unittests/audio/test_pit.py +++ b/tests/unittests/audio/test_pit.py @@ -31,9 +31,9 @@ ) from unittests import BATCH_SIZE, NUM_BATCHES, _Input -from unittests.audio import _average_metric_wrapper from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester +from unittests.audio import _average_metric_wrapper seed_all(42) diff --git a/tests/unittests/audio/test_sdr.py b/tests/unittests/audio/test_sdr.py index 537249da1eb..61257588606 100644 --- a/tests/unittests/audio/test_sdr.py +++ b/tests/unittests/audio/test_sdr.py @@ -24,9 +24,9 @@ from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_11 from unittests import _Input -from unittests.audio import _SAMPLE_AUDIO_SPEECH, _SAMPLE_AUDIO_SPEECH_BAB_DB, _SAMPLE_NUMPY_ISSUE_895 from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester +from unittests.audio import _SAMPLE_AUDIO_SPEECH, _SAMPLE_AUDIO_SPEECH_BAB_DB, _SAMPLE_NUMPY_ISSUE_895 seed_all(42) diff --git a/tests/unittests/audio/test_si_sdr.py b/tests/unittests/audio/test_si_sdr.py index ee9c6ff6d4b..d8b8f78c5cc 100644 --- a/tests/unittests/audio/test_si_sdr.py +++ b/tests/unittests/audio/test_si_sdr.py @@ -21,9 +21,9 @@ from torchmetrics.functional.audio import scale_invariant_signal_distortion_ratio from unittests import BATCH_SIZE, NUM_BATCHES, _Input -from unittests.audio import _average_metric_wrapper from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester +from unittests.audio import _average_metric_wrapper seed_all(42) diff --git a/tests/unittests/audio/test_snr.py b/tests/unittests/audio/test_snr.py index 917010d5dd4..332707028ff 100644 --- a/tests/unittests/audio/test_snr.py +++ b/tests/unittests/audio/test_snr.py @@ -21,9 +21,9 @@ from torchmetrics.functional.audio import signal_noise_ratio from unittests import _Input -from unittests.audio import _average_metric_wrapper from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester +from unittests.audio import _average_metric_wrapper seed_all(42) diff --git a/tests/unittests/audio/test_stoi.py b/tests/unittests/audio/test_stoi.py index af2c828d838..54374098779 100644 --- a/tests/unittests/audio/test_stoi.py +++ b/tests/unittests/audio/test_stoi.py @@ -22,9 +22,9 @@ from torchmetrics.functional.audio import short_time_objective_intelligibility from unittests import _Input -from unittests.audio import _SAMPLE_AUDIO_SPEECH, _SAMPLE_AUDIO_SPEECH_BAB_DB, _average_metric_wrapper from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester +from unittests.audio import _SAMPLE_AUDIO_SPEECH, _SAMPLE_AUDIO_SPEECH_BAB_DB, _average_metric_wrapper seed_all(42) diff --git a/tests/unittests/classification/test_accuracy.py b/tests/unittests/classification/test_accuracy.py index cf751ae9a57..db497cdb197 100644 --- a/tests/unittests/classification/test_accuracy.py +++ b/tests/unittests/classification/test_accuracy.py @@ -29,9 +29,9 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES, THRESHOLD -from unittests.classification._inputs import _binary_cases, _input_binary, _multiclass_cases, _multilabel_cases from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index +from unittests.classification._inputs import _binary_cases, _input_binary, _multiclass_cases, _multilabel_cases seed_all(42) diff --git a/tests/unittests/classification/test_auroc.py b/tests/unittests/classification/test_auroc.py index a2a60d3ad62..3691c7305b7 100644 --- a/tests/unittests/classification/test_auroc.py +++ b/tests/unittests/classification/test_auroc.py @@ -26,9 +26,9 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES -from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index +from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases seed_all(42) diff --git a/tests/unittests/classification/test_average_precision.py b/tests/unittests/classification/test_average_precision.py index eedcd2794df..2ff0274c93c 100644 --- a/tests/unittests/classification/test_average_precision.py +++ b/tests/unittests/classification/test_average_precision.py @@ -35,9 +35,9 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES -from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index +from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases seed_all(42) diff --git a/tests/unittests/classification/test_calibration_error.py b/tests/unittests/classification/test_calibration_error.py index 0d176ca7fbc..a2000cc984e 100644 --- a/tests/unittests/classification/test_calibration_error.py +++ b/tests/unittests/classification/test_calibration_error.py @@ -32,9 +32,9 @@ from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_13 from unittests import NUM_CLASSES -from unittests.classification._inputs import _binary_cases, _multiclass_cases from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index +from unittests.classification._inputs import _binary_cases, _multiclass_cases seed_all(42) diff --git a/tests/unittests/classification/test_cohen_kappa.py b/tests/unittests/classification/test_cohen_kappa.py index 724e74ebcd9..40ebbc028bd 100644 --- a/tests/unittests/classification/test_cohen_kappa.py +++ b/tests/unittests/classification/test_cohen_kappa.py @@ -23,9 +23,9 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES, THRESHOLD -from unittests.classification._inputs import _binary_cases, _multiclass_cases from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index +from unittests.classification._inputs import _binary_cases, _multiclass_cases seed_all(42) diff --git a/tests/unittests/classification/test_confusion_matrix.py b/tests/unittests/classification/test_confusion_matrix.py index bcf02d679d4..666e9f0fc05 100644 --- a/tests/unittests/classification/test_confusion_matrix.py +++ b/tests/unittests/classification/test_confusion_matrix.py @@ -32,9 +32,9 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES, THRESHOLD -from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index +from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases seed_all(42) diff --git a/tests/unittests/classification/test_dice.py b/tests/unittests/classification/test_dice.py index 211c210bb86..6854265d3d9 100644 --- a/tests/unittests/classification/test_dice.py +++ b/tests/unittests/classification/test_dice.py @@ -23,6 +23,8 @@ from torchmetrics.utilities.checks import _input_format_classification from torchmetrics.utilities.enums import DataType +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester from unittests.classification._inputs import _input_binary, _input_binary_logits, _input_binary_prob from unittests.classification._inputs import _input_multiclass as _input_mcls from unittests.classification._inputs import _input_multiclass_logits as _input_mcls_logits @@ -33,8 +35,6 @@ from unittests.classification._inputs import _input_multilabel_multidim as _input_mlmd from unittests.classification._inputs import _input_multilabel_multidim_prob as _input_mlmd_prob from unittests.classification._inputs import _input_multilabel_prob as _input_mlb_prob -from unittests._helpers import seed_all -from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/classification/test_exact_match.py b/tests/unittests/classification/test_exact_match.py index 44e2dbe22e1..5afd4c00e40 100644 --- a/tests/unittests/classification/test_exact_match.py +++ b/tests/unittests/classification/test_exact_match.py @@ -22,9 +22,9 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES, THRESHOLD -from unittests.classification._inputs import _multiclass_cases, _multilabel_cases from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester, inject_ignore_index +from unittests.classification._inputs import _multiclass_cases, _multilabel_cases seed_all(42) diff --git a/tests/unittests/classification/test_f_beta.py b/tests/unittests/classification/test_f_beta.py index 0ab54ac341f..3a334708485 100644 --- a/tests/unittests/classification/test_f_beta.py +++ b/tests/unittests/classification/test_f_beta.py @@ -42,9 +42,9 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES, THRESHOLD -from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index +from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases seed_all(42) diff --git a/tests/unittests/classification/test_group_fairness.py b/tests/unittests/classification/test_group_fairness.py index 686839166e5..2f37de97d4f 100644 --- a/tests/unittests/classification/test_group_fairness.py +++ b/tests/unittests/classification/test_group_fairness.py @@ -28,7 +28,6 @@ from torchmetrics.functional.classification.group_fairness import binary_fairness from unittests import THRESHOLD -from unittests.classification._inputs import _group_cases from unittests._helpers import seed_all from unittests._helpers.testers import ( MetricTester, @@ -39,6 +38,7 @@ from unittests._helpers.testers import _assert_allclose as _core_assert_allclose from unittests._helpers.testers import _assert_requires_grad as _core_assert_requires_grad from unittests._helpers.testers import _assert_tensor as _core_assert_tensor +from unittests.classification._inputs import _group_cases seed_all(42) diff --git a/tests/unittests/classification/test_hamming_distance.py b/tests/unittests/classification/test_hamming_distance.py index 6b20a45c847..f7f3686c73b 100644 --- a/tests/unittests/classification/test_hamming_distance.py +++ b/tests/unittests/classification/test_hamming_distance.py @@ -33,9 +33,9 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES, THRESHOLD -from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index +from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases seed_all(42) diff --git a/tests/unittests/classification/test_hinge.py b/tests/unittests/classification/test_hinge.py index 96841f3231c..8f285794d15 100644 --- a/tests/unittests/classification/test_hinge.py +++ b/tests/unittests/classification/test_hinge.py @@ -25,8 +25,8 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES -from unittests.classification._inputs import _binary_cases, _multiclass_cases from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index +from unittests.classification._inputs import _binary_cases, _multiclass_cases torch.manual_seed(42) diff --git a/tests/unittests/classification/test_jaccard.py b/tests/unittests/classification/test_jaccard.py index a5ed8945456..8fa17ca1d32 100644 --- a/tests/unittests/classification/test_jaccard.py +++ b/tests/unittests/classification/test_jaccard.py @@ -33,8 +33,8 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES, THRESHOLD -from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index +from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases def _reference_sklearn_jaccard_index_binary(preds, target, ignore_index=None): diff --git a/tests/unittests/classification/test_matthews_corrcoef.py b/tests/unittests/classification/test_matthews_corrcoef.py index 425dcabd23b..03f649bc0ac 100644 --- a/tests/unittests/classification/test_matthews_corrcoef.py +++ b/tests/unittests/classification/test_matthews_corrcoef.py @@ -32,9 +32,9 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES, THRESHOLD -from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index +from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases seed_all(42) diff --git a/tests/unittests/classification/test_precision_fixed_recall.py b/tests/unittests/classification/test_precision_fixed_recall.py index ddd96500da1..f320d2cf1e9 100644 --- a/tests/unittests/classification/test_precision_fixed_recall.py +++ b/tests/unittests/classification/test_precision_fixed_recall.py @@ -34,9 +34,9 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES -from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index +from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases seed_all(42) diff --git a/tests/unittests/classification/test_precision_recall.py b/tests/unittests/classification/test_precision_recall.py index 98d615618e7..86fbe262aea 100644 --- a/tests/unittests/classification/test_precision_recall.py +++ b/tests/unittests/classification/test_precision_recall.py @@ -42,9 +42,9 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES, THRESHOLD -from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index +from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases seed_all(42) diff --git a/tests/unittests/classification/test_precision_recall_curve.py b/tests/unittests/classification/test_precision_recall_curve.py index a990ada5144..6f78438007e 100644 --- a/tests/unittests/classification/test_precision_recall_curve.py +++ b/tests/unittests/classification/test_precision_recall_curve.py @@ -34,9 +34,9 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES -from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index +from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases seed_all(42) diff --git a/tests/unittests/classification/test_ranking.py b/tests/unittests/classification/test_ranking.py index 9bbcb4f3928..e85d38cde05 100644 --- a/tests/unittests/classification/test_ranking.py +++ b/tests/unittests/classification/test_ranking.py @@ -32,9 +32,9 @@ ) from unittests import NUM_CLASSES -from unittests.classification._inputs import _multilabel_cases from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester, inject_ignore_index +from unittests.classification._inputs import _multilabel_cases seed_all(42) diff --git a/tests/unittests/classification/test_recall_fixed_precision.py b/tests/unittests/classification/test_recall_fixed_precision.py index 7f28f10594d..9bdca8950bd 100644 --- a/tests/unittests/classification/test_recall_fixed_precision.py +++ b/tests/unittests/classification/test_recall_fixed_precision.py @@ -34,9 +34,9 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES -from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index +from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases seed_all(42) diff --git a/tests/unittests/classification/test_roc.py b/tests/unittests/classification/test_roc.py index 4b7c55379e8..167ad4876f0 100644 --- a/tests/unittests/classification/test_roc.py +++ b/tests/unittests/classification/test_roc.py @@ -25,9 +25,9 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES -from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index +from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases seed_all(42) diff --git a/tests/unittests/classification/test_sensitivity_specificity.py b/tests/unittests/classification/test_sensitivity_specificity.py index 13dd2ea2763..d629c86583a 100644 --- a/tests/unittests/classification/test_sensitivity_specificity.py +++ b/tests/unittests/classification/test_sensitivity_specificity.py @@ -36,9 +36,9 @@ from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_11 from unittests import NUM_CLASSES -from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index +from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases seed_all(42) diff --git a/tests/unittests/classification/test_specificity.py b/tests/unittests/classification/test_specificity.py index 7855f95f3cd..fe5fd8977a8 100644 --- a/tests/unittests/classification/test_specificity.py +++ b/tests/unittests/classification/test_specificity.py @@ -33,9 +33,9 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES, THRESHOLD -from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester, inject_ignore_index +from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases seed_all(42) diff --git a/tests/unittests/classification/test_specificity_sensitivity.py b/tests/unittests/classification/test_specificity_sensitivity.py index f5ade54d041..0bafdfe55ea 100644 --- a/tests/unittests/classification/test_specificity_sensitivity.py +++ b/tests/unittests/classification/test_specificity_sensitivity.py @@ -35,9 +35,9 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES -from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index +from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases seed_all(42) diff --git a/tests/unittests/classification/test_stat_scores.py b/tests/unittests/classification/test_stat_scores.py index e820c31a0bf..86c793f8c83 100644 --- a/tests/unittests/classification/test_stat_scores.py +++ b/tests/unittests/classification/test_stat_scores.py @@ -32,9 +32,9 @@ from torchmetrics.metric import Metric from unittests import NUM_CLASSES, THRESHOLD -from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index +from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases seed_all(42) diff --git a/tests/unittests/clustering/test_adjusted_mutual_info_score.py b/tests/unittests/clustering/test_adjusted_mutual_info_score.py index 40a617ad53e..474e221d6a5 100644 --- a/tests/unittests/clustering/test_adjusted_mutual_info_score.py +++ b/tests/unittests/clustering/test_adjusted_mutual_info_score.py @@ -20,9 +20,9 @@ from torchmetrics.functional.clustering.adjusted_mutual_info_score import adjusted_mutual_info_score from unittests import BATCH_SIZE, NUM_CLASSES -from unittests.clustering._inputs import _float_inputs_extrinsic, _single_target_extrinsic1, _single_target_extrinsic2 from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester +from unittests.clustering._inputs import _float_inputs_extrinsic, _single_target_extrinsic1, _single_target_extrinsic2 seed_all(42) diff --git a/tests/unittests/clustering/test_adjusted_rand_score.py b/tests/unittests/clustering/test_adjusted_rand_score.py index 5232314db2e..b98536aad15 100644 --- a/tests/unittests/clustering/test_adjusted_rand_score.py +++ b/tests/unittests/clustering/test_adjusted_rand_score.py @@ -17,8 +17,8 @@ from torchmetrics.clustering.adjusted_rand_score import AdjustedRandScore from torchmetrics.functional.clustering.adjusted_rand_score import adjusted_rand_score -from unittests.clustering._inputs import _float_inputs_extrinsic, _single_target_extrinsic1, _single_target_extrinsic2 from unittests._helpers.testers import MetricTester +from unittests.clustering._inputs import _float_inputs_extrinsic, _single_target_extrinsic1, _single_target_extrinsic2 @pytest.mark.parametrize( diff --git a/tests/unittests/clustering/test_calinski_harabasz_score.py b/tests/unittests/clustering/test_calinski_harabasz_score.py index 9bd6639d178..f81da592389 100644 --- a/tests/unittests/clustering/test_calinski_harabasz_score.py +++ b/tests/unittests/clustering/test_calinski_harabasz_score.py @@ -16,9 +16,9 @@ from torchmetrics.clustering.calinski_harabasz_score import CalinskiHarabaszScore from torchmetrics.functional.clustering.calinski_harabasz_score import calinski_harabasz_score -from unittests.clustering._inputs import _single_target_intrinsic1, _single_target_intrinsic2 from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester +from unittests.clustering._inputs import _single_target_intrinsic1, _single_target_intrinsic2 seed_all(42) diff --git a/tests/unittests/clustering/test_davies_bouldin_score.py b/tests/unittests/clustering/test_davies_bouldin_score.py index 2614789c1d4..bea2018c2cc 100644 --- a/tests/unittests/clustering/test_davies_bouldin_score.py +++ b/tests/unittests/clustering/test_davies_bouldin_score.py @@ -16,9 +16,9 @@ from torchmetrics.clustering.davies_bouldin_score import DaviesBouldinScore from torchmetrics.functional.clustering.davies_bouldin_score import davies_bouldin_score -from unittests.clustering._inputs import _single_target_intrinsic1, _single_target_intrinsic2 from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester +from unittests.clustering._inputs import _single_target_intrinsic1, _single_target_intrinsic2 seed_all(42) diff --git a/tests/unittests/clustering/test_dunn_index.py b/tests/unittests/clustering/test_dunn_index.py index 84a80b0f913..c2e6adcd2cb 100644 --- a/tests/unittests/clustering/test_dunn_index.py +++ b/tests/unittests/clustering/test_dunn_index.py @@ -19,12 +19,12 @@ from torchmetrics.clustering.dunn_index import DunnIndex from torchmetrics.functional.clustering.dunn_index import dunn_index +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester from unittests.clustering._inputs import ( _single_target_intrinsic1, _single_target_intrinsic2, ) -from unittests._helpers import seed_all -from unittests._helpers.testers import MetricTester seed_all(42) diff --git a/tests/unittests/clustering/test_fowlkes_mallows_index.py b/tests/unittests/clustering/test_fowlkes_mallows_index.py index f4949c74448..6e5674ae337 100644 --- a/tests/unittests/clustering/test_fowlkes_mallows_index.py +++ b/tests/unittests/clustering/test_fowlkes_mallows_index.py @@ -16,9 +16,9 @@ from torchmetrics.clustering import FowlkesMallowsIndex from torchmetrics.functional.clustering import fowlkes_mallows_index -from unittests.clustering._inputs import _single_target_extrinsic1, _single_target_extrinsic2 from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester +from unittests.clustering._inputs import _single_target_extrinsic1, _single_target_extrinsic2 seed_all(42) diff --git a/tests/unittests/clustering/test_homogeneity_completeness_v_measure.py b/tests/unittests/clustering/test_homogeneity_completeness_v_measure.py index 3c3e1c4dace..dd716182b4b 100644 --- a/tests/unittests/clustering/test_homogeneity_completeness_v_measure.py +++ b/tests/unittests/clustering/test_homogeneity_completeness_v_measure.py @@ -28,9 +28,9 @@ v_measure_score, ) -from unittests.clustering._inputs import _float_inputs_extrinsic, _single_target_extrinsic1, _single_target_extrinsic2 from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester +from unittests.clustering._inputs import _float_inputs_extrinsic, _single_target_extrinsic1, _single_target_extrinsic2 seed_all(42) diff --git a/tests/unittests/clustering/test_mutual_info_score.py b/tests/unittests/clustering/test_mutual_info_score.py index 09b82ee37f3..ab9222b8082 100644 --- a/tests/unittests/clustering/test_mutual_info_score.py +++ b/tests/unittests/clustering/test_mutual_info_score.py @@ -18,9 +18,9 @@ from torchmetrics.functional.clustering.mutual_info_score import mutual_info_score from unittests import BATCH_SIZE, NUM_CLASSES -from unittests.clustering._inputs import _float_inputs_extrinsic, _single_target_extrinsic1, _single_target_extrinsic2 from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester +from unittests.clustering._inputs import _float_inputs_extrinsic, _single_target_extrinsic1, _single_target_extrinsic2 seed_all(42) diff --git a/tests/unittests/clustering/test_normalized_mutual_info_score.py b/tests/unittests/clustering/test_normalized_mutual_info_score.py index 47696f4b13c..07109771b0f 100644 --- a/tests/unittests/clustering/test_normalized_mutual_info_score.py +++ b/tests/unittests/clustering/test_normalized_mutual_info_score.py @@ -20,9 +20,9 @@ from torchmetrics.functional.clustering import normalized_mutual_info_score from unittests import BATCH_SIZE, NUM_CLASSES -from unittests.clustering._inputs import _float_inputs_extrinsic, _single_target_extrinsic1, _single_target_extrinsic2 from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester +from unittests.clustering._inputs import _float_inputs_extrinsic, _single_target_extrinsic1, _single_target_extrinsic2 seed_all(42) diff --git a/tests/unittests/clustering/test_rand_score.py b/tests/unittests/clustering/test_rand_score.py index 3c1189e5702..824f5e90fbd 100644 --- a/tests/unittests/clustering/test_rand_score.py +++ b/tests/unittests/clustering/test_rand_score.py @@ -17,9 +17,9 @@ from torchmetrics.clustering.rand_score import RandScore from torchmetrics.functional.clustering.rand_score import rand_score -from unittests.clustering._inputs import _float_inputs_extrinsic, _single_target_extrinsic1, _single_target_extrinsic2 from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester +from unittests.clustering._inputs import _float_inputs_extrinsic, _single_target_extrinsic1, _single_target_extrinsic2 seed_all(42) diff --git a/tests/unittests/detection/test_map.py b/tests/unittests/detection/test_map.py index 9cb3fe22dc1..f0dcdff52f2 100644 --- a/tests/unittests/detection/test_map.py +++ b/tests/unittests/detection/test_map.py @@ -32,8 +32,8 @@ _TORCHVISION_GREATER_EQUAL_0_8, ) -from unittests.detection import _DETECTION_BBOX, _DETECTION_SEGM, _DETECTION_VAL from unittests._helpers.testers import MetricTester +from unittests.detection import _DETECTION_BBOX, _DETECTION_SEGM, _DETECTION_VAL _pytest_condition = not (_PYCOCOTOOLS_AVAILABLE and _TORCHVISION_GREATER_EQUAL_0_8) diff --git a/tests/unittests/text/_helpers.py b/tests/unittests/text/_helpers.py index 678244343b9..1f237d3e96f 100644 --- a/tests/unittests/text/_helpers.py +++ b/tests/unittests/text/_helpers.py @@ -24,8 +24,13 @@ from unittests import NUM_PROCESSES, _reference_cachier from unittests._helpers import seed_all -from unittests._helpers.testers import MetricTester, _assert_allclose, _assert_requires_grad, _assert_tensor, \ - _select_rand_best_device +from unittests._helpers.testers import ( + MetricTester, + _assert_allclose, + _assert_requires_grad, + _assert_tensor, + _select_rand_best_device, +) TEXT_METRIC_INPUT = Union[Sequence[str], Sequence[Sequence[str]], Sequence[Sequence[Sequence[str]]]] NUM_BATCHES = 2