From b898cbe18fb1d3414544bf8d11de5f83cadfb5db Mon Sep 17 00:00:00 2001 From: Reese Wang Date: Wed, 8 Jan 2025 08:31:00 +0800 Subject: [PATCH] [JAX] Add THD + SWA unit tests (#1390) * Fix SWA mask for THD and forcing seqlen_kv >= seqlen_q for SWA Signed-off-by: Reese Wang * Generalize sliding window mask Signed-off-by: Reese Wang * Fix pylint Signed-off-by: Reese Wang --------- Signed-off-by: Reese Wang --- tests/jax/test_fused_attn.py | 60 ++++++++-------- tests/jax/utils.py | 10 +-- transformer_engine/jax/attention.py | 79 +++++++++------------- transformer_engine/jax/flax/transformer.py | 17 +++-- 4 files changed, 80 insertions(+), 86 deletions(-) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 01fc2b3e21..5cbbec7b04 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -148,30 +148,30 @@ def make_mask( segment_ids: [1, 1, 1, 0, 2, 2, 2, 3, 3, 3, 4, 0, 0, 5, 5, 5] segment_pos: [0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2] """ + # segment masks inv_mask = make_attention_mask( segment_ids_q, segment_ids_kv, lambda x, y: (jnp.logical_and(jnp.equal(x, y), x != 0)) ) + + if segment_pos_q is None: + segment_pos_q = jnp.broadcast_to( + jnp.arange(segment_ids_q.shape[-1], dtype=jnp.int32), segment_ids_q.shape + ) + if segment_pos_kv is None: + segment_pos_kv = jnp.broadcast_to( + jnp.arange(segment_ids_kv.shape[-1], dtype=jnp.int32), segment_ids_kv.shape + ) + + # causal mask if attn_mask_type.is_causal(): - if segment_pos_q is None: - segment_pos_q = jnp.broadcast_to( - jnp.arange(segment_ids_q.shape[-1], dtype=jnp.int32), segment_ids_q.shape - ) - if segment_pos_kv is None: - segment_pos_kv = jnp.broadcast_to( - jnp.arange(segment_ids_kv.shape[-1], dtype=jnp.int32), segment_ids_kv.shape - ) inv_causal_mask = make_attention_mask( segment_pos_q, segment_pos_kv, lambda x, y: jnp.greater_equal(x, y) ) inv_mask = combine_masks(inv_causal_mask, inv_mask) - if window_size is not None: - max_seqlen_q = inv_mask.shape[-2] - max_seqlen_kv = inv_mask.shape[-1] - inv_swa_mask = make_swa_mask(max_seqlen_q, max_seqlen_kv, window_size, attn_mask_type) - inv_swa_mask = jnp.broadcast_to(inv_swa_mask, inv_mask.shape) - inv_mask = combine_masks(inv_mask, inv_swa_mask) - + # sliding window mask + inv_swa_mask = make_swa_mask(segment_pos_q, segment_pos_kv, window_size, jnp.bool_) + inv_mask = combine_masks(inv_mask, inv_swa_mask) mask = jnp.logical_not(inv_mask) return mask @@ -314,13 +314,6 @@ def _get_max_segments_per_sequence(self): return self.num_segments_per_seq + 1 def _check_configs(self): - # TODO(rewang): Fix THD + PADDING_CAUSAL + SWA reference - if ( - self.qkv_layout.is_thd() - and self.attn_mask_type == AttnMaskType.PADDING_CAUSAL_MASK - and self.window_size is not None - ): - pytest.skip("THD + PADDING_CAUSAL + SWA reference is not implemented.") # TODO(rewang): probably adds this in is_fused_attn_available if self.qkv_layout.is_thd() and not self.attn_mask_type.is_padding(): pytest.skip("THD format requires padding masks.") @@ -432,7 +425,12 @@ def gen_valid(bs, max_seqlen, pad_ratio): return tokens, jnp.logical_not(tokens) def generate_random_segment_ids( - batch_size, sequence_length, num_segments, seed, with_segment_pad=True + batch_size, + sequence_length, + num_segments, + seed, + with_segment_pad=True, + min_segment_len=None, ): rng = np.random.default_rng(seed=seed) # [1, 1, 1, 2, 2, 3, 3, 3, 3, 0, 0], 0 means pad @@ -448,15 +446,20 @@ def generate_random_segment_ids( current_pos = 0 segment_id = 1 - for _ in range(num_segments): - segment_size = rng.integers(1, max_segment_size + 1) + for seg_id in range(num_segments): + # min_segment_len is to force kv_len >= q_len because cuDNN kernels failed + # TODO(rewang): Remove this constrain after cuDNN supports + min_segment_size = 1 + if min_segment_len is not None: + min_segment_size = min_segment_len[i][seg_id] + segment_size = rng.integers(min_segment_size, max_segment_size + 1) if current_pos + segment_size > sequence_length: break segment_end = current_pos + segment_size segment_ids[i, current_pos:segment_end] = segment_id segment_pos[i, current_pos:segment_end] = np.arange(segment_size) if with_segment_pad: - num_valid = rng.integers(1, segment_size + 1) + num_valid = rng.integers(min_segment_size, segment_size + 1) segment_pad[i, current_pos + num_valid : segment_end] = 1 current_pos = segment_end segment_id += 1 @@ -473,18 +476,21 @@ def generate_random_segment_ids( self.segment_ids_q, self.segment_pos_q, self.pad_q = generate_random_segment_ids( self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42 ) + self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(self.segment_ids_q) if self.qkv_layout == QKVLayout.T3HD: self.segment_ids_kv = self.segment_ids_q self.segment_pos_kv = self.segment_pos_q self.pad_kv = self.pad_q else: + # Force kv_len >= q_len for swa, otherwise, cuDNN kernels don't support + min_segment_len = None if self.window_size is None else self.seqlens_q self.segment_ids_kv, self.segment_pos_kv, self.pad_kv = generate_random_segment_ids( self.batch_size, self.max_seqlen_kv, self.num_segments_per_seq, seed=2024, + min_segment_len=min_segment_len, ) - self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(self.segment_ids_q) self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets(self.segment_ids_kv) else: self.num_segments_per_seq = 1 diff --git a/tests/jax/utils.py b/tests/jax/utils.py index 3ff879e68c..9cb02bc555 100644 --- a/tests/jax/utils.py +++ b/tests/jax/utils.py @@ -919,14 +919,14 @@ def apply_swa_mask( """Apply the sliding window mask to a given mask""" _attn_mask_type = canonicalize_attn_mask_type(attn_mask_type) assert _attn_mask_type is not None + batch = original_mask.shape[0] max_seqlen_q = original_mask.shape[-2] max_seqlen_kv = original_mask.shape[-1] - swa_mask = make_swa_mask( - max_seqlen_q, max_seqlen_kv, window_size, _attn_mask_type, dtype=original_mask.dtype - ) + pos_q = jnp.broadcast_to(jnp.arange(max_seqlen_q), (batch, max_seqlen_q)) + pos_kv = jnp.broadcast_to(jnp.arange(max_seqlen_kv), (batch, max_seqlen_kv)) + swa_mask = make_swa_mask(pos_q, pos_kv, window_size, original_mask.dtype) # In swa_mask and original_mask 0 is masked out - swa_mask_bcast = jnp.broadcast_to(swa_mask, original_mask.shape) - new_mask = jnp.where(original_mask == 1, swa_mask_bcast, original_mask) + new_mask = jnp.where(original_mask == 1, swa_mask, original_mask) return new_mask diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 997d4657df..7b6c605236 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -147,59 +147,44 @@ class CPStrategy(Enum): def make_swa_mask( - max_seqlen_q: int, - max_seqlen_kv: int, + segment_pos_q: jnp.ndarray, + segment_pos_kv: jnp.ndarray, window_size: Optional[Tuple[int, int]] = None, - attn_mask_type: AttnMaskType = AttnMaskType.NO_MASK, dtype: jax.typing.DTypeLike = jnp.float32, ): """ - Generate sliding window mask. `True` or `1` means keep the element. - - For `CAUSAL_BOTTOM_RIGHT_MASK` and `PADDING_CAUSAL_BOTTOM_RIGHT_MASK` mask type, - the sliding window diagonal is aligned to the bottom right corner, and for other - mask types, the top left corner. - - Parameters - ---------- - max_seqlen_q: int - Maximum sequence length for queries. - max_seqlen_kv: int - Maximum sequence length for keys and values. - window_size: Optional[Tuple[int, int]] = None - Sliding window size for local attention, where query at position i attends to keys - in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q - + window_size[1]] inclusive. Negative number in window size means infinity window. - `None` means no sliding window. - attn_mask_type: AttnMaskType, default = AttnMaskType.NO_MASK - dtype: jax.typing.DTypeLike, default=jnp.float32 - The mask data type. - Returns - ---------- - swa_mask: jax.numpy.tensor - Matrix with shape [max_seqlen_q, max_seqlen_kv]. Elements with value 1 are the positions - that will get attention, value 0 are the masked out positions. + Generate a sliding window mask (1 = attend, 0 = masked). + + Args: + segment_pos_q (jnp.ndarray): + Query positions within each segment. For example, a batch with segment_ids = + [[1, 1, 1, 2, 2, 2, 2, 2]] yields segment_pos = + [[0, 1, 2, 0, 1, 2, 3, 4]]. + segment_pos_kv (jnp.ndarray): + Key/value positions within each segment. + window_size (Optional[Tuple[int, int]], optional): + Sliding window size for local attention, where query at position i attends to keys + in [i - window_size[0], i + window_size[1]] inclusive. A negative number means an + infinite window; None means no sliding window. + Defaults to None. + dtype (jax.typing.DTypeLike, optional): + Mask data type. Defaults to jnp.float32. + + Returns: + jnp.ndarray: + The mask with shape [b, 1, max_seqlen_q, max_seqlen_kv]. """ - swa_mask = jnp.ones((max_seqlen_q, max_seqlen_kv), dtype=dtype) - if window_size is None: - return swa_mask - left_window, right_window = window_size - if attn_mask_type.is_bottom_right(): - if left_window < 0: - left_window = max_seqlen_kv - if right_window < 0: - right_window = max_seqlen_kv - bottom_right_shift = max_seqlen_kv - max_seqlen_q - swa_mask = jnp.triu(swa_mask, k=-left_window + bottom_right_shift) - swa_mask = jnp.tril(swa_mask, k=right_window + bottom_right_shift) + if window_size is not None: + left_window, right_window = window_size else: - if left_window < 0: - left_window = max_seqlen_q - if right_window < 0: - right_window = max_seqlen_q - swa_mask = jnp.triu(swa_mask, k=-left_window) - swa_mask = jnp.tril(swa_mask, k=right_window) - return swa_mask + left_window = right_window = jnp.inf + left_window = jnp.inf if left_window < 0 else left_window + right_window = jnp.inf if right_window < 0 else right_window + pos_q = jnp.expand_dims(segment_pos_q, axis=-1) + pos_kv = jnp.expand_dims(segment_pos_kv, axis=-2) + inv_swa_mask = (pos_kv >= pos_q - left_window) & (pos_kv <= pos_q + right_window) + inv_swa_mask = jnp.expand_dims(inv_swa_mask, axis=-3) + return inv_swa_mask.astype(dtype) def canonicalize_attn_mask_type(attn_mask_type: str): diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index e343e9d823..cf2b13d074 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -194,15 +194,18 @@ def __call__( if self.attn_bias_type == AttnBiasType.PRE_SCALE_BIAS: attn_weights += bias - def apply_swa_mask(attn_mask_type: AttnMaskType, original_mask: Array) -> Array: + def apply_swa_mask(original_mask: Array) -> Array: """Apply the sliding window mask to a given mask""" + batch = original_mask.shape[0] max_seqlen_q = original_mask.shape[-2] max_seqlen_kv = original_mask.shape[-1] - swa_mask = make_swa_mask(max_seqlen_q, max_seqlen_kv, self.window_size, attn_mask_type) - # In swa_mask 0 is masked out, in original_mask 1 is masked out - swa_mask = 1 - swa_mask.astype(original_mask.dtype) - swa_mask_bcast = jnp.broadcast_to(swa_mask, original_mask.shape) - new_mask = jnp.where(original_mask == 0, swa_mask_bcast, original_mask) + # TODO(rewang): Support THD format pos + pos_q = jnp.broadcast_to(jnp.arange(max_seqlen_q), (batch, max_seqlen_q)) + pos_kv = jnp.broadcast_to(jnp.arange(max_seqlen_kv), (batch, max_seqlen_kv)) + # In inv_swa_mask 0 is masked out, in original_mask 1 is masked out + inv_swa_mask = make_swa_mask(pos_q, pos_kv, self.window_size, original_mask.dtype) + swa_mask = 1 - inv_swa_mask + new_mask = jnp.where(original_mask == 0, swa_mask, original_mask) return new_mask def convert_to_softmax_type(attn_mask_type, mask): @@ -213,7 +216,7 @@ def convert_to_softmax_type(attn_mask_type, mask): if attn_mask_type == AttnMaskType.CAUSAL_MASK and self.window_size is None: mask = None if mask is not None: - mask = apply_swa_mask(attn_mask_type, mask) + mask = apply_swa_mask(mask) # Currently cuDNN backend only supports SWA for causal/padding_causal, follow this if attn_mask_type in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]: return SoftmaxType.SCALED_UPPER_TRIANG_MASKED, mask