Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fea(): reworked the 8x hpu skipping strategy #1694

Merged
merged 2 commits into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions optimum/habana/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,3 +403,15 @@ def get_device_name():
return "gaudi3"
else:
raise ValueError(f"Unsupported device: the device type is {device_type}.")


def get_device_count():
"""
Returns the number of the current gaudi devices
"""
import habana_frameworks.torch.utils.experimental as htexp

if htexp.hpu.is_available():
return htexp.hpu.device_count()
else:
raise ValueError("No hpu is found avail on this system")
19 changes: 16 additions & 3 deletions tests/test_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from unittest import TestCase, skipUnless

import diffusers
import habana_frameworks.torch.hpu as hthpu
import numpy as np
import PIL
import pytest
Expand Down Expand Up @@ -190,6 +189,20 @@ def check_gated_model_access(model):
return pytest.mark.skipif(gated, reason=f"{model} is gated, please log in with approved HF access token")


def check_8xhpu(test_case):
"""
Decorator marking a test as it requires 8xHPU on system
"""
from optimum.habana.utils import get_device_count

if get_device_count() != 8:
skip = True
else:
skip = False

return pytest.mark.skipif(skip, reason="test requires 8xHPU multi-card system")(test_case)


class GaudiPipelineUtilsTester(TestCase):
"""
Tests the features added on top of diffusers/pipeline_utils.py.
Expand Down Expand Up @@ -781,7 +794,7 @@ def test_no_generation_regression_upscale(self):
self.assertLess(np.abs(expected_slice - upscaled_image[-3:, -3:, -1].flatten()).max(), 5e-3)

@slow
@pytest.mark.skipif(hthpu.is_available() and hthpu.device_count() != 8, reason="system does not have 8 cards")
@check_8xhpu
def test_sd_textual_inversion(self):
path_to_script = (
Path(os.path.dirname(__file__)).parent
Expand Down Expand Up @@ -2450,7 +2463,7 @@ def test_script_train_controlnet(self):
self.assertEqual(return_code, 0)

@slow
@pytest.mark.skipif(hthpu.is_available() and hthpu.device_count() != 8, reason="system does not have 8 cards")
@check_8xhpu
def test_train_controlnet(self):
with tempfile.TemporaryDirectory() as tmpdir:
path_to_script = (
Expand Down