Skip to content

Commit

Permalink
[JAX] Correct fused attention output after each step of ring attention (
Browse files Browse the repository at this point in the history
#1393)

Correct fused attention output after each step to reduce intermediate memory use.

Signed-off-by: Michael Goldfarb <[email protected]>
  • Loading branch information
mgoldfarb-nvidia authored Jan 8, 2025
1 parent 61cf102 commit a4cb1d1
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 26 deletions.
8 changes: 4 additions & 4 deletions tests/jax/test_distributed_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ def qkv_to_layout(self, q, k, v, qkv_layout):
raise ValueError(f"Unsupported {qkv_layout=}")
return qkv_args

def impl_test_contex_parallel_attn(
def impl_test_context_parallel_attn(
self,
device_count,
mesh_shape,
Expand Down Expand Up @@ -583,7 +583,7 @@ def grad_func(func, *args, **kwargs):

assert_allclose(target_grads[i], ref_grads[i], dtype=dtype)

def test_contex_parallel_allgather_attn(
def test_context_parallel_allgather_attn(
self,
device_count,
mesh_shape,
Expand All @@ -596,7 +596,7 @@ def test_contex_parallel_allgather_attn(
qkv_layout,
load_balanced,
):
return self.impl_test_contex_parallel_attn(
return self.impl_test_context_parallel_attn(
device_count,
mesh_shape,
mesh_axes,
Expand All @@ -623,7 +623,7 @@ def test_context_parallel_ring_attn(
qkv_layout,
load_balanced,
):
return self.impl_test_contex_parallel_attn(
return self.impl_test_context_parallel_attn(
device_count,
mesh_shape,
mesh_axes,
Expand Down
61 changes: 39 additions & 22 deletions transformer_engine/jax/cpp_extensions/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1549,12 +1549,19 @@ def permute_kv(self, kv, cp_perm):
"""Permutes kv around the ring as described by cp_perm."""
return lax_paral_op(kv, lax.ppermute, self.config.cp_axis, mesh=self.mesh, perm=cp_perm)

def correct_softmax_aux(self, softmax_aux, softmax_aux_per_step):
"""Apply soft max correction after an attention step."""
max_scale = jnp.maximum(softmax_aux, softmax_aux_per_step)
min_scale = jnp.minimum(softmax_aux, softmax_aux_per_step)
new_softmax_aux = max_scale + jnp.log(1 + jnp.exp(min_scale - max_scale))
return new_softmax_aux
@staticmethod
def correct_output_and_softmax_aux(output, softmax_aux, partial_output, partial_softmax_aux):
"""
Corrects the output and softmax_aux tensor after each iteration of ring attention.
See https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795 for
derivation of this equation.
"""
new_out = output - jax.nn.sigmoid(partial_softmax_aux - softmax_aux).transpose(
0, 2, 1, 3
) * (output - partial_output)
new_aux = softmax_aux - jax.nn.log_sigmoid(softmax_aux - partial_softmax_aux)
return new_out, new_aux

def adjust_seqlen(self, seqlen, max_seqlen, idx):
"""Adjust the sequence length per step."""
Expand Down Expand Up @@ -1615,10 +1622,7 @@ def ring_attn_fwd_impl(
cp_rank = get_mesh_axis_rank(config.cp_axis, mesh)
cp_perm = [(i, (i + 1) % cp_size) for i in range(cp_size)]

output_per_steps = jnp.zeros((cp_size, *q.shape), dtype=q.dtype)
softmax_aux_per_steps = jnp.zeros(
(cp_size, batch, head, q_max_seqlen, 1), dtype=jnp.float32
)
output = jnp.zeros(q.shape).astype(jnp.float32)
softmax_aux = jnp.full((batch, head, q_max_seqlen, 1), -jnp.inf, dtype=jnp.float32)

# RNG shape should be the shared shape. This is unused for ring attention as we do not
Expand All @@ -1627,7 +1631,7 @@ def ring_attn_fwd_impl(
rng_state = jnp.zeros(rng_state_shape).astype(result_infos[2].dtype)

def scan_kv_block(idx, carry):
kv, softmax_aux, output_per_steps, softmax_aux_per_steps = carry
kv, output, softmax_aux = carry

# Send KV block to next step so we can overlap compute.
kv_next = helper.permute_kv(kv, cp_perm)
Expand Down Expand Up @@ -1718,25 +1722,38 @@ def jax_cond_wrap():
else:
output_per_step, softmax_aux_per_step = no_mask_compute()

softmax_aux = helper.correct_softmax_aux(softmax_aux, softmax_aux_per_step)
output_per_steps = output_per_steps.at[idx].set(output_per_step)
softmax_aux_per_steps = softmax_aux_per_steps.at[idx].set(softmax_aux_per_step)
def skip_correction(output, softmax_aux, output_per_step, softmax_aux_per_step):
# No correction done here but we cast outputs to float32 and perform reduction
# in full precision.
# pylint: disable=unused-argument
return output_per_step.astype(jnp.float32), softmax_aux_per_step

return (kv_next, softmax_aux, output_per_steps, softmax_aux_per_steps)
def correction(output, softmax_aux, output_per_step, softmax_aux_per_step):
return helper.correct_output_and_softmax_aux(
output, softmax_aux, output_per_step, softmax_aux_per_step
)

carry = (kv, softmax_aux, output_per_steps, softmax_aux_per_steps)
# first step there is no correction we get initial output and stats
output, softmax_aux = lax.cond(
(idx == 0),
skip_correction,
correction,
output,
softmax_aux,
output_per_step,
softmax_aux_per_step,
)

return (kv_next, output, softmax_aux)

carry = (kv, output, softmax_aux)
if helper.use_scanloop():
carry = lax.fori_loop(0, cp_size, scan_kv_block, carry)
else:
for i in range(0, cp_size):
carry = scan_kv_block(i, carry)
(kv, softmax_aux, output_per_steps, softmax_aux_per_steps) = carry
(kv, output, softmax_aux) = carry

output = jnp.zeros(q.shape).astype(jnp.float32)
for idx in range(cp_size):
output = output + output_per_steps[idx].astype(jnp.float32) * jnp.exp(
softmax_aux_per_steps[idx] - softmax_aux
).transpose(0, 2, 1, 3)
output = output.astype(q.dtype)
return output, softmax_aux, rng_state

Expand Down

0 comments on commit a4cb1d1

Please sign in to comment.