-
Notifications
You must be signed in to change notification settings - Fork 26
/
Copy pathrunif.py
50 lines (42 loc) · 1.36 KB
/
runif.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import pytest
from lightning.pytorch.accelerators import find_usable_cuda_devices
"""
Simplified from:
https://github.com/ashleve/lightning-hydra-template/blob/main/tests/helpers/runif.py
which adapted it from
https://github.com/PyTorchLightning/pytorch-lightning/blob/master/tests/helpers/runif.py
"""
class RunIf:
"""RunIf wrapper for conditional skipping of tests.
Fully compatible with `@pytest.mark`.
Example:
@RunIf(min_gpus=1)
@pytest.mark.parametrize("arg1", [1.0, 2.0])
def test_wrapper(arg1):
assert arg1 > 0
"""
def __new__(
self,
min_gpus: int = 0,
**kwargs,
):
"""
Args:
min_gpus: min number of gpus required to run test
kwargs: native pytest.mark.skipif keyword arguments
"""
conditions = []
reasons = []
if min_gpus:
try:
find_usable_cuda_devices(min_gpus)
conditions.append(False)
except (ValueError, RuntimeError):
conditions.append(True)
reasons.append(f"GPUs>={min_gpus}")
reasons = [rs for cond, rs in zip(conditions, reasons) if cond]
return pytest.mark.skipif(
condition=any(conditions),
reason=f"Requires: [{' + '.join(reasons)}]",
**kwargs,
)