Skip to content

Commit

Permalink
Merge branch 'main' into xren/cp_optim
Browse files Browse the repository at this point in the history
  • Loading branch information
xrennvidia authored Jan 8, 2025
2 parents 2722165 + a4cb1d1 commit 86109a6
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 113 deletions.
8 changes: 4 additions & 4 deletions tests/jax/test_distributed_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ def qkv_to_layout(self, q, k, v, qkv_layout):
raise ValueError(f"Unsupported {qkv_layout=}")
return qkv_args

def impl_test_contex_parallel_attn(
def impl_test_context_parallel_attn(
self,
device_count,
mesh_shape,
Expand Down Expand Up @@ -583,7 +583,7 @@ def grad_func(func, *args, **kwargs):

assert_allclose(target_grads[i], ref_grads[i], dtype=dtype)

def test_contex_parallel_allgather_attn(
def test_context_parallel_allgather_attn(
self,
device_count,
mesh_shape,
Expand All @@ -596,7 +596,7 @@ def test_contex_parallel_allgather_attn(
qkv_layout,
load_balanced,
):
return self.impl_test_contex_parallel_attn(
return self.impl_test_context_parallel_attn(
device_count,
mesh_shape,
mesh_axes,
Expand All @@ -623,7 +623,7 @@ def test_context_parallel_ring_attn(
qkv_layout,
load_balanced,
):
return self.impl_test_contex_parallel_attn(
return self.impl_test_context_parallel_attn(
device_count,
mesh_shape,
mesh_axes,
Expand Down
60 changes: 33 additions & 27 deletions tests/jax/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,30 +148,30 @@ def make_mask(
segment_ids: [1, 1, 1, 0, 2, 2, 2, 3, 3, 3, 4, 0, 0, 5, 5, 5]
segment_pos: [0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2]
"""
# segment masks
inv_mask = make_attention_mask(
segment_ids_q, segment_ids_kv, lambda x, y: (jnp.logical_and(jnp.equal(x, y), x != 0))
)

if segment_pos_q is None:
segment_pos_q = jnp.broadcast_to(
jnp.arange(segment_ids_q.shape[-1], dtype=jnp.int32), segment_ids_q.shape
)
if segment_pos_kv is None:
segment_pos_kv = jnp.broadcast_to(
jnp.arange(segment_ids_kv.shape[-1], dtype=jnp.int32), segment_ids_kv.shape
)

# causal mask
if attn_mask_type.is_causal():
if segment_pos_q is None:
segment_pos_q = jnp.broadcast_to(
jnp.arange(segment_ids_q.shape[-1], dtype=jnp.int32), segment_ids_q.shape
)
if segment_pos_kv is None:
segment_pos_kv = jnp.broadcast_to(
jnp.arange(segment_ids_kv.shape[-1], dtype=jnp.int32), segment_ids_kv.shape
)
inv_causal_mask = make_attention_mask(
segment_pos_q, segment_pos_kv, lambda x, y: jnp.greater_equal(x, y)
)
inv_mask = combine_masks(inv_causal_mask, inv_mask)

if window_size is not None:
max_seqlen_q = inv_mask.shape[-2]
max_seqlen_kv = inv_mask.shape[-1]
inv_swa_mask = make_swa_mask(max_seqlen_q, max_seqlen_kv, window_size, attn_mask_type)
inv_swa_mask = jnp.broadcast_to(inv_swa_mask, inv_mask.shape)
inv_mask = combine_masks(inv_mask, inv_swa_mask)

# sliding window mask
inv_swa_mask = make_swa_mask(segment_pos_q, segment_pos_kv, window_size, jnp.bool_)
inv_mask = combine_masks(inv_mask, inv_swa_mask)
mask = jnp.logical_not(inv_mask)
return mask

Expand Down Expand Up @@ -314,13 +314,6 @@ def _get_max_segments_per_sequence(self):
return self.num_segments_per_seq + 1

def _check_configs(self):
# TODO(rewang): Fix THD + PADDING_CAUSAL + SWA reference
if (
self.qkv_layout.is_thd()
and self.attn_mask_type == AttnMaskType.PADDING_CAUSAL_MASK
and self.window_size is not None
):
pytest.skip("THD + PADDING_CAUSAL + SWA reference is not implemented.")
# TODO(rewang): probably adds this in is_fused_attn_available
if self.qkv_layout.is_thd() and not self.attn_mask_type.is_padding():
pytest.skip("THD format requires padding masks.")
Expand Down Expand Up @@ -432,7 +425,12 @@ def gen_valid(bs, max_seqlen, pad_ratio):
return tokens, jnp.logical_not(tokens)

def generate_random_segment_ids(
batch_size, sequence_length, num_segments, seed, with_segment_pad=True
batch_size,
sequence_length,
num_segments,
seed,
with_segment_pad=True,
min_segment_len=None,
):
rng = np.random.default_rng(seed=seed)
# [1, 1, 1, 2, 2, 3, 3, 3, 3, 0, 0], 0 means pad
Expand All @@ -448,15 +446,20 @@ def generate_random_segment_ids(
current_pos = 0
segment_id = 1

for _ in range(num_segments):
segment_size = rng.integers(1, max_segment_size + 1)
for seg_id in range(num_segments):
# min_segment_len is to force kv_len >= q_len because cuDNN kernels failed
# TODO(rewang): Remove this constrain after cuDNN supports
min_segment_size = 1
if min_segment_len is not None:
min_segment_size = min_segment_len[i][seg_id]
segment_size = rng.integers(min_segment_size, max_segment_size + 1)
if current_pos + segment_size > sequence_length:
break
segment_end = current_pos + segment_size
segment_ids[i, current_pos:segment_end] = segment_id
segment_pos[i, current_pos:segment_end] = np.arange(segment_size)
if with_segment_pad:
num_valid = rng.integers(1, segment_size + 1)
num_valid = rng.integers(min_segment_size, segment_size + 1)
segment_pad[i, current_pos + num_valid : segment_end] = 1
current_pos = segment_end
segment_id += 1
Expand All @@ -473,18 +476,21 @@ def generate_random_segment_ids(
self.segment_ids_q, self.segment_pos_q, self.pad_q = generate_random_segment_ids(
self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42
)
self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(self.segment_ids_q)
if self.qkv_layout == QKVLayout.T3HD:
self.segment_ids_kv = self.segment_ids_q
self.segment_pos_kv = self.segment_pos_q
self.pad_kv = self.pad_q
else:
# Force kv_len >= q_len for swa, otherwise, cuDNN kernels don't support
min_segment_len = None if self.window_size is None else self.seqlens_q
self.segment_ids_kv, self.segment_pos_kv, self.pad_kv = generate_random_segment_ids(
self.batch_size,
self.max_seqlen_kv,
self.num_segments_per_seq,
seed=2024,
min_segment_len=min_segment_len,
)
self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(self.segment_ids_q)
self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets(self.segment_ids_kv)
else:
self.num_segments_per_seq = 1
Expand Down
10 changes: 5 additions & 5 deletions tests/jax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,14 +919,14 @@ def apply_swa_mask(
"""Apply the sliding window mask to a given mask"""
_attn_mask_type = canonicalize_attn_mask_type(attn_mask_type)
assert _attn_mask_type is not None
batch = original_mask.shape[0]
max_seqlen_q = original_mask.shape[-2]
max_seqlen_kv = original_mask.shape[-1]
swa_mask = make_swa_mask(
max_seqlen_q, max_seqlen_kv, window_size, _attn_mask_type, dtype=original_mask.dtype
)
pos_q = jnp.broadcast_to(jnp.arange(max_seqlen_q), (batch, max_seqlen_q))
pos_kv = jnp.broadcast_to(jnp.arange(max_seqlen_kv), (batch, max_seqlen_kv))
swa_mask = make_swa_mask(pos_q, pos_kv, window_size, original_mask.dtype)
# In swa_mask and original_mask 0 is masked out
swa_mask_bcast = jnp.broadcast_to(swa_mask, original_mask.shape)
new_mask = jnp.where(original_mask == 1, swa_mask_bcast, original_mask)
new_mask = jnp.where(original_mask == 1, swa_mask, original_mask)
return new_mask


Expand Down
79 changes: 32 additions & 47 deletions transformer_engine/jax/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,59 +147,44 @@ class CPStrategy(Enum):


def make_swa_mask(
max_seqlen_q: int,
max_seqlen_kv: int,
segment_pos_q: jnp.ndarray,
segment_pos_kv: jnp.ndarray,
window_size: Optional[Tuple[int, int]] = None,
attn_mask_type: AttnMaskType = AttnMaskType.NO_MASK,
dtype: jax.typing.DTypeLike = jnp.float32,
):
"""
Generate sliding window mask. `True` or `1` means keep the element.
For `CAUSAL_BOTTOM_RIGHT_MASK` and `PADDING_CAUSAL_BOTTOM_RIGHT_MASK` mask type,
the sliding window diagonal is aligned to the bottom right corner, and for other
mask types, the top left corner.
Parameters
----------
max_seqlen_q: int
Maximum sequence length for queries.
max_seqlen_kv: int
Maximum sequence length for keys and values.
window_size: Optional[Tuple[int, int]] = None
Sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Negative number in window size means infinity window.
`None` means no sliding window.
attn_mask_type: AttnMaskType, default = AttnMaskType.NO_MASK
dtype: jax.typing.DTypeLike, default=jnp.float32
The mask data type.
Returns
----------
swa_mask: jax.numpy.tensor
Matrix with shape [max_seqlen_q, max_seqlen_kv]. Elements with value 1 are the positions
that will get attention, value 0 are the masked out positions.
Generate a sliding window mask (1 = attend, 0 = masked).
Args:
segment_pos_q (jnp.ndarray):
Query positions within each segment. For example, a batch with segment_ids =
[[1, 1, 1, 2, 2, 2, 2, 2]] yields segment_pos =
[[0, 1, 2, 0, 1, 2, 3, 4]].
segment_pos_kv (jnp.ndarray):
Key/value positions within each segment.
window_size (Optional[Tuple[int, int]], optional):
Sliding window size for local attention, where query at position i attends to keys
in [i - window_size[0], i + window_size[1]] inclusive. A negative number means an
infinite window; None means no sliding window.
Defaults to None.
dtype (jax.typing.DTypeLike, optional):
Mask data type. Defaults to jnp.float32.
Returns:
jnp.ndarray:
The mask with shape [b, 1, max_seqlen_q, max_seqlen_kv].
"""
swa_mask = jnp.ones((max_seqlen_q, max_seqlen_kv), dtype=dtype)
if window_size is None:
return swa_mask
left_window, right_window = window_size
if attn_mask_type.is_bottom_right():
if left_window < 0:
left_window = max_seqlen_kv
if right_window < 0:
right_window = max_seqlen_kv
bottom_right_shift = max_seqlen_kv - max_seqlen_q
swa_mask = jnp.triu(swa_mask, k=-left_window + bottom_right_shift)
swa_mask = jnp.tril(swa_mask, k=right_window + bottom_right_shift)
if window_size is not None:
left_window, right_window = window_size
else:
if left_window < 0:
left_window = max_seqlen_q
if right_window < 0:
right_window = max_seqlen_q
swa_mask = jnp.triu(swa_mask, k=-left_window)
swa_mask = jnp.tril(swa_mask, k=right_window)
return swa_mask
left_window = right_window = jnp.inf
left_window = jnp.inf if left_window < 0 else left_window
right_window = jnp.inf if right_window < 0 else right_window
pos_q = jnp.expand_dims(segment_pos_q, axis=-1)
pos_kv = jnp.expand_dims(segment_pos_kv, axis=-2)
inv_swa_mask = (pos_kv >= pos_q - left_window) & (pos_kv <= pos_q + right_window)
inv_swa_mask = jnp.expand_dims(inv_swa_mask, axis=-3)
return inv_swa_mask.astype(dtype)


def canonicalize_attn_mask_type(attn_mask_type: str):
Expand Down
61 changes: 39 additions & 22 deletions transformer_engine/jax/cpp_extensions/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1549,12 +1549,19 @@ def permute_kv(self, kv, cp_perm):
"""Permutes kv around the ring as described by cp_perm."""
return lax_paral_op(kv, lax.ppermute, self.config.cp_axis, mesh=self.mesh, perm=cp_perm)

def correct_softmax_aux(self, softmax_aux, softmax_aux_per_step):
"""Apply soft max correction after an attention step."""
max_scale = jnp.maximum(softmax_aux, softmax_aux_per_step)
min_scale = jnp.minimum(softmax_aux, softmax_aux_per_step)
new_softmax_aux = max_scale + jnp.log(1 + jnp.exp(min_scale - max_scale))
return new_softmax_aux
@staticmethod
def correct_output_and_softmax_aux(output, softmax_aux, partial_output, partial_softmax_aux):
"""
Corrects the output and softmax_aux tensor after each iteration of ring attention.
See https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795 for
derivation of this equation.
"""
new_out = output - jax.nn.sigmoid(partial_softmax_aux - softmax_aux).transpose(
0, 2, 1, 3
) * (output - partial_output)
new_aux = softmax_aux - jax.nn.log_sigmoid(softmax_aux - partial_softmax_aux)
return new_out, new_aux

def adjust_seqlen(self, seqlen, max_seqlen, idx):
"""Adjust the sequence length per step."""
Expand Down Expand Up @@ -1615,10 +1622,7 @@ def ring_attn_fwd_impl(
cp_rank = get_mesh_axis_rank(config.cp_axis, mesh)
cp_perm = [(i, (i + 1) % cp_size) for i in range(cp_size)]

output_per_steps = jnp.zeros((cp_size, *q.shape), dtype=q.dtype)
softmax_aux_per_steps = jnp.zeros(
(cp_size, batch, head, q_max_seqlen, 1), dtype=jnp.float32
)
output = jnp.zeros(q.shape).astype(jnp.float32)
softmax_aux = jnp.full((batch, head, q_max_seqlen, 1), -jnp.inf, dtype=jnp.float32)

# RNG shape should be the shared shape. This is unused for ring attention as we do not
Expand All @@ -1627,7 +1631,7 @@ def ring_attn_fwd_impl(
rng_state = jnp.zeros(rng_state_shape).astype(result_infos[2].dtype)

def scan_kv_block(idx, carry):
kv, softmax_aux, output_per_steps, softmax_aux_per_steps = carry
kv, output, softmax_aux = carry

# Send KV block to next step so we can overlap compute.
kv_next = helper.permute_kv(kv, cp_perm)
Expand Down Expand Up @@ -1718,25 +1722,38 @@ def jax_cond_wrap():
else:
output_per_step, softmax_aux_per_step = no_mask_compute()

softmax_aux = helper.correct_softmax_aux(softmax_aux, softmax_aux_per_step)
output_per_steps = output_per_steps.at[idx].set(output_per_step)
softmax_aux_per_steps = softmax_aux_per_steps.at[idx].set(softmax_aux_per_step)
def skip_correction(output, softmax_aux, output_per_step, softmax_aux_per_step):
# No correction done here but we cast outputs to float32 and perform reduction
# in full precision.
# pylint: disable=unused-argument
return output_per_step.astype(jnp.float32), softmax_aux_per_step

return (kv_next, softmax_aux, output_per_steps, softmax_aux_per_steps)
def correction(output, softmax_aux, output_per_step, softmax_aux_per_step):
return helper.correct_output_and_softmax_aux(
output, softmax_aux, output_per_step, softmax_aux_per_step
)

carry = (kv, softmax_aux, output_per_steps, softmax_aux_per_steps)
# first step there is no correction we get initial output and stats
output, softmax_aux = lax.cond(
(idx == 0),
skip_correction,
correction,
output,
softmax_aux,
output_per_step,
softmax_aux_per_step,
)

return (kv_next, output, softmax_aux)

carry = (kv, output, softmax_aux)
if helper.use_scanloop():
carry = lax.fori_loop(0, cp_size, scan_kv_block, carry)
else:
for i in range(0, cp_size):
carry = scan_kv_block(i, carry)
(kv, softmax_aux, output_per_steps, softmax_aux_per_steps) = carry
(kv, output, softmax_aux) = carry

output = jnp.zeros(q.shape).astype(jnp.float32)
for idx in range(cp_size):
output = output + output_per_steps[idx].astype(jnp.float32) * jnp.exp(
softmax_aux_per_steps[idx] - softmax_aux
).transpose(0, 2, 1, 3)
output = output.astype(q.dtype)
return output, softmax_aux, rng_state

Expand Down
17 changes: 10 additions & 7 deletions transformer_engine/jax/flax/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,15 +194,18 @@ def __call__(
if self.attn_bias_type == AttnBiasType.PRE_SCALE_BIAS:
attn_weights += bias

def apply_swa_mask(attn_mask_type: AttnMaskType, original_mask: Array) -> Array:
def apply_swa_mask(original_mask: Array) -> Array:
"""Apply the sliding window mask to a given mask"""
batch = original_mask.shape[0]
max_seqlen_q = original_mask.shape[-2]
max_seqlen_kv = original_mask.shape[-1]
swa_mask = make_swa_mask(max_seqlen_q, max_seqlen_kv, self.window_size, attn_mask_type)
# In swa_mask 0 is masked out, in original_mask 1 is masked out
swa_mask = 1 - swa_mask.astype(original_mask.dtype)
swa_mask_bcast = jnp.broadcast_to(swa_mask, original_mask.shape)
new_mask = jnp.where(original_mask == 0, swa_mask_bcast, original_mask)
# TODO(rewang): Support THD format pos
pos_q = jnp.broadcast_to(jnp.arange(max_seqlen_q), (batch, max_seqlen_q))
pos_kv = jnp.broadcast_to(jnp.arange(max_seqlen_kv), (batch, max_seqlen_kv))
# In inv_swa_mask 0 is masked out, in original_mask 1 is masked out
inv_swa_mask = make_swa_mask(pos_q, pos_kv, self.window_size, original_mask.dtype)
swa_mask = 1 - inv_swa_mask
new_mask = jnp.where(original_mask == 0, swa_mask, original_mask)
return new_mask

def convert_to_softmax_type(attn_mask_type, mask):
Expand All @@ -213,7 +216,7 @@ def convert_to_softmax_type(attn_mask_type, mask):
if attn_mask_type == AttnMaskType.CAUSAL_MASK and self.window_size is None:
mask = None
if mask is not None:
mask = apply_swa_mask(attn_mask_type, mask)
mask = apply_swa_mask(mask)
# Currently cuDNN backend only supports SWA for causal/padding_causal, follow this
if attn_mask_type in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]:
return SoftmaxType.SCALED_UPPER_TRIANG_MASKED, mask
Expand Down
Loading

0 comments on commit 86109a6

Please sign in to comment.