Skip to content

Commit

Permalink
ensure we cast to the right dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed May 1, 2024
1 parent 2516d06 commit 407d54b
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 71 deletions.
92 changes: 30 additions & 62 deletions src/levanter/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,24 +305,27 @@ def _te_flash_attention(
attn_output = haliax.named(attn_output, ("B", "S", "H", "D"))
# the output shape is B, S_q, H_q, D_v. Right now we're requiring D_k == D_v
# we can reshape it to match our expected output
attn_output = _restore_named_axes(
attn_output,
# the output shape is B, S_q, H_q, D_v. Right now we're requiring D_k == D_v
# we can reshape it to match our expected output
attn_output = _unflatten_bshd(attn_output, q_class, v_class)

reference_out_shape = eqx.filter_eval_shape(
simple_attention_with_dropout,
QPos,
KPos,
Key,
QPos,
attention_dtype,
bias,
dropout,
inference,
query,
key,
value,
mask,
bias,
inference,
dropout,
attention_dtype,
precision,
prng,
q_class,
query,
v_class,
value,
prng=prng,
)
attn_output = attn_output.rearrange(reference_out_shape.axes).astype(reference_out_shape.dtype)

return attn_output

Expand Down Expand Up @@ -462,6 +465,14 @@ def _maybe_flatten(q, axes, name):
return q


def _unflatten_bshd(attn_output, q_class, v_class):
attn_output = attn_output.unflatten_axis("B", q_class["B"])
attn_output = attn_output.unflatten_axis("S", q_class["S"])
attn_output = attn_output.unflatten_axis("H", q_class["H"])
attn_output = attn_output.unflatten_axis("D", v_class["D"])
return attn_output


class AttentionMask(eqx.Module):
"""
Expand Down Expand Up @@ -673,9 +684,9 @@ def _tpu_splash_attention(
k_ = _reshape_axes_for_bshd_bins(key, k_class, output_order=list("BHSD")).array
v_ = _reshape_axes_for_bshd_bins(value, v_class, output_order=list("BHSD")).array

jax.debug.inspect_array_sharding(q_, callback=lambda sharding: print(f"q_: {sharding}"))
jax.debug.inspect_array_sharding(k_, callback=lambda sharding: print(f"k_: {sharding}"))
jax.debug.inspect_array_sharding(v_, callback=lambda sharding: print(f"v_: {sharding}"))
# jax.debug.inspect_array_sharding(q_, callback=lambda sharding: print(f"q_: {sharding}"))
# jax.debug.inspect_array_sharding(k_, callback=lambda sharding: print(f"k_: {sharding}"))
# jax.debug.inspect_array_sharding(v_, callback=lambda sharding: print(f"v_: {sharding}"))

B, Hq, Sq, D = q_.shape
Bk, Hk, Sk, Dk = k_.shape
Expand Down Expand Up @@ -759,60 +770,14 @@ def wrap_flash_attention(q, k, v):
q = q.astype(jnp.float32)
k = k.astype(jnp.float32)
v = v.astype(jnp.float32)
print(q.dtype, k.dtype, v.dtype)
return jax.vmap(splash_kernel)(q, k, v, segment_ids=None)

attn_output = wrap_flash_attention(q_, k_, v_)

attn_output = haliax.named(attn_output, ("B", "H", "S", "D"))
attn_output = _restore_named_axes(
attn_output,
KPos,
Key,
QPos,
attention_dtype,
bias,
dropout,
inference,
key,
mask,
precision,
prng,
q_class,
query,
v_class,
value,
)

attn_output = haliax.shard(attn_output)

return attn_output


def _restore_named_axes(
attn_output,
KPos,
Key,
QPos,
attention_dtype,
bias,
dropout,
inference,
key,
mask,
precision,
prng,
q_class,
query,
v_class,
value,
):
# the output shape is B, S_q, H_q, D_v. Right now we're requiring D_k == D_v
# we can reshape it to match our expected output
attn_output = attn_output.unflatten_axis("B", q_class["B"])
attn_output = attn_output.unflatten_axis("S", q_class["S"])
attn_output = attn_output.unflatten_axis("H", q_class["H"])
attn_output = attn_output.unflatten_axis("D", v_class["D"])
attn_output = _unflatten_bshd(attn_output, q_class, v_class)
reference_out_shape = eqx.filter_eval_shape(
simple_attention_with_dropout,
QPos,
Expand All @@ -830,4 +795,7 @@ def _restore_named_axes(
prng=prng,
)
attn_output = attn_output.rearrange(reference_out_shape.axes).astype(reference_out_shape.dtype)

attn_output = haliax.shard(attn_output)

return attn_output
5 changes: 2 additions & 3 deletions src/levanter/models/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,10 +195,9 @@ def __call__(self, x: NamedArray, mask: Optional[AttentionMask | NamedArray], la
prng=k_drop,
attention_dtype=jnp.float32 if self.config.upcast_attn else None,
)
attn_output = self.c_proj(attn_output, key=k_out)

if self.config.upcast_attn:
attn_output = attn_output.astype(x.dtype)
attn_output = attn_output.astype(x.dtype)
attn_output = self.c_proj(attn_output, key=k_out)

return attn_output

Expand Down
4 changes: 1 addition & 3 deletions src/levanter/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,9 +274,7 @@ def __call__(self, x: NamedArray, mask: Optional[NamedArray | AttentionMask], *,
)

attn_output = attn_output.flatten_axes(("kv_heads", "q_heads_per_group"), "heads")

if self.config.upcast_attn:
attn_output = attn_output.astype(x.dtype)
attn_output = attn_output.astype(x.dtype)

attn_output = self.o_proj(attn_output, key=key_o)
return attn_output
Expand Down
2 changes: 2 additions & 0 deletions src/levanter/models/mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,8 @@ def __call__(
)

attn_output = self.out_proj(attn_output, key=k_out)
attn_output = attn_output.astype(hidden_states.dtype)

return attn_output

def from_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None):
Expand Down
5 changes: 2 additions & 3 deletions src/levanter/models/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,10 +224,9 @@ def __call__(self, x: NamedArray, xa: Optional[NamedArray] = None, mask: Optiona
prng=k_drop,
attention_dtype=jnp.float32 if self.config.upcast_attn else None,
)
attn_output = self.out_proj(attn_output, key=k_out)

if self.config.upcast_attn:
attn_output = attn_output.astype(x.dtype)
attn_output = attn_output.astype(x.dtype)
attn_output = self.out_proj(attn_output, key=k_out)

return attn_output

Expand Down

0 comments on commit 407d54b

Please sign in to comment.