Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ensure we cast to the right dtype in attention #571

Merged
merged 3 commits into from
May 10, 2024
Merged

ensure we cast to the right dtype in attention #571

merged 3 commits into from
May 10, 2024

Conversation

dlwh
Copy link
Member

@dlwh dlwh commented May 1, 2024

I think fixes #569

@versae: can you give this branch (fix_attn_dtype) a try? I'm traveling and don't have time right now

cc @Helw150 since it tweaks your code.

Copy link
Collaborator

@Helw150 Helw150 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes to Whisper code look good to me.

Overall PR also looks good to me, though I haven't been hitting this issue in my training runs so don't have a reproduction case.

@dlwh
Copy link
Member Author

dlwh commented May 1, 2024 via email

@versae
Copy link
Contributor

versae commented May 2, 2024

Sure, I'll test it as soon as I can (OOO atm).

@Ivan-Zhou
Copy link
Contributor

Ivan-Zhou commented May 5, 2024

A few data points with Llama2-7B run on V5lite:

I tested with this branch, fix_attn_dtype, then the dtype error was avoided, but it kept running into OOM error, even if I reduce the batch size to 256:

I tried adding back the line to not cast to x.dtype when config.upcast_attn = False (here), but it doesn't help. It still goes OOM: https://wandb.ai/stanford-mercury/markweb/runs/eo41llama40o42560505.

@versae
Copy link
Contributor

versae commented May 7, 2024

Training a Mistral 7B (seq len 2048, batch size 2048, parallelism 16), I can confirm fix_attn_type seems to fix the dtype error. I haven't encountered any OOM in my limited experiments yet.

@dlwh
Copy link
Member Author

dlwh commented May 7, 2024

@Ivan-Zhou set model.flash_attention_block_size to null/None and see how it goes?

@dlwh
Copy link
Member Author

dlwh commented May 7, 2024

then try 512

@Ivan-Zhou
Copy link
Contributor

@dlwh You are right. With reduced FA block size, I am able to resolve this OOM error.

I am able to train Llama 7B with using 1024 FA block size and up to 2048 batch size and 4k seq length: https://wandb.ai/stanford-mercury/markweb/runs/eo44039204924kd030d20480507

@dlwh
Copy link
Member Author

dlwh commented May 8, 2024

awesome. can you try with it just set to null/unset? this is a new low level kernel built into JAX and the default block size is 512

@Ivan-Zhou
Copy link
Contributor

@dlwh I re-launched training jobs with FA block at null (pink) and 512 (blue). They have identical throughput, both less than 1024 (green):
image

@dlwh dlwh merged commit 2888a35 into main May 10, 2024
4 checks passed
@dlwh dlwh deleted the fix_attn_dtype branch May 10, 2024 03:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Scanning layers seems not to allow the mixing of precisions in mp anymore
4 participants