-
Notifications
You must be signed in to change notification settings - Fork 333
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support device mapping for Paged Attention #1011
Support device mapping for Paged Attention #1011
Conversation
Code Metrics Report=============================================================================== Language Files Lines Code Comments Blanks =============================================================================== C Header 2 35 28 0 7 Dockerfile 1 41 22 10 9 JSON 12 105 104 0 1 Python 63 2706 2338 71 297 Shell 1 57 22 18 17 Plain Text 3 3723 0 2413 1310 TOML 18 605 539 2 64 YAML 2 21 19 2 0 ------------------------------------------------------------------------------- Jupyter Notebooks 4 0 0 0 0 |- Markdown 2 77 32 31 14 |- Python 2 205 178 1 26 (Total) 282 210 32 40 ------------------------------------------------------------------------------- Markdown 43 3333 0 2526 807 |- BASH 6 103 100 0 3 |- JSON 1 12 12 0 0 |- Python 7 121 109 0 12 |- Rust 12 406 344 0 62 |- TOML 2 75 63 0 12 (Total) 4050 628 2526 896 ------------------------------------------------------------------------------- Rust 296 89600 80403 1861 7336 |- Markdown 143 1593 25 1448 120 (Total) 91193 80428 3309 7456 =============================================================================== Total 445 100226 83475 6903 9848 =============================================================================== |
Currently, device-mapped paged attention is approximately 10% slower compared to single device paged attention. I found the slowdown is at least partially due to the overhead of moving tensors to the device on every layer forward pass. I implemented a temporary workaround in the Ideally, we would avoid making copies of the tensors in the model's forward and instead perform this operation in the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @cdoko! Thanks for the PR.
Currently, device-mapped paged attention is approximately 10% slower compared to single device paged attention. I found the slowdown is at least partially due to the overhead of moving tensors to the device on every layer forward pass.
With the comments I made, hopefully this should be adressed.
However, I have not yet found a way to pass the layer_devices information there.
Please feel free to make whatever changes you find necessary to get this to work!
I think a nice way to do this would be to add an API to the mapper (device_map.rs
) to extract all the devices which will be mapped to (inlcuding normal_loading_metadata.real_device
). One of my comments suggests a way to utilize this information, and I think this method would work nicely with that. What do you think?
@@ -70,6 +70,34 @@ impl PagedAttention { | |||
input_metadata.slot_mappings.clone() | |||
}; | |||
|
|||
// When device mapping, these Tensors are fixed on the first device, and must be moved to the same device as q,k,v |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see you mentioned a performance penalty:
Currently, device-mapped paged attention is approximately 10% slower compared to single device paged attention. I found the slowdown is at least partially due to the overhead of moving tensors to the device on every layer forward pass.
To avoid this, can you please update PagedAttentionInputMetadata
to store all tensors as hashmaps of device location to the actual tensor? Do you think this is a good solution?
This takes up more memory on each GPU but requires only one copy (in the inputs processor) and enables us to remove this section. I'm thinking something similar to this where we create multiple RoPE instantiations on different devices.
Additionally, I just merged #1014. Can you please merge with master to get these new changes, otherwise a conflict will occur with the addition of the changes I requested above.
mistralrs-core/src/models/gemma.rs
Outdated
@@ -284,6 +284,7 @@ impl Attention { | |||
v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))? | |||
}; | |||
|
|||
let start_offsets_kernel = start_offsets_kernel.to_device(q.device())?; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we merge these into RotaryEmbedding
in layers.rs
?
I updated PagedAttentionInputMetadata to store tensors as hashmaps of device location to the actual tensor. The performance penalty has decreased, but there's still a remaining small penalty of a few percent; I even tested a similar optimization to the attention mask in the model forward pass, since it's also copied every layer, but it didn't give any further performance improvement. I suspect this is due to the unavoidable mapping of To implement the updated As a side note, I noticed that Qwen2 and Quantized Llama's RotaryEmbedding uses |
Thanks for the updates. I think this is close to merge.
Sounds great! I agree, TP is something we should look into. I think the hard part is integrating it nicely with the existing codebase.
Sounds good.
Yes, can you please update it to use the ones in mistralrs? |
I already did, just wanted to confirm.
Personally I'm interested in speculative decoding for the much higher T/s. I took a look at the If the PR is ok, I'll probably be working on the VRAM calculations for mistralrs-server next, because currently the flags like |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@cdoko thanks for the PR!
Is resolving these cache issues the main blocker for getting speculative decoding working?
Yes, it's just that using the Normal cache isn't supported yet.
And what about with PA?
I think the main problem is that we need some extensive management of the KV cache (in particular, rolling back the cache) for PA, which I haven't implemented yet.
If the PR is ok, I'll probably be working on the VRAM calculations for mistralrs-server next, because currently the flags like --pa-gpu-mem assume single device and don't account for multi-device setups.
Sounds great!
Added support for device mapping in Paged Attention by passing the device list from the mapper to the cache engine. Manual moving was required for certain unmapped tensors in the paged attention forward pass. I have tested the device mapping support on several text models and it appears to be functional.
Memory allocation is currently calculated as if all memory is available on a single device, resulting in that memory being split across devices; ideally, we should calculate available memory per GPU. If this PR is fine, I will work on this next.
Additionally, this feature currently only supports GPU devices due to an error when attempting to use reshape_and_cache() with non-cuda tensors.
Please let me know if you'd like me to revise anything!