Skip to content

Commit

Permalink
Merge branch 'main' into mgoldfarb-nvidia/online_softmax_aux_correction
Browse files Browse the repository at this point in the history
  • Loading branch information
mgoldfarb-nvidia authored Jan 8, 2025
2 parents bce6c2e + 61cf102 commit 268d140
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 87 deletions.
60 changes: 33 additions & 27 deletions tests/jax/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,30 +148,30 @@ def make_mask(
segment_ids: [1, 1, 1, 0, 2, 2, 2, 3, 3, 3, 4, 0, 0, 5, 5, 5]
segment_pos: [0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2]
"""
# segment masks
inv_mask = make_attention_mask(
segment_ids_q, segment_ids_kv, lambda x, y: (jnp.logical_and(jnp.equal(x, y), x != 0))
)

if segment_pos_q is None:
segment_pos_q = jnp.broadcast_to(
jnp.arange(segment_ids_q.shape[-1], dtype=jnp.int32), segment_ids_q.shape
)
if segment_pos_kv is None:
segment_pos_kv = jnp.broadcast_to(
jnp.arange(segment_ids_kv.shape[-1], dtype=jnp.int32), segment_ids_kv.shape
)

# causal mask
if attn_mask_type.is_causal():
if segment_pos_q is None:
segment_pos_q = jnp.broadcast_to(
jnp.arange(segment_ids_q.shape[-1], dtype=jnp.int32), segment_ids_q.shape
)
if segment_pos_kv is None:
segment_pos_kv = jnp.broadcast_to(
jnp.arange(segment_ids_kv.shape[-1], dtype=jnp.int32), segment_ids_kv.shape
)
inv_causal_mask = make_attention_mask(
segment_pos_q, segment_pos_kv, lambda x, y: jnp.greater_equal(x, y)
)
inv_mask = combine_masks(inv_causal_mask, inv_mask)

if window_size is not None:
max_seqlen_q = inv_mask.shape[-2]
max_seqlen_kv = inv_mask.shape[-1]
inv_swa_mask = make_swa_mask(max_seqlen_q, max_seqlen_kv, window_size, attn_mask_type)
inv_swa_mask = jnp.broadcast_to(inv_swa_mask, inv_mask.shape)
inv_mask = combine_masks(inv_mask, inv_swa_mask)

# sliding window mask
inv_swa_mask = make_swa_mask(segment_pos_q, segment_pos_kv, window_size, jnp.bool_)
inv_mask = combine_masks(inv_mask, inv_swa_mask)
mask = jnp.logical_not(inv_mask)
return mask

Expand Down Expand Up @@ -314,13 +314,6 @@ def _get_max_segments_per_sequence(self):
return self.num_segments_per_seq + 1

def _check_configs(self):
# TODO(rewang): Fix THD + PADDING_CAUSAL + SWA reference
if (
self.qkv_layout.is_thd()
and self.attn_mask_type == AttnMaskType.PADDING_CAUSAL_MASK
and self.window_size is not None
):
pytest.skip("THD + PADDING_CAUSAL + SWA reference is not implemented.")
# TODO(rewang): probably adds this in is_fused_attn_available
if self.qkv_layout.is_thd() and not self.attn_mask_type.is_padding():
pytest.skip("THD format requires padding masks.")
Expand Down Expand Up @@ -432,7 +425,12 @@ def gen_valid(bs, max_seqlen, pad_ratio):
return tokens, jnp.logical_not(tokens)

def generate_random_segment_ids(
batch_size, sequence_length, num_segments, seed, with_segment_pad=True
batch_size,
sequence_length,
num_segments,
seed,
with_segment_pad=True,
min_segment_len=None,
):
rng = np.random.default_rng(seed=seed)
# [1, 1, 1, 2, 2, 3, 3, 3, 3, 0, 0], 0 means pad
Expand All @@ -448,15 +446,20 @@ def generate_random_segment_ids(
current_pos = 0
segment_id = 1

for _ in range(num_segments):
segment_size = rng.integers(1, max_segment_size + 1)
for seg_id in range(num_segments):
# min_segment_len is to force kv_len >= q_len because cuDNN kernels failed
# TODO(rewang): Remove this constrain after cuDNN supports
min_segment_size = 1
if min_segment_len is not None:
min_segment_size = min_segment_len[i][seg_id]
segment_size = rng.integers(min_segment_size, max_segment_size + 1)
if current_pos + segment_size > sequence_length:
break
segment_end = current_pos + segment_size
segment_ids[i, current_pos:segment_end] = segment_id
segment_pos[i, current_pos:segment_end] = np.arange(segment_size)
if with_segment_pad:
num_valid = rng.integers(1, segment_size + 1)
num_valid = rng.integers(min_segment_size, segment_size + 1)
segment_pad[i, current_pos + num_valid : segment_end] = 1
current_pos = segment_end
segment_id += 1
Expand All @@ -473,18 +476,21 @@ def generate_random_segment_ids(
self.segment_ids_q, self.segment_pos_q, self.pad_q = generate_random_segment_ids(
self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42
)
self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(self.segment_ids_q)
if self.qkv_layout == QKVLayout.T3HD:
self.segment_ids_kv = self.segment_ids_q
self.segment_pos_kv = self.segment_pos_q
self.pad_kv = self.pad_q
else:
# Force kv_len >= q_len for swa, otherwise, cuDNN kernels don't support
min_segment_len = None if self.window_size is None else self.seqlens_q
self.segment_ids_kv, self.segment_pos_kv, self.pad_kv = generate_random_segment_ids(
self.batch_size,
self.max_seqlen_kv,
self.num_segments_per_seq,
seed=2024,
min_segment_len=min_segment_len,
)
self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(self.segment_ids_q)
self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets(self.segment_ids_kv)
else:
self.num_segments_per_seq = 1
Expand Down
10 changes: 5 additions & 5 deletions tests/jax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,14 +919,14 @@ def apply_swa_mask(
"""Apply the sliding window mask to a given mask"""
_attn_mask_type = canonicalize_attn_mask_type(attn_mask_type)
assert _attn_mask_type is not None
batch = original_mask.shape[0]
max_seqlen_q = original_mask.shape[-2]
max_seqlen_kv = original_mask.shape[-1]
swa_mask = make_swa_mask(
max_seqlen_q, max_seqlen_kv, window_size, _attn_mask_type, dtype=original_mask.dtype
)
pos_q = jnp.broadcast_to(jnp.arange(max_seqlen_q), (batch, max_seqlen_q))
pos_kv = jnp.broadcast_to(jnp.arange(max_seqlen_kv), (batch, max_seqlen_kv))
swa_mask = make_swa_mask(pos_q, pos_kv, window_size, original_mask.dtype)
# In swa_mask and original_mask 0 is masked out
swa_mask_bcast = jnp.broadcast_to(swa_mask, original_mask.shape)
new_mask = jnp.where(original_mask == 1, swa_mask_bcast, original_mask)
new_mask = jnp.where(original_mask == 1, swa_mask, original_mask)
return new_mask


Expand Down
79 changes: 32 additions & 47 deletions transformer_engine/jax/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,59 +147,44 @@ class CPStrategy(Enum):


def make_swa_mask(
max_seqlen_q: int,
max_seqlen_kv: int,
segment_pos_q: jnp.ndarray,
segment_pos_kv: jnp.ndarray,
window_size: Optional[Tuple[int, int]] = None,
attn_mask_type: AttnMaskType = AttnMaskType.NO_MASK,
dtype: jax.typing.DTypeLike = jnp.float32,
):
"""
Generate sliding window mask. `True` or `1` means keep the element.
For `CAUSAL_BOTTOM_RIGHT_MASK` and `PADDING_CAUSAL_BOTTOM_RIGHT_MASK` mask type,
the sliding window diagonal is aligned to the bottom right corner, and for other
mask types, the top left corner.
Parameters
----------
max_seqlen_q: int
Maximum sequence length for queries.
max_seqlen_kv: int
Maximum sequence length for keys and values.
window_size: Optional[Tuple[int, int]] = None
Sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Negative number in window size means infinity window.
`None` means no sliding window.
attn_mask_type: AttnMaskType, default = AttnMaskType.NO_MASK
dtype: jax.typing.DTypeLike, default=jnp.float32
The mask data type.
Returns
----------
swa_mask: jax.numpy.tensor
Matrix with shape [max_seqlen_q, max_seqlen_kv]. Elements with value 1 are the positions
that will get attention, value 0 are the masked out positions.
Generate a sliding window mask (1 = attend, 0 = masked).
Args:
segment_pos_q (jnp.ndarray):
Query positions within each segment. For example, a batch with segment_ids =
[[1, 1, 1, 2, 2, 2, 2, 2]] yields segment_pos =
[[0, 1, 2, 0, 1, 2, 3, 4]].
segment_pos_kv (jnp.ndarray):
Key/value positions within each segment.
window_size (Optional[Tuple[int, int]], optional):
Sliding window size for local attention, where query at position i attends to keys
in [i - window_size[0], i + window_size[1]] inclusive. A negative number means an
infinite window; None means no sliding window.
Defaults to None.
dtype (jax.typing.DTypeLike, optional):
Mask data type. Defaults to jnp.float32.
Returns:
jnp.ndarray:
The mask with shape [b, 1, max_seqlen_q, max_seqlen_kv].
"""
swa_mask = jnp.ones((max_seqlen_q, max_seqlen_kv), dtype=dtype)
if window_size is None:
return swa_mask
left_window, right_window = window_size
if attn_mask_type.is_bottom_right():
if left_window < 0:
left_window = max_seqlen_kv
if right_window < 0:
right_window = max_seqlen_kv
bottom_right_shift = max_seqlen_kv - max_seqlen_q
swa_mask = jnp.triu(swa_mask, k=-left_window + bottom_right_shift)
swa_mask = jnp.tril(swa_mask, k=right_window + bottom_right_shift)
if window_size is not None:
left_window, right_window = window_size
else:
if left_window < 0:
left_window = max_seqlen_q
if right_window < 0:
right_window = max_seqlen_q
swa_mask = jnp.triu(swa_mask, k=-left_window)
swa_mask = jnp.tril(swa_mask, k=right_window)
return swa_mask
left_window = right_window = jnp.inf
left_window = jnp.inf if left_window < 0 else left_window
right_window = jnp.inf if right_window < 0 else right_window
pos_q = jnp.expand_dims(segment_pos_q, axis=-1)
pos_kv = jnp.expand_dims(segment_pos_kv, axis=-2)
inv_swa_mask = (pos_kv >= pos_q - left_window) & (pos_kv <= pos_q + right_window)
inv_swa_mask = jnp.expand_dims(inv_swa_mask, axis=-3)
return inv_swa_mask.astype(dtype)


def canonicalize_attn_mask_type(attn_mask_type: str):
Expand Down
17 changes: 10 additions & 7 deletions transformer_engine/jax/flax/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,15 +194,18 @@ def __call__(
if self.attn_bias_type == AttnBiasType.PRE_SCALE_BIAS:
attn_weights += bias

def apply_swa_mask(attn_mask_type: AttnMaskType, original_mask: Array) -> Array:
def apply_swa_mask(original_mask: Array) -> Array:
"""Apply the sliding window mask to a given mask"""
batch = original_mask.shape[0]
max_seqlen_q = original_mask.shape[-2]
max_seqlen_kv = original_mask.shape[-1]
swa_mask = make_swa_mask(max_seqlen_q, max_seqlen_kv, self.window_size, attn_mask_type)
# In swa_mask 0 is masked out, in original_mask 1 is masked out
swa_mask = 1 - swa_mask.astype(original_mask.dtype)
swa_mask_bcast = jnp.broadcast_to(swa_mask, original_mask.shape)
new_mask = jnp.where(original_mask == 0, swa_mask_bcast, original_mask)
# TODO(rewang): Support THD format pos
pos_q = jnp.broadcast_to(jnp.arange(max_seqlen_q), (batch, max_seqlen_q))
pos_kv = jnp.broadcast_to(jnp.arange(max_seqlen_kv), (batch, max_seqlen_kv))
# In inv_swa_mask 0 is masked out, in original_mask 1 is masked out
inv_swa_mask = make_swa_mask(pos_q, pos_kv, self.window_size, original_mask.dtype)
swa_mask = 1 - inv_swa_mask
new_mask = jnp.where(original_mask == 0, swa_mask, original_mask)
return new_mask

def convert_to_softmax_type(attn_mask_type, mask):
Expand All @@ -213,7 +216,7 @@ def convert_to_softmax_type(attn_mask_type, mask):
if attn_mask_type == AttnMaskType.CAUSAL_MASK and self.window_size is None:
mask = None
if mask is not None:
mask = apply_swa_mask(attn_mask_type, mask)
mask = apply_swa_mask(mask)
# Currently cuDNN backend only supports SWA for causal/padding_causal, follow this
if attn_mask_type in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]:
return SoftmaxType.SCALED_UPPER_TRIANG_MASKED, mask
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/module/layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def forward(
ub=ub_obj_lnout if ub_overlap_ag else None,
extra_output_tensor=ln_out if ub_overlap_ag else None,
)
if not is_grad_enabled:
if not is_grad_enabled and not return_layernorm_output:
clear_tensor_data(ln_out_total)

if bias_gelu_nvfusion:
Expand Down

0 comments on commit 268d140

Please sign in to comment.