Skip to content

Commit

Permalink
_select_rand_best_device()
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Mar 13, 2024
1 parent 44f9395 commit bc76193
Show file tree
Hide file tree
Showing 123 changed files with 223 additions and 214 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import numpy
import torch

from unittests.helpers.wrappers import skip_on_connection_issues, skip_on_running_out_of_memory
from unittests._helpers.wrappers import skip_on_connection_issues, skip_on_running_out_of_memory


def seed_all(seed):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import sys
from copy import deepcopy
from functools import partial
from random import randrange
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -337,6 +338,16 @@ def _assert_dtype_support(
_assert_tensor(metric_functional(y_hat, y, **kwargs_update))


def _select_rand_best_device() -> str:
"""Select the best device to run tests on."""
nb_gpus = torch.cuda.device_count()
if nb_gpus > 1:
return f"cuda:{randrange(nb_gpus)}"
if nb_gpus:
return "cuda"
return "cpu"


class MetricTester:
"""Test class for all metrics.
Expand Down Expand Up @@ -371,16 +382,14 @@ def run_functional_metric_test(
target when running update on the metric.
"""
device = "cuda" if (torch.cuda.is_available() and torch.cuda.device_count() > 0) else "cpu"

_functional_test(
preds=preds,
target=target,
metric_functional=metric_functional,
reference_metric=reference_metric,
metric_args=metric_args,
atol=self.atol,
device=device,
device=_select_rand_best_device(),
fragment_kwargs=fragment_kwargs,
**kwargs_update,
)
Expand Down Expand Up @@ -431,7 +440,7 @@ def run_class_metric_test(
"reference_metric": reference_metric,
"metric_args": metric_args or {},
"atol": atol or self.atol,
"device": "cuda" if torch.cuda.is_available() else "cpu",
"device": _select_rand_best_device(),
"dist_sync_on_step": dist_sync_on_step,
"check_dist_sync_on_step": check_dist_sync_on_step,
"check_batch": check_batch,
Expand Down
File renamed without changes.
4 changes: 2 additions & 2 deletions tests/unittests/audio/test_c_si_snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@

from unittests import BATCH_SIZE, NUM_BATCHES, _Input
from unittests.audio import _SAMPLE_AUDIO_SPEECH, _SAMPLE_AUDIO_SPEECH_BAB_DB
from unittests.helpers import seed_all
from unittests.helpers.testers import MetricTester
from unittests._helpers import seed_all
from unittests._helpers.testers import MetricTester

seed_all(42)

Expand Down
4 changes: 2 additions & 2 deletions tests/unittests/audio/test_pesq.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@

from unittests import _Input
from unittests.audio import _SAMPLE_AUDIO_SPEECH, _SAMPLE_AUDIO_SPEECH_BAB_DB, _average_metric_wrapper
from unittests.helpers import seed_all
from unittests.helpers.testers import MetricTester
from unittests._helpers import seed_all
from unittests._helpers.testers import MetricTester

seed_all(42)

Expand Down
4 changes: 2 additions & 2 deletions tests/unittests/audio/test_pit.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@

from unittests import BATCH_SIZE, NUM_BATCHES, _Input
from unittests.audio import _average_metric_wrapper
from unittests.helpers import seed_all
from unittests.helpers.testers import MetricTester
from unittests._helpers import seed_all
from unittests._helpers.testers import MetricTester

seed_all(42)

Expand Down
4 changes: 2 additions & 2 deletions tests/unittests/audio/test_sa_sdr.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
)

from unittests import BATCH_SIZE, NUM_BATCHES, _Input
from unittests.helpers import seed_all
from unittests.helpers.testers import MetricTester
from unittests._helpers import seed_all
from unittests._helpers.testers import MetricTester

seed_all(42)

Expand Down
4 changes: 2 additions & 2 deletions tests/unittests/audio/test_sdr.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@

from unittests import _Input
from unittests.audio import _SAMPLE_AUDIO_SPEECH, _SAMPLE_AUDIO_SPEECH_BAB_DB, _SAMPLE_NUMPY_ISSUE_895
from unittests.helpers import seed_all
from unittests.helpers.testers import MetricTester
from unittests._helpers import seed_all
from unittests._helpers.testers import MetricTester

seed_all(42)

Expand Down
4 changes: 2 additions & 2 deletions tests/unittests/audio/test_si_sdr.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@

from unittests import BATCH_SIZE, NUM_BATCHES, _Input
from unittests.audio import _average_metric_wrapper
from unittests.helpers import seed_all
from unittests.helpers.testers import MetricTester
from unittests._helpers import seed_all
from unittests._helpers.testers import MetricTester

seed_all(42)

Expand Down
4 changes: 2 additions & 2 deletions tests/unittests/audio/test_si_snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
from torchmetrics.functional.audio import scale_invariant_signal_noise_ratio

from unittests import BATCH_SIZE, NUM_BATCHES, _Input
from unittests.helpers import seed_all
from unittests.helpers.testers import MetricTester
from unittests._helpers import seed_all
from unittests._helpers.testers import MetricTester

seed_all(42)

Expand Down
4 changes: 2 additions & 2 deletions tests/unittests/audio/test_snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@

from unittests import _Input
from unittests.audio import _average_metric_wrapper
from unittests.helpers import seed_all
from unittests.helpers.testers import MetricTester
from unittests._helpers import seed_all
from unittests._helpers.testers import MetricTester

seed_all(42)

Expand Down
4 changes: 2 additions & 2 deletions tests/unittests/audio/test_srmr.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
from torchmetrics.functional.audio.srmr import speech_reverberation_modulation_energy_ratio
from torchmetrics.utilities.imports import _TORCHAUDIO_GREATER_EQUAL_0_10

from unittests.helpers import seed_all
from unittests.helpers.testers import MetricTester
from unittests._helpers import seed_all
from unittests._helpers.testers import MetricTester

seed_all(42)

Expand Down
4 changes: 2 additions & 2 deletions tests/unittests/audio/test_stoi.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@

from unittests import _Input
from unittests.audio import _SAMPLE_AUDIO_SPEECH, _SAMPLE_AUDIO_SPEECH_BAB_DB, _average_metric_wrapper
from unittests.helpers import seed_all
from unittests.helpers.testers import MetricTester
from unittests._helpers import seed_all
from unittests._helpers.testers import MetricTester

seed_all(42)

Expand Down
2 changes: 1 addition & 1 deletion tests/unittests/bases/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torchmetrics.collections import MetricCollection

from unittests import BATCH_SIZE, NUM_BATCHES
from unittests.helpers.testers import MetricTester
from unittests._helpers.testers import MetricTester


def compare_mean(values, weights):
Expand Down
4 changes: 2 additions & 2 deletions tests/unittests/bases/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@
)
from torchmetrics.utilities.checks import _allclose_recursive

