Skip to content

Commit

Permalink
Uses pytest markers instead of module skip.
Browse files Browse the repository at this point in the history
  • Loading branch information
markblee committed Jan 23, 2025
1 parent 30284c8 commit 0e07bf5
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 4 deletions.
10 changes: 7 additions & 3 deletions axlearn/common/flash_attention/gpu_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Currently tested on A100/H100.
"""

import functools
from typing import Literal

Expand All @@ -28,9 +29,6 @@
from axlearn.common.flash_attention.utils import _repeat_kv_heads, mha_reference
from axlearn.common.test_utils import TestCase

if jax.default_backend() not in ("gpu", "cpu"):
pytest.skip(reason="Incompatible hardware", allow_module_level=True)


@pytest.mark.parametrize(
"batch_size,seq_len,num_heads,per_head_dim",
Expand All @@ -51,6 +49,7 @@
@pytest.mark.parametrize("attention_bias_type", [None, "2d", "4d"])
@pytest.mark.parametrize("use_segment_ids", [True, False])
@pytest.mark.parametrize("input_dtype", [jnp.float16, jnp.float32])
@pytest.mark.gpu
def test_triton_fwd_only_against_ref(
batch_size: int,
seq_len: int,
Expand Down Expand Up @@ -122,6 +121,7 @@ def test_triton_fwd_only_against_ref(
chex.assert_trees_all_close(o, o_ref, atol=0.03)


@pytest.mark.gpu
class FlashDecodingTest(TestCase):
"""Tests FlashDecoding."""

Expand Down Expand Up @@ -234,6 +234,7 @@ def test_decode_against_ref(
@pytest.mark.parametrize("block_size", [64, 128])
@pytest.mark.parametrize("causal", [True, False])
@pytest.mark.parametrize("input_dtype", [jnp.float16, jnp.float32])
@pytest.mark.gpu
def test_triton_against_xla_ref(
batch_size: int,
num_heads: int,
Expand Down Expand Up @@ -353,6 +354,7 @@ def ref_fn(q, k, v, bias, segment_ids, k5):
)
@pytest.mark.parametrize("causal", [True, False])
@pytest.mark.parametrize("dtype", [jnp.bfloat16, jnp.float16])
@pytest.mark.gpu
def test_cudnn_against_triton_ref(
batch_size: int,
num_heads: int,
Expand Down Expand Up @@ -433,6 +435,7 @@ def ref_fn(q, k, v):
@pytest.mark.parametrize("causal", [True, False])
@pytest.mark.parametrize("dtype", [jnp.bfloat16, jnp.float16])
@pytest.mark.parametrize("dropout_rate", [0.1, 0.25])
@pytest.mark.gpu
def test_cudnn_dropout_against_xla_dropout(
batch_size: int,
num_heads: int,
Expand Down Expand Up @@ -515,6 +518,7 @@ def ref_fn(q, k, v):
raise ValueError(f"Unsupported dtype: {dtype}")


@pytest.mark.gpu
def test_cudnn_dropout_determinism():
"""Tests that cuDNN dropout produces identical outputs across runs."""
if jax.default_backend() == "cpu":
Expand Down
2 changes: 2 additions & 0 deletions axlearn/common/flash_attention/tpu_attention_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright © 2023 Apple Inc.

"""Tests TPU FlashAttention kernels."""

from __future__ import annotations

import unittest
Expand Down Expand Up @@ -46,6 +47,7 @@ def jax_fn_mask(query_position: Tensor, key_position: Tensor) -> Tensor:
return jnp.greater_equal(query_position, key_position)


@pytest.mark.tpu
class TestFlashAttention(TestCase):
"""Tests FlashAttention layer."""

Expand Down
2 changes: 1 addition & 1 deletion run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ fi

UNQUOTED_PYTEST_FILES=$(echo $1 | tr -d "'")
pytest --durations=100 -v -n auto \
-m "not (gs_login or tpu or high_cpu or fp64)" ${UNQUOTED_PYTEST_FILES} \
-m "not (gs_login or tpu or gpu or high_cpu or fp64)" ${UNQUOTED_PYTEST_FILES} \
--dist worksteal &
TEST_PIDS[$!]=1

Expand Down

0 comments on commit 0e07bf5

Please sign in to comment.