diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 759ea893ef..10da7486cf 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -170,8 +170,7 @@ def make_mask( max_seqlen_kv = inv_mask.shape[-1] inv_swa_mask = make_swa_mask(max_seqlen_q, max_seqlen_kv, window_size, attn_mask_type) inv_swa_mask = jnp.broadcast_to(inv_swa_mask, inv_mask.shape) - # In inv_swa_mask and inv_mask 0 is masked out - inv_mask = jnp.where(inv_mask != 0, inv_swa_mask, inv_mask) + inv_mask = combine_masks(inv_mask, inv_swa_mask) mask = jnp.logical_not(inv_mask) return mask @@ -315,6 +314,13 @@ def _get_max_segments_per_sequence(self): return self.num_segments_per_seq + 1 def _check_configs(self): + # TODO(rewang): Fix THD + PADDING_CAUSAL + SWA reference + if ( + self.qkv_layout.is_thd() + and self.attn_mask_type == AttnMaskType.PADDING_CAUSAL_MASK + and self.window_size is not None + ): + pytest.skip("THD + PADDING_CAUSAL + SWA reference is not implemented.") # TODO(rewang): probably adds this in is_fused_attn_available if self.qkv_layout.is_thd() and not self.attn_mask_type.is_padding(): pytest.skip("THD format requires padding masks.") @@ -504,7 +510,13 @@ def generate_random_segment_ids( if self.qkv_layout.is_thd(): self.mask_for_customcall = None # THD format doesn't support mask else: - self.mask_for_customcall = self.mask + self.mask_for_customcall = make_mask( + self.segment_ids_q, + self.segment_ids_kv, + self.segment_pos_q, + self.segment_pos_kv, + self.attn_mask_type, + ) self.dropout_rng = dropout_key if self.dropout_prob > 0 else None self.scaling_factor = 1.0 / sqrt(self.head_dim)