from unittests.helpers import seed_all
from unittests.helpers.testers import DummyMetricDiff, DummyMetricMultiOutputDict, DummyMetricSum
from unittests._helpers import seed_all
from unittests._helpers.testers import DummyMetricDiff, DummyMetricMultiOutputDict, DummyMetricSum

seed_all(42)

Expand Down
4 changes: 2 additions & 2 deletions tests/unittests/bases/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
from torchmetrics.utilities.exceptions import TorchMetricsUserError

from unittests import NUM_PROCESSES
from unittests.helpers import seed_all
from unittests.helpers.testers import DummyListMetric, DummyMetric, DummyMetricSum
from unittests._helpers import seed_all
from unittests._helpers.testers import DummyListMetric, DummyMetric, DummyMetricSum

seed_all(42)

Expand Down
2 changes: 1 addition & 1 deletion tests/unittests/bases/test_hashing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from unittests.helpers.testers import DummyListMetric, DummyMetric
from unittests._helpers.testers import DummyListMetric, DummyMetric


@pytest.mark.parametrize(
Expand Down
4 changes: 2 additions & 2 deletions tests/unittests/bases/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
from torchmetrics.classification import BinaryAccuracy
from torchmetrics.regression import PearsonCorrCoef

from unittests.helpers import seed_all
from unittests.helpers.testers import DummyListMetric, DummyMetric, DummyMetricMultiOutput, DummyMetricSum
from unittests._helpers import seed_all
from unittests._helpers.testers import DummyListMetric, DummyMetric, DummyMetricMultiOutput, DummyMetricSum

seed_all(42)

Expand Down
2 changes: 1 addition & 1 deletion tests/unittests/classification/_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from torch import Tensor

from unittests import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, NUM_CLASSES, _GroupInput, _Input
from unittests.helpers import seed_all
from unittests._helpers import seed_all

seed_all(1)

Expand Down
4 changes: 2 additions & 2 deletions tests/unittests/classification/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@

from unittests import NUM_CLASSES, THRESHOLD
from unittests.classification._inputs import _binary_cases, _input_binary, _multiclass_cases, _multilabel_cases
from unittests.helpers import seed_all
from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index
from unittests._helpers import seed_all
from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index

seed_all(42)

Expand Down
4 changes: 2 additions & 2 deletions tests/unittests/classification/test_auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@

from unittests import NUM_CLASSES
from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases
from unittests.helpers import seed_all
from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index
from unittests._helpers import seed_all
from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index

seed_all(42)

Expand Down
4 changes: 2 additions & 2 deletions tests/unittests/classification/test_average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@

from unittests import NUM_CLASSES
from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases
from unittests.helpers import seed_all
from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index
from unittests._helpers import seed_all
from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index

seed_all(42)

Expand Down
4 changes: 2 additions & 2 deletions tests/unittests/classification/test_calibration_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@

from unittests import NUM_CLASSES
from unittests.classification._inputs import _binary_cases, _multiclass_cases
from unittests.helpers import seed_all
from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index
from unittests._helpers import seed_all
from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index

seed_all(42)

Expand Down
4 changes: 2 additions & 2 deletions tests/unittests/classification/test_cohen_kappa.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@

from unittests import NUM_CLASSES, THRESHOLD
from unittests.classification._inputs import _binary_cases, _multiclass_cases
from unittests.helpers import seed_all
from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index
from unittests._helpers import seed_all
from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index

seed_all(42)

Expand Down
4 changes: 2 additions & 2 deletions tests/unittests/classification/test_confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@

from unittests import NUM_CLASSES, THRESHOLD
from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases
from unittests.helpers import seed_all
from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index
from unittests._helpers import seed_all
from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index

seed_all(42)

Expand Down
4 changes: 2 additions & 2 deletions tests/unittests/classification/test_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
from unittests.classification._inputs import _input_multilabel_multidim as _input_mlmd
from unittests.classification._inputs import _input_multilabel_multidim_prob as _input_mlmd_prob
from unittests.classification._inputs import _input_multilabel_prob as _input_mlb_prob
from unittests.helpers import seed_all
from unittests.helpers.testers import MetricTester
from unittests._helpers import seed_all
from unittests._helpers.testers import MetricTester

seed_all(42)

Expand Down
4 changes: 2 additions & 2 deletions tests/unittests/classification/test_exact_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@

from unittests import NUM_CLASSES, THRESHOLD
from unittests.classification._inputs import _multiclass_cases, _multilabel_cases
from unittests.helpers import seed_all
from unittests.helpers.testers import MetricTester, inject_ignore_index
from unittests._helpers import seed_all
from unittests._helpers.testers import MetricTester, inject_ignore_index

seed_all(42)

Expand Down
4 changes: 2 additions & 2 deletions tests/unittests/classification/test_f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@

from unittests import NUM_CLASSES, THRESHOLD
from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases
from unittests.helpers import seed_all
from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index
from unittests._helpers import seed_all
from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index

seed_all(42)

Expand Down
10 changes: 5 additions & 5 deletions tests/unittests/classification/test_group_fairness.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,16 @@

from unittests import THRESHOLD
from unittests.classification._inputs import _group_cases
from unittests.helpers import seed_all
from unittests.helpers.testers import (
from unittests._helpers import seed_all
from unittests._helpers.testers import (
MetricTester,
_assert_dtype_support,
inject_ignore_index,
remove_ignore_index_groups,
)
from unittests.helpers.testers import _assert_allclose as _core_assert_allclose
from unittests.helpers.testers import _assert_requires_grad as _core_assert_requires_grad
from unittests.helpers.testers import _assert_tensor as _core_assert_tensor
from unittests._helpers.testers import _assert_allclose as _core_assert_allclose
from unittests._helpers.testers import _assert_requires_grad as _core_assert_requires_grad
from unittests._helpers.testers import _assert_tensor as _core_assert_tensor

seed_all(42)

Expand Down
4 changes: 2 additions & 2 deletions tests/unittests/classification/test_hamming_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@

from unittests import NUM_CLASSES, THRESHOLD
from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases
from unittests.helpers import seed_all
from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index
from unittests._helpers import seed_all
from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index

seed_all(42)

Expand Down
2 changes: 1 addition & 1 deletion tests/unittests/classification/test_hinge.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from unittests import NUM_CLASSES
from unittests.classification._inputs import _binary_cases, _multiclass_cases
from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index
from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index

torch.manual_seed(42)

Expand Down
2 changes: 1 addition & 1 deletion tests/unittests/classification/test_jaccard.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

from unittests import NUM_CLASSES, THRESHOLD
from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases
from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index
from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index


def _reference_sklearn_jaccard_index_binary(preds, target, ignore_index=None):
Expand Down
4 changes: 2 additions & 2 deletions tests/unittests/classification/test_matthews_corrcoef.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@

from unittests import NUM_CLASSES, THRESHOLD
from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases
from unittests.helpers import seed_all
from unittests.helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index
from unittests._helpers import seed_all
from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index

seed_all(42)

Expand Down
Loading

0 comments on commit bc76193

Please sign in to comment.