-
Notifications
You must be signed in to change notification settings - Fork 87
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ensure we cast to the right dtype in attention #571
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changes to Whisper code look good to me.
Overall PR also looks good to me, though I haven't been hitting this issue in my training runs so don't have a reproduction case.
It’s a new regression from the TPU splash attention kernel. Still need to
benchmark it at a full scale
…On Wed, May 1, 2024 at 12:44 PM William Held ***@***.***> wrote:
***@***.**** approved this pull request.
Changes to Whisper code look good to me.
Overall PR also looks good to me, though I haven't been hitting this issue
in my training runs so don't have a reproduction case.
—
Reply to this email directly, view it on GitHub
<#571 (review)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAACLIOFXDH2GYS2YKT6QPLZAFAZ7AVCNFSM6AAAAABHCNISTKVHI2DSMVQWIX3LMV43YUDVNRWFEZLROVSXG5CSMV3GSZLXHMZDAMZUGMZDQMZSGY>
.
You are receiving this because you authored the thread.Message ID:
***@***.***>
|
Sure, I'll test it as soon as I can (OOO atm). |
A few data points with Llama2-7B run on V5lite:
I tested with this branch,
I tried adding back the line to not cast to x.dtype when config.upcast_attn = False (here), but it doesn't help. It still goes OOM: https://wandb.ai/stanford-mercury/markweb/runs/eo41llama40o42560505. |
Training a Mistral 7B (seq len 2048, batch size 2048, parallelism 16), I can confirm |
@Ivan-Zhou set model.flash_attention_block_size to null/None and see how it goes? |
then try 512 |
@dlwh You are right. With reduced FA block size, I am able to resolve this OOM error. I am able to train Llama 7B with using 1024 FA block size and up to 2048 batch size and 4k seq length: https://wandb.ai/stanford-mercury/markweb/runs/eo44039204924kd030d20480507 |
awesome. can you try with it just set to null/unset? this is a new low level kernel built into JAX and the default block size is 512 |
@dlwh I re-launched training jobs with FA block at null (pink) and 512 (blue). They have identical throughput, both less than 1024 (green): |
I think fixes #569
@versae: can you give this branch (fix_attn_dtype) a try? I'm traveling and don't have time right now
cc @Helw150 since it tweaks your code.