diff --git a/tests/jax/conftest.py b/tests/jax/conftest.py index ccb6690a87..5bb86c6081 100644 --- a/tests/jax/conftest.py +++ b/tests/jax/conftest.py @@ -20,7 +20,7 @@ def clear_live_arrays(): @pytest.fixture(autouse=True, scope="module") -def enable_fused_attn(): +def enable_fused_attn_after_hopper(): """ Enable fused attn for hopper+ arch. Fused attn kernels on pre-hopper arch are not deterministic.