diff --git a/src/levanter/models/attention.py b/src/levanter/models/attention.py index 2b1fa92db..56812ef97 100644 --- a/src/levanter/models/attention.py +++ b/src/levanter/models/attention.py @@ -689,10 +689,6 @@ def _tpu_splash_attention( k_ = _reshape_axes_for_bshd_bins(key, k_class, output_order=list("BHSD")).array v_ = _reshape_axes_for_bshd_bins(value, v_class, output_order=list("BHSD")).array - # jax.debug.inspect_array_sharding(q_, callback=lambda sharding: print(f"q_: {sharding}")) - # jax.debug.inspect_array_sharding(k_, callback=lambda sharding: print(f"k_: {sharding}")) - # jax.debug.inspect_array_sharding(v_, callback=lambda sharding: print(f"v_: {sharding}")) - B, Hq, Sq, D = q_.shape Bk, Hk, Sk, Dk = k_.shape @@ -798,7 +794,6 @@ def wrap_flash_attention(q, k, v): precision, prng=prng, ) - print(reference_out_shape.dtype, attn_output.dtype) attn_output = attn_output.rearrange(reference_out_shape.axes).astype(reference_out_shape.dtype) attn_output = haliax.shard(attn_output)