From 776864f62405170585f198bfffe471afc65c83ef Mon Sep 17 00:00:00 2001 From: David Hall Date: Thu, 25 Apr 2024 16:48:47 -0700 Subject: [PATCH] wip adding splash attention --- src/levanter/models/attention.py | 437 +++++++++++++++++++++++-------- tests/test_attention.py | 18 +- tests/test_flash_attention.py | 25 ++ 3 files changed, 364 insertions(+), 116 deletions(-) diff --git a/src/levanter/models/attention.py b/src/levanter/models/attention.py index e7ac02f8b..c47a58830 100644 --- a/src/levanter/models/attention.py +++ b/src/levanter/models/attention.py @@ -1,3 +1,4 @@ +import functools import math import warnings from typing import Optional, Union, overload @@ -5,12 +6,16 @@ import equinox as eqx import jax import jax.numpy as jnp +from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel, splash_attention_mask +from jax.experimental.shard_map import shard_map +from jax.sharding import PartitionSpec from jaxtyping import PRNGKeyArray import haliax from haliax import Axis, AxisSelection, AxisSelector, NamedArray from haliax.jax_utils import named_call from haliax.nn.attention import causal_mask, combine_masks_and, combine_masks_or +from haliax.partitioning import pspec_for_axis from haliax.types import PrecisionLike @@ -68,48 +73,41 @@ def dot_product_attention( QPos, KPos, Key, query, key, value, mask, bias, inference, dropout, attention_dtype, precision, prng=prng ) elif accelerator_type == "gpu": - try: - return _te_flash_attention( - QPos, - KPos, - Key, - query, - key, - value, - block_size=flash_block_size, - mask=mask, - bias=bias, - dropout=dropout, - inference=inference, - precision=precision, - prng=prng, - attention_dtype=attention_dtype, - ) - except ImportError as e: - if "transformer_engine" not in str(e): - raise - - warnings.warn( - "transformer_engine is not installed. Please install it to use NVIDIA's optimized fused attention. " - "Falling back to the reference implementation." - ) - except NotImplementedError as e: - message = str(e) - warnings.warn( - f"Could not use transformer_engine for flash attention: {message}. Falling back to the reference" - ) - except ValueError as e: - message = str(e) - if message.startswith("Unsupported backend="): - _dtype = attention_dtype or query.dtype - msg = "TE doesn't work with these arguments. Falling back to the reference implementation.\n" - "Check nvte_get_fused_attn_backend for supported configurations:\n" - "https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/common/fused_attn/fused_attn.cpp#L71" - if _dtype not in (jnp.float16, jnp.bfloat16, jnp.float8_e5m2, jnp.float8_e4m3fn): - msg += f"In particular, TE doesn't support {_dtype} yet." - warnings.warn(msg) - else: - raise + attention_out = _try_te_attention( + QPos, + KPos, + Key, + query, + key, + value, + mask, + bias, + dropout, + inference, + prng=prng, + attention_dtype=attention_dtype, + precision=precision, + flash_block_size=flash_block_size, + ) + if attention_out is not None: + return attention_out + elif accelerator_type == "tpu": + return _tpu_splash_attention( + QPos, + KPos, + Key, + query, + key, + value, + mask, + bias, + dropout, + inference, + prng=prng, + attention_dtype=attention_dtype, + precision=precision, + block_size=flash_block_size, + ) from levanter.models.flash_attention import flash_attention @@ -157,6 +155,71 @@ def simple_attention_with_dropout( return haliax.dot(KPos, weights, value) +def _try_te_attention( + QPos: AxisSelector, + KPos: AxisSelection, + Key: AxisSelector, + query: NamedArray, + key: NamedArray, + value: NamedArray, + mask: Optional[Union[NamedArray, "AttentionMask"]] = None, + bias: Optional[NamedArray] = None, + dropout: float = 0.0, + inference: bool = False, + *, + prng: Optional[PRNGKeyArray] = None, + attention_dtype: Optional[jnp.dtype] = None, + precision: PrecisionLike = None, + flash_block_size: Optional[int] = None, +): + try: + return _te_flash_attention( + QPos, + KPos, + Key, + query, + key, + value, + block_size=flash_block_size, + mask=mask, + bias=bias, + dropout=dropout, + inference=inference, + precision=precision, + prng=prng, + attention_dtype=attention_dtype, + ) + except ImportError as e: + if "transformer_engine" not in str(e): + raise + + warnings.warn( + "transformer_engine is not installed. Please install it to use NVIDIA's optimized fused attention. " + "Falling back to the reference implementation." + ) + + return None + except NotImplementedError as e: + message = str(e) + warnings.warn( + f"Could not use transformer_engine for flash attention: {message}. Falling back to the reference" + ) + return None + except ValueError as e: + message = str(e) + if message.startswith("Unsupported backend="): + _dtype = attention_dtype or query.dtype + msg = "TE doesn't work with these arguments. Falling back to the reference implementation.\n" + "Check nvte_get_fused_attn_backend for supported configurations:\n" + "https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/common/fused_attn/fused_attn.cpp#L71" + if _dtype not in (jnp.float16, jnp.bfloat16, jnp.float8_e5m2, jnp.float8_e4m3fn): + msg += f"In particular, TE doesn't support {_dtype} yet." + warnings.warn(msg) + else: + raise + return None + + def _te_flash_attention( QPos: AxisSelector, KPos: AxisSelection, @@ -174,16 +237,6 @@ def _te_flash_attention( precision: PrecisionLike = None, block_size: Optional[int] = None, ): - # try: - # from transformer_engine.jax.fused_attn import fused_attn_kvpacked - # except ImportError: - # from transformer_engine.jax.fused_attn import cross_fused_attn as fused_attn_kvpacked - # - # try: - # from transformer_engine.jax.fused_attn import fused_attn_qkvpacked - # except ImportError: - # from transformer_engine.jax.fused_attn import self_fused_attn as fused_attn_qkvpacked - # from transformer_engine.jax.fused_attn import fused_attn # noqa: F401 from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType # noqa: F401 @@ -198,10 +251,10 @@ def _te_flash_attention( # references: https://github.com/NVIDIA/TransformerEngine/blob/8255f87f3ee8076db21777795ce15b6ddf8754c0/transformer_engine/jax/fused_attn.py#L31 # https://github.com/NVIDIA/TransformerEngine/blob/8255f87f3ee8076db21777795ce15b6ddf8754c0/transformer_engine/jax/flax/transformer.py#L269 - q_class, k_class, v_class = _te_bin_and_group_axes_by_function(query, key, value, QPos, KPos, Key) - q_: jax.Array = _reshape_axes_for_te_bins(query, q_class).array - k_ = _reshape_axes_for_te_bins(key, k_class).array - v_ = _reshape_axes_for_te_bins(value, v_class).array + q_class, k_class, v_class = _bin_and_group_axes_by_function(query, key, value, QPos, KPos, Key) + q_: jax.Array = _reshape_axes_for_bshd_bins(query, q_class).array + k_ = _reshape_axes_for_bshd_bins(key, k_class).array + v_ = _reshape_axes_for_bshd_bins(value, v_class).array B, Sq, Hq, D = q_.shape Bk, Sk, Hk, Dk = k_.shape @@ -231,35 +284,6 @@ def _te_flash_attention( if bias: raise NotImplementedError("Using bias with flash attention on GPU is not currently implemented.") - # if q_.shape == k_.shape == v_.shape: - # # can use self_fused_attn - # qkv_ = jnp.stack((q_, k_, v_), axis=2) - # attn_output = ( - # qkv=qkv_, # jnp.ndarray, - # bias=fused_attn_bias, # jnp.ndarray, - # mask=fused_attn_mask, # jnp.ndarray, - # seed=prng, # jnp.ndarray, - # attn_bias_type=attn_bias_type, # AttnBiasType, - # attn_mask_type=attn_mask_type, # AttnMaskType, - # scaling_factor=scaling_factor, # float, - # dropout_probability=dropout, # float, - # is_training=is_training, # bool, - # ) - # elif k_.shape == v_.shape: - # kv_ = jnp.stack((k_, v_), axis=2) - # attn_output = cross_fused_attn( - # q=q_, - # kv=kv_, - # bias=fused_attn_bias, - # mask=fused_attn_mask, - # seed=prng, - # attn_bias_type=attn_bias_type, - # attn_mask_type=attn_mask_type, - # scaling_factor=scaling_factor, - # dropout_probability=dropout, - # is_training=is_training, - # ) - # else: attn_output = fused_attn( q=q_, k=k_, @@ -279,27 +303,24 @@ def _te_flash_attention( attn_output = haliax.named(attn_output, ("B", "S", "H", "D")) # the output shape is B, S_q, H_q, D_v. Right now we're requiring D_k == D_v # we can reshape it to match our expected output - attn_output = attn_output.unflatten_axis("B", q_class["B"]) - attn_output = attn_output.unflatten_axis("S", q_class["S"]) - attn_output = attn_output.unflatten_axis("H", q_class["H"]) - attn_output = attn_output.unflatten_axis("D", v_class["D"]) - output_axes = eqx.filter_eval_shape( - simple_attention_with_dropout, - QPos, + attn_output = _restore_named_axes( + attn_output, KPos, Key, - query, - key, - value, - mask, + QPos, + attention_dtype, bias, - inference, dropout, - attention_dtype, + inference, + key, + mask, precision, - prng=prng, - ).axes - attn_output = attn_output.rearrange(output_axes) + prng, + q_class, + query, + v_class, + value, + ) return attn_output @@ -338,12 +359,13 @@ def _te_materialize_mask(KPos, QPos, batch_size, mask): _DUMMY_BATCH = "__batch__" -def _te_bin_and_group_axes_by_function(q, k, v, QPos, KPos, Key): +def _bin_and_group_axes_by_function(q, k, v, QPos, KPos, Key): """ - TransformerEngine's fused attention API is not as expressive as ours is, so we have to do some grouping - to make this work. + TE and the Splash Attention kernel require the Q, K, and V to be in a specific format. This function groups the axes + of Q, K, and V into the right bins to match that format. - TE requires Q, K, and V to have shape BSHD (Batch, Sequence, Head, Embed). the size of the axes is a bit flexible, + TE requires Q, K, and V to have shape BSHD (Batch, Sequence, Head, Embed), while Splash Attention requires BHSD + the size of the axes is a bit flexible, with the following conditions: - B must be the same for all (TODO: is this true?) - S must be the same for K and V. Q's S can be different @@ -418,7 +440,7 @@ def _te_bin_and_group_axes_by_function(q, k, v, QPos, KPos, Key): return q_class, k_class, v_class -def _reshape_axes_for_te_bins(q, q_class): +def _reshape_axes_for_bshd_bins(q, q_class, output_order=("B", "S", "H", "D")): """ Reshape the axes of a qkv as BSHD to match the bins in q_class """ @@ -434,7 +456,7 @@ def _maybe_flatten(q, axes, name): q = _maybe_flatten(q, q_class["S"], "S") q = _maybe_flatten(q, q_class["H"], "H") q = _maybe_flatten(q, q_class["D"], "D") - q = q.rearrange(("B", "S", "H", "D")) + q = q.rearrange(output_order) return q @@ -557,3 +579,204 @@ def materialize_mask( # TODO: padding mask # TODO: FCM mask? # TODO: sequence packing mask + + +# CF https://github.com/google/maxtext/blob/db31dd4b0b686bca4cd7cf940917ec372faa183a/MaxText/layers/attentions.py#L179 +def _tpu_splash_attention( + QPos: AxisSelector, + KPos: AxisSelection, + Key: AxisSelector, + query: NamedArray, + key: NamedArray, + value: NamedArray, + mask: Optional[Union[NamedArray, "AttentionMask"]] = None, + bias: Optional[NamedArray] = None, + dropout: float = 0.0, + inference: bool = False, + *, + prng: Optional[PRNGKeyArray] = None, + attention_dtype: Optional[jnp.dtype] = None, + precision: PrecisionLike = None, + block_size: Optional[int] = None, +) -> NamedArray: + # Splash attention requires BHSD format + # We need to reshape the input to match this format + if dropout != 0.0: + raise NotImplementedError("Splash attention does not support dropout") + + if bias is not None: + raise NotImplementedError("Splash attention does not support bias") + + q_class, k_class, v_class = _bin_and_group_axes_by_function(query, key, value, QPos, KPos, Key) + + q_: jax.Array = _reshape_axes_for_bshd_bins(query, q_class, output_order=list("BHSD")).array + k_ = _reshape_axes_for_bshd_bins(key, k_class, output_order=list("BHSD")).array + v_ = _reshape_axes_for_bshd_bins(value, v_class, output_order=list("BHSD")).array + + jax.debug.inspect_array_sharding(q_, callback=lambda sharding: print(f"q_: {sharding}")) + jax.debug.inspect_array_sharding(k_, callback=lambda sharding: print(f"k_: {sharding}")) + jax.debug.inspect_array_sharding(v_, callback=lambda sharding: print(f"v_: {sharding}")) + + B, Hq, Sq, D = q_.shape + Bk, Hk, Sk, Dk = k_.shape + + QPos = query.resolve_axis(QPos) + KPos = key.resolve_axis(KPos) + + # TODO: must Dk == Dv? + if k_.shape != v_.shape: + raise ValueError("k and v must have the same axes") + + # TODO: this isn't really necessary on TPU? + if B != Bk: + raise ValueError(f"Batch axes must be the same for q, k, and v: {q_class['B']} != {k_class['B']}") + + if D != Dk: + raise ValueError(f"Embedding axes must be the same for q, k, and v: {q_class['D']} != {k_class['D']}") + + def _physical_axis_for_binning(d): + b_out = tuple(ax for ax in pspec_for_axis(d["B"]) if ax is not None) or None + h_out = tuple(ax for ax in pspec_for_axis(d["H"]) if ax is not None) or None + s_out = tuple(ax for ax in pspec_for_axis(d["S"]) if ax is not None) or None + d_out = tuple(ax for ax in pspec_for_axis(d["D"]) if ax is not None) or None + + return PartitionSpec(b_out, h_out, s_out, d_out) + + # BHSD + physical_axes_q = _physical_axis_for_binning(q_class) + physical_axes_k = _physical_axis_for_binning(k_class) + physical_axes_v = _physical_axis_for_binning(v_class) + + # MaxText uses a block size of 512 + block_size = block_size or 512 + + # copied from MaxText + @functools.partial( + shard_map, + mesh=haliax.partitioning._get_mesh(), + in_specs=( + physical_axes_q, + physical_axes_k, + physical_axes_v, + ), + out_specs=physical_axes_q, + check_rep=False, + ) + def wrap_flash_attention(q, k, v): + block_sizes = splash_attention_kernel.BlockSizes( + block_q=min(block_size, Sq), + block_kv_compute=min(block_size, Sk), + block_kv=min(block_size, Sk), + block_q_dkv=min(block_size, Sq), + block_kv_dkv=min(block_size, Sk), + block_kv_dkv_compute=min(block_size, Sq), + block_q_dq=min(block_size, Sq), + block_kv_dq=min(block_size, Sq), + ) + + if mask is None: + kernel_mask = splash_attention_mask.FullMask(_shape=(Sq, Sk)) + elif isinstance(mask, AttentionMask): + if mask.is_causal: + masks = [splash_attention_mask.CausalMask(shape=(Sq, Sq)) for i in range(Hq)] + kernel_mask = splash_attention_mask.MultiHeadMask(masks=masks) + else: + kernel_mask = splash_attention_mask.FullMask(_shape=(Sq, Sk)) + + if mask.explicit_mask is not None: + raise NotImplementedError("Explicit masks are not yet supported for splash attention") + elif isinstance(mask, NamedArray): + raise NotImplementedError("NamedArray masks are not yet supported for splash attention") + else: + raise ValueError(f"Unknown mask type: {mask}") + + # copied from MaxText + splash_kernel = splash_attention_kernel.make_splash_mha( + mask=kernel_mask, head_shards=1, q_seq_shards=1, block_sizes=block_sizes + ) + + return jax.vmap(splash_kernel)(q, k, v, segment_ids=None) + + attn_output = wrap_flash_attention(q_, k_, v_) + + attn_output = haliax.named(attn_output, ("B", "H", "S", "D")) + attn_output = _restore_named_axes( + attn_output, + KPos, + Key, + QPos, + attention_dtype, + bias, + dropout, + inference, + key, + mask, + precision, + prng, + q_class, + query, + v_class, + value, + ) + + attn_output = haliax.shard(attn_output) + + return attn_output + + +def _restore_named_axes( + attn_output, + KPos, + Key, + QPos, + attention_dtype, + bias, + dropout, + inference, + key, + mask, + precision, + prng, + q_class, + query, + v_class, + value, +): + # the output shape is B, S_q, H_q, D_v. Right now we're requiring D_k == D_v + # we can reshape it to match our expected output + attn_output = attn_output.unflatten_axis("B", q_class["B"]) + attn_output = attn_output.unflatten_axis("S", q_class["S"]) + attn_output = attn_output.unflatten_axis("H", q_class["H"]) + attn_output = attn_output.unflatten_axis("D", v_class["D"]) + output_axes = eqx.filter_eval_shape( + simple_attention_with_dropout, + QPos, + KPos, + Key, + query, + key, + value, + mask, + bias, + inference, + dropout, + attention_dtype, + precision, + prng=prng, + ).axes + attn_output = attn_output.rearrange(output_axes) + return attn_output + + +class ExplicitMask(eqx.Module): + """ + Represents an explicit mask for attention. This is a mask that is applied to the attention weights directly. + """ + + mask: NamedArray + + def __init__(self, mask: NamedArray): + self.mask = mask + + def materialize(self, QPos: Axis, KPos: Axis) -> NamedArray: + return self.mask diff --git a/tests/test_attention.py b/tests/test_attention.py index 0d908595e..be664281b 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -3,7 +3,7 @@ import haliax as hax -from levanter.models.attention import AttentionMask, _te_bin_and_group_axes_by_function, _te_flash_attention +from levanter.models.attention import AttentionMask, _bin_and_group_axes_by_function, _te_flash_attention from test_utils import skip_if_module_missing @@ -55,7 +55,7 @@ def test_te_bin_and_group_axes_by_function(): k = hax.zeros((B, KPos, H, D)) v = hax.zeros((B, KPos, H, D)) - q_c, k_c, v_c = _te_bin_and_group_axes_by_function(q, k, v, "QPos", "KPos", "D") + q_c, k_c, v_c = _bin_and_group_axes_by_function(q, k, v, "QPos", "KPos", "D") assert q_c["B"] == [B] assert k_c["B"] == [B] assert v_c["B"] == [B] @@ -73,21 +73,21 @@ def test_te_bin_and_group_axes_by_function(): assert v_c["D"] == [D] gq = hax.zeros((B, QPos, H, G, D)) - q_c, k_c, v_c = _te_bin_and_group_axes_by_function(gq, k, v, "QPos", "KPos", "D") + q_c, k_c, v_c = _bin_and_group_axes_by_function(gq, k, v, "QPos", "KPos", "D") assert q_c["H"] == [H, G] assert k_c["H"] == [H] assert v_c["H"] == [H] gk = hax.zeros((B, KPos, G, H, D)) with pytest.raises(ValueError): - _te_bin_and_group_axes_by_function(q, gk, v, "QPos", "KPos", "D") + _bin_and_group_axes_by_function(q, gk, v, "QPos", "KPos", "D") with pytest.raises(ValueError): - _te_bin_and_group_axes_by_function(gq, gk, v, "QPos", "KPos", "D") + _bin_and_group_axes_by_function(gq, gk, v, "QPos", "KPos", "D") for gk_axes in [(B, KPos, G, H, D), (B, KPos, G, H, D), (G, B, KPos, H, D)]: gk = hax.zeros(gk_axes) - q_c, k_c, v_c = _te_bin_and_group_axes_by_function(gq, gk, gk, "QPos", "KPos", "D") + q_c, k_c, v_c = _bin_and_group_axes_by_function(gq, gk, gk, "QPos", "KPos", "D") assert q_c["H"] == [H, G] assert k_c["H"] == [H, G] assert v_c["H"] == [H, G] @@ -96,7 +96,7 @@ def test_te_bin_and_group_axes_by_function(): gq = hax.zeros((G, B, QPos, H, D)) for gk_axes in [(B, KPos, H, G, D), (B, KPos, G, H, D), (G, B, KPos, H, D)]: gk = hax.zeros(gk_axes) - q_c, k_c, v_c = _te_bin_and_group_axes_by_function(gq, gk, gk, "QPos", "KPos", "D") + q_c, k_c, v_c = _bin_and_group_axes_by_function(gq, gk, gk, "QPos", "KPos", "D") assert q_c["H"] == [H] assert k_c["H"] == [H] assert v_c["H"] == [H] @@ -117,7 +117,7 @@ def test_mqa_te_bin_and_group_axes_by_function(): k = hax.zeros((B, KPos, D)) v = hax.zeros((B, KPos, D)) - q_c, k_c, v_c = _te_bin_and_group_axes_by_function(gq, k, v, "QPos", "KPos", "D") + q_c, k_c, v_c = _bin_and_group_axes_by_function(gq, k, v, "QPos", "KPos", "D") assert q_c["H"] == [G] assert k_c["H"] == [] assert v_c["H"] == [] @@ -125,7 +125,7 @@ def test_mqa_te_bin_and_group_axes_by_function(): gk = hax.zeros((B, KPos, G, D)) with pytest.raises(ValueError): # don't currently handle dim in Q and K but not V - _te_bin_and_group_axes_by_function(gq, gk, v, "QPos", "KPos", "D") + _bin_and_group_axes_by_function(gq, gk, v, "QPos", "KPos", "D") @skip_if_module_missing("transformer_engine") diff --git a/tests/test_flash_attention.py b/tests/test_flash_attention.py index 2ee5a93ac..a79aa36fa 100644 --- a/tests/test_flash_attention.py +++ b/tests/test_flash_attention.py @@ -3,11 +3,13 @@ import equinox import jax.numpy as jnp import jax.random as jrandom +import jax.sharding import pytest import haliax as hax import haliax.nn as hnn +import levanter.models.attention from levanter.models.attention import AttentionMask, simple_attention_with_dropout from levanter.models.flash_attention import flash_attention @@ -141,3 +143,26 @@ def test_fa_dropout_does_something(): assert with_o.axes == without_o.axes mean = jnp.mean(jnp.isclose(with_o.array, without_o.array, atol=1e-3, rtol=1e-3)) assert mean < 1e-2 + + +def test_tpu_flash_attention(): + if jax.devices()[0].device_kind != "tpu": + pytest.skip("TPU-only test") + + Key = hax.Axis("Key", 128) + QPos = hax.Axis("QPos", BLOCK_SIZE * 4) + KPos = hax.Axis("KPos", BLOCK_SIZE * 4) + with jax.sharding.Mesh(jax.devices(), ("dp",)): + mask = AttentionMask.causal() + + q = hax.random.normal(jrandom.PRNGKey(0), (QPos, Key)) + k = hax.random.normal(jrandom.PRNGKey(1), (KPos, Key)) + v = hax.random.normal(jrandom.PRNGKey(2), (KPos, Key)) + + flash_out = levanter.models.attention._tpu_splash_attention( + QPos, KPos, Key, q, k, v, inference=True, mask=mask, block_size=BLOCK_SIZE + ) + hax_out = hnn.attention.dot_product_attention(KPos, Key, q, k, v, mask=mask.materialize(QPos, KPos)) + + assert hax_out.axes == flash_out.axes + assert jnp.allclose(hax_out.array, flash_out.array, atol=1e-3, rtol=1e-3)