-
Notifications
You must be signed in to change notification settings - Fork 333
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support device mapping for Paged Attention #1011
Changes from 19 commits
acea1fd
6032955
946dfd9
8a134b9
6f02469
86d026a
d54e767
819278c
007f2db
e7d2d80
db0cdc5
05ed5fe
937319b
047ca07
882f4e7
a51bca6
db78205
e6324b4
8fadbbc
9d9918d
895d0a9
7d46900
aa90ef2
8fc40fc
ad66e29
e0719f9
cffeaaa
f0f3ac1
e935067
a833acf
efbd6f4
8a0177a
b614be9
30618da
0215e86
44e0559
095e28a
80eb294
ef7ee66
f269c55
587b4f7
36d89c9
17f8065
3ca105a
40706f2
62d2126
78189c9
6724ee1
6ca0625
d3b4dae
ae3f53e
3bf680d
7560df9
83cf77d
45aad07
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -70,6 +70,34 @@ impl PagedAttention { | |
input_metadata.slot_mappings.clone() | ||
}; | ||
|
||
// When device mapping, these Tensors are fixed on the first device, and must be moved to the same device as q,k,v | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see you mentioned a performance penalty:
To avoid this, can you please update This takes up more memory on each GPU but requires only one copy (in the inputs processor) and enables us to remove this section. I'm thinking something similar to this where we create multiple RoPE instantiations on different devices. Additionally, I just merged #1014. Can you please merge with master to get these new changes, otherwise a conflict will occur with the addition of the changes I requested above. |
||
// - slot_mapping | ||
// - input_metadata.block_tables | ||
// - input_metadata.context_lens | ||
// - self.alibi_slopes | ||
// - attention_mask | ||
let slot_mapping = slot_mapping.to_device(query.device())?; | ||
let block_tables = input_metadata | ||
.block_tables | ||
.as_ref() | ||
.unwrap() | ||
.to_device(query.device())?; | ||
let context_lens = input_metadata | ||
.context_lens | ||
.as_ref() | ||
.unwrap() | ||
.to_device(query.device())?; | ||
let alibi_slopes = if let Some(alibi_slopes) = self.alibi_slopes.as_ref() { | ||
Some(alibi_slopes.to_device(query.device())?) | ||
} else { | ||
None | ||
}; | ||
let attention_mask = if let Some(mask) = attention_mask { | ||
Some(mask.to_device(query.device())?) | ||
} else { | ||
None | ||
}; | ||
|
||
let (batch_size, attention_heads, seq_len, head_size) = query.shape().dims4()?; | ||
let (_, key_value_heads, _, _) = key.shape().dims4()?; | ||
|
||
|
@@ -80,7 +108,7 @@ impl PagedAttention { | |
query, | ||
key, | ||
value, | ||
Some(mask), | ||
Some(&mask), | ||
None, | ||
&SdpaParams { | ||
n_kv_groups: self.n_kv_groups, | ||
|
@@ -92,7 +120,7 @@ impl PagedAttention { | |
)?), | ||
}; | ||
|
||
// // paged-attn expects [batch_size, num_tokens, num_heads, head_size] | ||
// paged-attn expects [batch_size, num_tokens, num_heads, head_size] | ||
let (query, key, value) = if seq_len > 1 { | ||
let q = query | ||
.transpose(1, 2)? | ||
|
@@ -105,7 +133,7 @@ impl PagedAttention { | |
.reshape(((), key_value_heads, head_size))?; | ||
(q, k, v) | ||
} else { | ||
//avoid unnecessary transpose for decoding | ||
// avoid unnecessary transpose for decoding | ||
let q = query.reshape(((), attention_heads, head_size))?; | ||
let k = key.reshape(((), key_value_heads, head_size))?; | ||
let v = value.reshape(((), key_value_heads, head_size))?; | ||
|
@@ -131,7 +159,6 @@ impl PagedAttention { | |
// Return result in prefill | ||
return Ok(att); | ||
} | ||
|
||
// Args: | ||
// output: shape = [num_generation_tokens, num_heads, head_size] | ||
// | ||
|
@@ -147,18 +174,16 @@ impl PagedAttention { | |
// | ||
// alibi_slopes: shape = [num_heads] | ||
#[allow(clippy::cast_possible_truncation)] | ||
let res = paged_attention( | ||
paged_attention( | ||
&query, | ||
key_cache.as_ref().unwrap(), | ||
value_cache.as_ref().unwrap(), | ||
input_metadata.block_tables.as_ref().unwrap(), | ||
input_metadata.context_lens.as_ref().unwrap(), | ||
self.alibi_slopes.as_ref(), | ||
&block_tables, | ||
&context_lens, | ||
alibi_slopes.as_ref(), | ||
input_metadata.max_context_len.unwrap(), | ||
self.scale, | ||
softcapping.unwrap_or(1.0f64) as f32, | ||
)?; | ||
|
||
Ok(res) | ||
) | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we merge these into
RotaryEmbedding
inlayers.rs
?