Skip to content

Commit

Permalink
mark tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ankitgola005 committed Nov 4, 2024
1 parent 8cd21e5 commit aa330f3
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 5 deletions.
4 changes: 2 additions & 2 deletions .azure/hpu-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@ jobs:
-k test_autocast_operators_override --runxfail \
-W ignore::FutureWarning --junitxml=hpu_precision_test_override-results.xml
env:
LOWER_LIST : tests/test_pytorch/ops_fp32.txt
FP32_LIST : tests/test_pytorch/ops_bf16.txt
LOWER_LIST: tests/test_pytorch/ops_fp32.txt
FP32_LIST: tests/test_pytorch/ops_bf16.txt
displayName: 'HPU precision test'
- bash: |
Expand Down
2 changes: 1 addition & 1 deletion src/lightning_habana/utils/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from lightning_habana.utils.resources import _HABANA_FRAMEWORK_AVAILABLE, get_hpu_synapse_version # noqa: F401

_HPU_SYNAPSE_GREATER_EQUAL_1_17_0 = Version(get_hpu_synapse_version()) >= Version("1.17.0")
_HPU_SYNAPSE_GREATER_EQUAL_1_18_0 = Version(get_hpu_synapse_version()) >= Version("1.18.0")
_HPU_SYNAPSE_GREATER_1_18_0 = Version(get_hpu_synapse_version()) > Version("1.18.0")
_TORCH_LESSER_EQUAL_1_13_1 = compare_version("torch", operator.le, "1.13.1")
_TORCH_GREATER_EQUAL_2_0_0 = compare_version("torch", operator.ge, "2.0.0")
_TORCH_LESSER_2_3_0 = Version(Version(torch.__version__).base_version) < Version("2.3.0")
Expand Down
2 changes: 2 additions & 0 deletions tests/test_pytorch/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from lightning_habana.pytorch.accelerator import HPUAccelerator
from lightning_habana.pytorch.plugins import HPUPrecisionPlugin
from lightning_habana.pytorch.strategies import HPUDDPStrategy, SingleHPUStrategy
from lightning_habana.utils.imports import _HPU_SYNAPSE_GREATER_1_18_0
from lightning_habana.utils.resources import get_device_name_from_hlsmi


Expand Down Expand Up @@ -221,6 +222,7 @@ def test_ddp_strategy_with_compile(tmp_path, arg_hpus):
assert _strategy._ddp_kwargs["find_unused_parameters"] is True


@pytest.mark.skipif(_HPU_SYNAPSE_GREATER_1_18_0, reason="Test valid for Synapse version <= 1.18.0")
@pytest.mark.usefixtures("_is_compile_allowed")
@pytest.mark.parametrize(
("record_module_names", "expectation"),
Expand Down
4 changes: 2 additions & 2 deletions tests/test_pytorch/test_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from lightning_habana.pytorch.plugins import HPUPrecisionPlugin
from lightning_habana.pytorch.plugins.precision import _PRECISION_INPUT
from lightning_habana.pytorch.strategies import HPUDDPStrategy, SingleHPUStrategy
from lightning_habana.utils.imports import _HPU_SYNAPSE_GREATER_EQUAL_1_18_0
from lightning_habana.utils.imports import _HPU_SYNAPSE_GREATER_1_18_0
from lightning_habana.utils.resources import get_device_name_from_hlsmi

supported_precision = get_args(_PRECISION_INPUT)
Expand Down Expand Up @@ -165,7 +165,7 @@ def forward(self, x):


@pytest.mark.xfail(strict=False, reason="Env needs to be set")
@pytest.mark.skipif(_HPU_SYNAPSE_GREATER_EQUAL_1_18_0, reason="Will be fixed in a future synapse version.")
@pytest.mark.skipif(_HPU_SYNAPSE_GREATER_1_18_0, reason="Test valid for Synapse version <= 1.18.0")
def test_autocast_operators_override(tmpdir):
"""Tests operator dtype overriding with torch autocast."""
# The override lists are set in cmdline
Expand Down
4 changes: 4 additions & 0 deletions tests/test_pytorch/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@
if _KINETO_AVAILABLE:
from lightning_habana.pytorch.profiler.profiler import HPUProfiler

from lightning_habana.utils.imports import _HPU_SYNAPSE_GREATER_1_18_0

pytestmark = pytest.mark.skipif(_HPU_SYNAPSE_GREATER_1_18_0, reason="Tests valid for Synapse version <= 1.18.0")


@pytest.fixture()
def _check_distributed(device_count):
Expand Down

0 comments on commit aa330f3

Please sign in to comment.