Skip to content

Commit

Permalink
fix jax tests
Browse files Browse the repository at this point in the history
Signed-off-by: Charlene Yang <[email protected]>
  • Loading branch information
cyanguwa committed Dec 19, 2024
1 parent 6a4e9e1 commit 16a9d04
Showing 1 changed file with 15 additions and 3 deletions.
18 changes: 15 additions & 3 deletions tests/jax/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,7 @@ def make_mask(
max_seqlen_kv = inv_mask.shape[-1]
inv_swa_mask = make_swa_mask(max_seqlen_q, max_seqlen_kv, window_size, attn_mask_type)
inv_swa_mask = jnp.broadcast_to(inv_swa_mask, inv_mask.shape)
# In inv_swa_mask and inv_mask 0 is masked out
inv_mask = jnp.where(inv_mask != 0, inv_swa_mask, inv_mask)
inv_mask = combine_masks(inv_mask, inv_swa_mask)

mask = jnp.logical_not(inv_mask)
return mask
Expand Down Expand Up @@ -315,6 +314,13 @@ def _get_max_segments_per_sequence(self):
return self.num_segments_per_seq + 1

def _check_configs(self):
# TODO(rewang): Fix THD + PADDING_CAUSAL + SWA reference
if (
self.qkv_layout.is_thd()
and self.attn_mask_type == AttnMaskType.PADDING_CAUSAL_MASK
and self.window_size is not None
):
pytest.skip("THD + PADDING_CAUSAL + SWA reference is not implemented.")
# TODO(rewang): probably adds this in is_fused_attn_available
if self.qkv_layout.is_thd() and not self.attn_mask_type.is_padding():
pytest.skip("THD format requires padding masks.")
Expand Down Expand Up @@ -504,7 +510,13 @@ def generate_random_segment_ids(
if self.qkv_layout.is_thd():
self.mask_for_customcall = None # THD format doesn't support mask
else:
self.mask_for_customcall = self.mask
self.mask_for_customcall = make_mask(
self.segment_ids_q,
self.segment_ids_kv,
self.segment_pos_q,
self.segment_pos_kv,
self.attn_mask_type,
)

self.dropout_rng = dropout_key if self.dropout_prob > 0 else None
self.scaling_factor = 1.0 / sqrt(self.head_dim)
Expand Down

0 comments on commit 16a9d04

Please sign in to comment.