From 48981083a6c694cd941648b435566bd163d6d462 Mon Sep 17 00:00:00 2001 From: Iman Gohari Date: Fri, 10 Jan 2025 18:53:10 +0000 Subject: [PATCH 1/2] fea(): reworked the 8x hpu skipping strategy --- optimum/habana/utils.py | 11 +++++++++++ tests/test_diffusers.py | 17 ++++++++++++++--- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/optimum/habana/utils.py b/optimum/habana/utils.py index 2225cb8c89..c23c8c840c 100755 --- a/optimum/habana/utils.py +++ b/optimum/habana/utils.py @@ -403,3 +403,14 @@ def get_device_name(): return "gaudi3" else: raise ValueError(f"Unsupported device: the device type is {device_type}.") + +def get_device_count(): + """ + Returns the number of the current gaudi devices + """ + import habana_frameworks.torch.utils.experimental as htexp + + if htexp.hpu.is_available(): + return htexp.hpu.device_count() + else: + raise ValueError(f"No hpu is found avail on this system") diff --git a/tests/test_diffusers.py b/tests/test_diffusers.py index f514318988..0bd15bd109 100644 --- a/tests/test_diffusers.py +++ b/tests/test_diffusers.py @@ -31,7 +31,6 @@ from unittest import TestCase, skipUnless import diffusers -import habana_frameworks.torch.hpu as hthpu import numpy as np import PIL import pytest @@ -189,6 +188,18 @@ def check_gated_model_access(model): return pytest.mark.skipif(gated, reason=f"{model} is gated, please log in with approved HF access token") +def check_8xhpu(test_case): + """ + Decorator marking a test as it requires 8xHPU on system + """ + from optimum.habana.utils import get_device_count + + if (get_device_count() != 8): + skip = True + else: + skip = False + + return pytest.mark.skipif(skip, reason=f"test requires 8xHPU multi-card system")(test_case) class GaudiPipelineUtilsTester(TestCase): """ @@ -781,7 +792,7 @@ def test_no_generation_regression_upscale(self): self.assertLess(np.abs(expected_slice - upscaled_image[-3:, -3:, -1].flatten()).max(), 5e-3) @slow - @pytest.mark.skipif(hthpu.is_available() and hthpu.device_count() != 8, reason="system does not have 8 cards") + @check_8xhpu def test_sd_textual_inversion(self): path_to_script = ( Path(os.path.dirname(__file__)).parent @@ -2450,7 +2461,7 @@ def test_script_train_controlnet(self): self.assertEqual(return_code, 0) @slow - @pytest.mark.skipif(hthpu.is_available() and hthpu.device_count() != 8, reason="system does not have 8 cards") + @check_8xhpu def test_train_controlnet(self): with tempfile.TemporaryDirectory() as tmpdir: path_to_script = ( From bb5b12341901ca262067515bab6b09242e82a78d Mon Sep 17 00:00:00 2001 From: Iman Gohari Date: Thu, 30 Jan 2025 17:01:56 +0000 Subject: [PATCH 2/2] make style --- optimum/habana/utils.py | 3 ++- tests/test_diffusers.py | 6 ++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/optimum/habana/utils.py b/optimum/habana/utils.py index c23c8c840c..244b52e203 100755 --- a/optimum/habana/utils.py +++ b/optimum/habana/utils.py @@ -404,6 +404,7 @@ def get_device_name(): else: raise ValueError(f"Unsupported device: the device type is {device_type}.") + def get_device_count(): """ Returns the number of the current gaudi devices @@ -413,4 +414,4 @@ def get_device_count(): if htexp.hpu.is_available(): return htexp.hpu.device_count() else: - raise ValueError(f"No hpu is found avail on this system") + raise ValueError("No hpu is found avail on this system") diff --git a/tests/test_diffusers.py b/tests/test_diffusers.py index 0bd15bd109..4bf6505c46 100644 --- a/tests/test_diffusers.py +++ b/tests/test_diffusers.py @@ -188,18 +188,20 @@ def check_gated_model_access(model): return pytest.mark.skipif(gated, reason=f"{model} is gated, please log in with approved HF access token") + def check_8xhpu(test_case): """ Decorator marking a test as it requires 8xHPU on system """ from optimum.habana.utils import get_device_count - if (get_device_count() != 8): + if get_device_count() != 8: skip = True else: skip = False - return pytest.mark.skipif(skip, reason=f"test requires 8xHPU multi-card system")(test_case) + return pytest.mark.skipif(skip, reason="test requires 8xHPU multi-card system")(test_case) + class GaudiPipelineUtilsTester(TestCase): """