Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

skip failing MX tests on cuda capability 10.0 #1624

Merged
merged 1 commit into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion test/prototype/mx_formats/test_custom_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
sem_vals_to_f32,
)
from torchao.prototype.mx_formats.mx_tensor import MXTensor
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, is_sm_at_least_100

torch.manual_seed(0)

Expand Down Expand Up @@ -310,6 +310,9 @@ def test_fp4_pack_unpack():
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_4, reason="requires PyTorch >= 2.4")
@pytest.mark.skipif(
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
)
def test_fp4_triton_unscaled_cast():
packed_vals = torch.arange(0, 255, dtype=torch.uint8, device="cuda")
f32_ref = f4_unpacked_to_f32(unpack_uint4(packed_vals))
Expand All @@ -320,6 +323,9 @@ def test_fp4_triton_unscaled_cast():
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_4, reason="requires PyTorch >= 2.4")
@pytest.mark.skipif(
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
)
def test_fp4_triton_scaled_cast():
size = (256,)
orig_vals = torch.randn(size, dtype=torch.float, device="cuda") * 100
Expand Down
12 changes: 11 additions & 1 deletion test/prototype/mx_formats/test_mx_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
swap_linear_with_mx_linear,
)
from torchao.quantization.utils import compute_error
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, is_sm_at_least_89
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_4,
is_sm_at_least_89,
is_sm_at_least_100,
)

torch.manual_seed(2)

Expand Down Expand Up @@ -99,6 +103,9 @@ def test_activation_checkpointing():


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
)
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
@pytest.mark.parametrize("bias", [False, True])
# TODO(future PR): figure out why torch.compile does not match eager when
Expand Down Expand Up @@ -184,6 +191,9 @@ def test_inference_linear(elem_dtype, bias, input_shape):


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
)
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
def test_inference_compile_simple(elem_dtype):
"""
Expand Down
11 changes: 10 additions & 1 deletion test/prototype/mx_formats/test_mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@
to_dtype,
)
from torchao.quantization.utils import compute_error
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, is_sm_at_least_89
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_4,
is_sm_at_least_89,
is_sm_at_least_100,
)

torch.manual_seed(2)

Expand Down Expand Up @@ -166,6 +170,8 @@ def test_transpose(elem_dtype, fp4_triton):
"""
if elem_dtype != DTYPE_FP4 and fp4_triton:
pytest.skip("unsupported configuration")
elif fp4_triton and is_sm_at_least_100():
pytest.skip("triton does not work yet on CUDA capability 10.0")

M, K = 128, 256
block_size = 32
Expand Down Expand Up @@ -205,6 +211,9 @@ def test_view(elem_dtype):


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
)
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
@pytest.mark.parametrize("hp_dtype", [torch.float32, torch.bfloat16])
@pytest.mark.parametrize("all_zeros", [False, True])
Expand Down
9 changes: 9 additions & 0 deletions torchao/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,15 @@ def is_sm_at_least_90():
)


# TODO(future PR): rename to 8_9, 9_0, 10_0 instead of 89, 10, 100
def is_sm_at_least_100():
return (
torch.cuda.is_available()
and torch.version.cuda
and torch.cuda.get_device_capability() >= (10, 0)
)


TORCH_VERSION_AFTER_2_5 = _torch_version_at_least("2.5.0.dev")
TORCH_VERSION_AFTER_2_4 = _torch_version_at_least("2.4.0.dev")
TORCH_VERSION_AFTER_2_3 = _torch_version_at_least("2.3.0.dev")
Expand Down
Loading