From 407d54b63e141d6261754986453bb1ffd1c8afb7 Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 1 May 2024 12:00:24 -0700 Subject: [PATCH] ensure we cast to the right dtype --- src/levanter/models/attention.py | 92 +++++++++++--------------------- src/levanter/models/gpt2.py | 5 +- src/levanter/models/llama.py | 4 +- src/levanter/models/mpt.py | 2 + src/levanter/models/whisper.py | 5 +- 5 files changed, 37 insertions(+), 71 deletions(-) diff --git a/src/levanter/models/attention.py b/src/levanter/models/attention.py index 3251b47e3..c610c5bf9 100644 --- a/src/levanter/models/attention.py +++ b/src/levanter/models/attention.py @@ -305,24 +305,27 @@ def _te_flash_attention( attn_output = haliax.named(attn_output, ("B", "S", "H", "D")) # the output shape is B, S_q, H_q, D_v. Right now we're requiring D_k == D_v # we can reshape it to match our expected output - attn_output = _restore_named_axes( - attn_output, + # the output shape is B, S_q, H_q, D_v. Right now we're requiring D_k == D_v + # we can reshape it to match our expected output + attn_output = _unflatten_bshd(attn_output, q_class, v_class) + + reference_out_shape = eqx.filter_eval_shape( + simple_attention_with_dropout, + QPos, KPos, Key, - QPos, - attention_dtype, - bias, - dropout, - inference, + query, key, + value, mask, + bias, + inference, + dropout, + attention_dtype, precision, - prng, - q_class, - query, - v_class, - value, + prng=prng, ) + attn_output = attn_output.rearrange(reference_out_shape.axes).astype(reference_out_shape.dtype) return attn_output @@ -462,6 +465,14 @@ def _maybe_flatten(q, axes, name): return q +def _unflatten_bshd(attn_output, q_class, v_class): + attn_output = attn_output.unflatten_axis("B", q_class["B"]) + attn_output = attn_output.unflatten_axis("S", q_class["S"]) + attn_output = attn_output.unflatten_axis("H", q_class["H"]) + attn_output = attn_output.unflatten_axis("D", v_class["D"]) + return attn_output + + class AttentionMask(eqx.Module): """ @@ -673,9 +684,9 @@ def _tpu_splash_attention( k_ = _reshape_axes_for_bshd_bins(key, k_class, output_order=list("BHSD")).array v_ = _reshape_axes_for_bshd_bins(value, v_class, output_order=list("BHSD")).array - jax.debug.inspect_array_sharding(q_, callback=lambda sharding: print(f"q_: {sharding}")) - jax.debug.inspect_array_sharding(k_, callback=lambda sharding: print(f"k_: {sharding}")) - jax.debug.inspect_array_sharding(v_, callback=lambda sharding: print(f"v_: {sharding}")) + # jax.debug.inspect_array_sharding(q_, callback=lambda sharding: print(f"q_: {sharding}")) + # jax.debug.inspect_array_sharding(k_, callback=lambda sharding: print(f"k_: {sharding}")) + # jax.debug.inspect_array_sharding(v_, callback=lambda sharding: print(f"v_: {sharding}")) B, Hq, Sq, D = q_.shape Bk, Hk, Sk, Dk = k_.shape @@ -759,60 +770,14 @@ def wrap_flash_attention(q, k, v): q = q.astype(jnp.float32) k = k.astype(jnp.float32) v = v.astype(jnp.float32) - print(q.dtype, k.dtype, v.dtype) return jax.vmap(splash_kernel)(q, k, v, segment_ids=None) attn_output = wrap_flash_attention(q_, k_, v_) attn_output = haliax.named(attn_output, ("B", "H", "S", "D")) - attn_output = _restore_named_axes( - attn_output, - KPos, - Key, - QPos, - attention_dtype, - bias, - dropout, - inference, - key, - mask, - precision, - prng, - q_class, - query, - v_class, - value, - ) - - attn_output = haliax.shard(attn_output) - - return attn_output - - -def _restore_named_axes( - attn_output, - KPos, - Key, - QPos, - attention_dtype, - bias, - dropout, - inference, - key, - mask, - precision, - prng, - q_class, - query, - v_class, - value, -): # the output shape is B, S_q, H_q, D_v. Right now we're requiring D_k == D_v # we can reshape it to match our expected output - attn_output = attn_output.unflatten_axis("B", q_class["B"]) - attn_output = attn_output.unflatten_axis("S", q_class["S"]) - attn_output = attn_output.unflatten_axis("H", q_class["H"]) - attn_output = attn_output.unflatten_axis("D", v_class["D"]) + attn_output = _unflatten_bshd(attn_output, q_class, v_class) reference_out_shape = eqx.filter_eval_shape( simple_attention_with_dropout, QPos, @@ -830,4 +795,7 @@ def _restore_named_axes( prng=prng, ) attn_output = attn_output.rearrange(reference_out_shape.axes).astype(reference_out_shape.dtype) + + attn_output = haliax.shard(attn_output) + return attn_output diff --git a/src/levanter/models/gpt2.py b/src/levanter/models/gpt2.py index 17abd68ae..191ac689d 100644 --- a/src/levanter/models/gpt2.py +++ b/src/levanter/models/gpt2.py @@ -195,10 +195,9 @@ def __call__(self, x: NamedArray, mask: Optional[AttentionMask | NamedArray], la prng=k_drop, attention_dtype=jnp.float32 if self.config.upcast_attn else None, ) - attn_output = self.c_proj(attn_output, key=k_out) - if self.config.upcast_attn: - attn_output = attn_output.astype(x.dtype) + attn_output = attn_output.astype(x.dtype) + attn_output = self.c_proj(attn_output, key=k_out) return attn_output diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index ff76039d9..6d74241c5 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -274,9 +274,7 @@ def __call__(self, x: NamedArray, mask: Optional[NamedArray | AttentionMask], *, ) attn_output = attn_output.flatten_axes(("kv_heads", "q_heads_per_group"), "heads") - - if self.config.upcast_attn: - attn_output = attn_output.astype(x.dtype) + attn_output = attn_output.astype(x.dtype) attn_output = self.o_proj(attn_output, key=key_o) return attn_output diff --git a/src/levanter/models/mpt.py b/src/levanter/models/mpt.py index 9c31e63b6..84c9bd5d8 100644 --- a/src/levanter/models/mpt.py +++ b/src/levanter/models/mpt.py @@ -283,6 +283,8 @@ def __call__( ) attn_output = self.out_proj(attn_output, key=k_out) + attn_output = attn_output.astype(hidden_states.dtype) + return attn_output def from_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None): diff --git a/src/levanter/models/whisper.py b/src/levanter/models/whisper.py index b0ac6941b..ff725151b 100644 --- a/src/levanter/models/whisper.py +++ b/src/levanter/models/whisper.py @@ -224,10 +224,9 @@ def __call__(self, x: NamedArray, xa: Optional[NamedArray] = None, mask: Optiona prng=k_drop, attention_dtype=jnp.float32 if self.config.upcast_attn else None, ) - attn_output = self.out_proj(attn_output, key=k_out) - if self.config.upcast_attn: - attn_output = attn_output.astype(x.dtype) + attn_output = attn_output.astype(x.dtype) + attn_output = self.out_proj(attn_output, key=k_out) return attn_output