Skip to content

Commit

Permalink
Fix unit test nit
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaoyu-work committed Jan 7, 2025
1 parent 5eb8f0b commit 215eb63
Showing 1 changed file with 14 additions and 15 deletions.
29 changes: 14 additions & 15 deletions test/unit_test/model/test_hf_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,23 +41,22 @@ def test_device_map(self, inputs, inner):
assert args.device_map == {"": inner}

@pytest.mark.skipif(not torch.cuda.is_available(), reason="No GPU available")
class TestHfLoadKwargsGPU:
@pytest.mark.parametrize(
("inputs", "inner"),
[
("cuda:0", "cuda:0"),
("0", "cuda:0"),
],
)
def test_device_map_cpu(self, inputs, inner):
if inputs == "0":
inputs = torch.device(0)
@pytest.mark.parametrize(
("inputs", "inner"),
[
("cuda:0", "cuda:0"),
("0", "cuda:0"),
],
)
def test_device_map_cpu(self, inputs, inner):
if inputs == "0":
inputs = torch.device(0)

args = HfLoadKwargs(device_map=inputs)
assert args.device_map == inner
args = HfLoadKwargs(device_map=inputs)
assert args.device_map == inner

args = HfLoadKwargs(device_map={"": inputs})
assert args.device_map == {"": inner}
args = HfLoadKwargs(device_map={"": inputs})
assert args.device_map == {"": inner}

@pytest.mark.parametrize(
("quantization_method", "quantization_config", "valid"),
Expand Down

0 comments on commit 215eb63

Please sign in to comment.