Skip to content

Commit

Permalink
[JAX] Add the missing 1HSS tests (NVIDIA#1052)
Browse files Browse the repository at this point in the history
Add the missing 1HSS tests

Signed-off-by: Reese Wang <[email protected]>
  • Loading branch information
zlsh80826 authored Aug 6, 2024
1 parent d74e65f commit 5bb3a41
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion tests/jax/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,10 @@ def _check_configs(self):
if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend:
pytest.skip("Unsupported inputs combination or device compute capability.")

if self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS:
if (
self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS
and self.bias_shape != BiasShape.BIAS_1HSS
):
if self.attn_mask_type not in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
pytest.skip(
"B1SS, BHSS and 11SS bias shapes are only supported for "
Expand Down

0 comments on commit 5bb3a41

Please sign in to comment.