Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support device mapping for Paged Attention #1011

Merged
merged 55 commits into from
Jan 1, 2025
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
acea1fd
Move start_offsets_kernel to correct device
cdoko Dec 28, 2024
6032955
Move start_offsets_kernel to correct device
cdoko Dec 28, 2024
946dfd9
Move start_offsets_kernel to correct device
cdoko Dec 28, 2024
8a134b9
Move start_offsets_kernel to correct device
cdoko Dec 28, 2024
6f02469
Move start_offsets_kernel to correct device
cdoko Dec 28, 2024
86d026a
Move start_offsets_kernel to correct device
cdoko Dec 28, 2024
d54e767
Move start_offsets_kernel to correct device
cdoko Dec 28, 2024
819278c
Move start_offsets_kernel to correct device
cdoko Dec 28, 2024
007f2db
Move start_offsets_kernel to correct device
cdoko Dec 28, 2024
e7d2d80
Update starcoder2.rs
cdoko Dec 28, 2024
db0cdc5
Support device mapping
cdoko Dec 28, 2024
05ed5fe
Support device mapping
cdoko Dec 28, 2024
937319b
Support device mapping
cdoko Dec 28, 2024
047ca07
Support device mapping
cdoko Dec 28, 2024
882f4e7
Support device mapping
cdoko Dec 28, 2024
a51bca6
format
cdoko Dec 28, 2024
db78205
Support device mapping
cdoko Dec 28, 2024
e6324b4
remove mut
cdoko Dec 28, 2024
8fadbbc
remove mut
cdoko Dec 28, 2024
9d9918d
Merge branch 'master' into device-mapping-paged-attn
cdoko Dec 31, 2024
895d0a9
Add get_unique_devices method
cdoko Dec 31, 2024
7d46900
Move tensor for device mapping
cdoko Dec 31, 2024
aa90ef2
Add DeviceMapper
cdoko Dec 31, 2024
8fc40fc
Fix wrong RotaryEmbedding import
cdoko Dec 31, 2024
ad66e29
Fix wrong RotaryEmbedding import
cdoko Dec 31, 2024
e0719f9
Remove unecessary tensor copies
cdoko Dec 31, 2024
cffeaaa
Add DeviceMapper
cdoko Dec 31, 2024
f0f3ac1
Add DeviceMapper
cdoko Dec 31, 2024
e935067
Add DeviceMapper
cdoko Dec 31, 2024
a833acf
Add device mapping
cdoko Dec 31, 2024
efbd6f4
Create tensor copies for each device for pa
cdoko Dec 31, 2024
8a0177a
Add device mapper
cdoko Dec 31, 2024
b614be9
Add device mapper
cdoko Dec 31, 2024
30618da
Add device mapper
cdoko Dec 31, 2024
0215e86
Add device mapper
cdoko Dec 31, 2024
44e0559
Add device mapper
cdoko Dec 31, 2024
095e28a
Add device mapper
cdoko Dec 31, 2024
80eb294
Add device mapper
cdoko Dec 31, 2024
ef7ee66
Add device mapper
cdoko Dec 31, 2024
f269c55
Add device mapper
cdoko Dec 31, 2024
587b4f7
Add device mapper
cdoko Dec 31, 2024
36d89c9
add device mapper
cdoko Dec 31, 2024
17f8065
Remove unecessary tensor move
cdoko Dec 31, 2024
3ca105a
Remove unecessary tensor move
cdoko Dec 31, 2024
40706f2
Remove unecessary tensor move
cdoko Dec 31, 2024
62d2126
Remove unecessary tensor move
cdoko Dec 31, 2024
78189c9
Remove unecessary tensor move
cdoko Dec 31, 2024
6724ee1
Remove unecessary tensor move
cdoko Dec 31, 2024
6ca0625
Remove unecessary tensor move
cdoko Dec 31, 2024
d3b4dae
Remove unecessary tensor move
cdoko Dec 31, 2024
ae3f53e
format
cdoko Dec 31, 2024
3bf680d
format
cdoko Dec 31, 2024
7560df9
format
cdoko Dec 31, 2024
83cf77d
clippy
cdoko Dec 31, 2024
45aad07
format
cdoko Dec 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mistralrs-core/src/dummy_paged_attention/cache_engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ impl CacheEngine {
_cache_config: &CacheConfig,
_dtype: DType,
_device: &Device,
_layer_devices: Vec<Option<Device>>,
) -> Result<Self> {
Ok(Self {
dummy_cache: Arc::new(Mutex::new(Vec::new())),
Expand Down
1 change: 1 addition & 0 deletions mistralrs-core/src/models/gemma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ impl Attention {
v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?
};

let start_offsets_kernel = start_offsets_kernel.to_device(q.device())?;
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we merge these into RotaryEmbedding in layers.rs?

self.rotary_emb
.forward(seqlen_offsets, &start_offsets_kernel, &mut q, &mut k, b_sz)?;

Expand Down
1 change: 1 addition & 0 deletions mistralrs-core/src/models/gemma2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ impl Attention {
v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?
};

let start_offsets_kernel = start_offsets_kernel.to_device(q.device())?;
self.rotary_emb
.forward(seqlen_offsets, &start_offsets_kernel, &mut q, &mut k, b_sz)?;

Expand Down
1 change: 1 addition & 0 deletions mistralrs-core/src/models/mistral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ impl Attention {
v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?
};

let start_offsets_kernel = start_offsets_kernel.to_device(q.device())?;
self.rotary_emb
.forward(seqlen_offsets, &start_offsets_kernel, &mut q, &mut k, b_sz)?;

Expand Down
1 change: 1 addition & 0 deletions mistralrs-core/src/models/mixtral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ impl Attention {
v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?
};

let start_offsets_kernel = start_offsets_kernel.to_device(q.device())?;
self.rotary_emb
.forward(seqlen_offsets, &start_offsets_kernel, &mut q, &mut k, b_sz)?;

Expand Down
1 change: 1 addition & 0 deletions mistralrs-core/src/models/phi2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ impl Attention {
v.reshape((b_size, self.num_kv_heads, seq_len, self.head_dim))?
};

let start_offsets_kernel = start_offsets_kernel.to_device(q.device())?;
self.rotary_emb.forward(
seqlen_offsets,
&start_offsets_kernel,
Expand Down
1 change: 1 addition & 0 deletions mistralrs-core/src/models/quantized_llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ impl LayerWeights {
v.reshape((b_sz, self.n_kv_head, seq_len, self.head_dim))?
};

let start_offsets_kernel = start_offsets_kernel.to_device(q.device())?;
self.rotary
.forward(start_offsets, &start_offsets_kernel, &mut q, &mut k, b_sz)?;

Expand Down
1 change: 1 addition & 0 deletions mistralrs-core/src/models/quantized_qwen2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ impl LayerWeights {
v.reshape((b_sz, self.n_kv_head, seq_len, self.head_dim))?
};

let start_offsets_kernel = start_offsets_kernel.to_device(q.device())?;
self.rotary
.forward(start_offsets, &start_offsets_kernel, &mut q, &mut k, b_sz)?;

Expand Down
1 change: 1 addition & 0 deletions mistralrs-core/src/models/quantized_starcoder2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ impl LayerWeights {
v.reshape((b_sz, self.n_kv_head, q_len, self.head_dim))?
};

let start_offsets_kernel = start_offsets_kernel.to_device(q.device())?;
self.rotary_emb
.forward(seqlen_offsets, &start_offsets_kernel, &mut q, &mut k, b_sz)?;

Expand Down
1 change: 1 addition & 0 deletions mistralrs-core/src/models/qwen2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ impl Attention {
v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?
};

let start_offsets_kernel = start_offsets_kernel.to_device(q.device())?;
self.rotary_emb
.forward(seqlen_offsets, &start_offsets_kernel, &mut q, &mut k, b_sz)?;

Expand Down
1 change: 1 addition & 0 deletions mistralrs-core/src/models/starcoder2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ impl Attention {
v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?
};

let start_offsets_kernel = start_offsets_kernel.to_device(q.device())?;
self.rotary_emb
.forward(seqlen_offsets, &start_offsets_kernel, &mut q, &mut k, b_sz)?;

Expand Down
7 changes: 6 additions & 1 deletion mistralrs-core/src/paged_attention/cache_engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,15 @@ impl CacheEngine {
cache_config: &CacheConfig,
dtype: DType,
device: &Device,
layer_devices: Vec<Option<Device>>,
) -> Result<Self> {
Ok(Self {
gpu_cache: Arc::new(Mutex::new(Self::allocate_gpu_cache(
model_config,
cache_config,
dtype,
device,
layer_devices,
)?)),
cpu_cache: Self::allocate_cpu_cache(model_config, cache_config, dtype, device)?,
num_layers: model_config.num_layers(),
Expand All @@ -55,13 +57,16 @@ impl CacheEngine {
cache_config: &CacheConfig,
dtype: DType,
device: &Device,
layer_devices: Vec<Option<Device>>,
) -> Result<Vec<KVCache>> {
let key_block_shape =
Self::calculate_key_block_shape(model_config, dtype, cache_config.block_size);
let value_block_shape =
Self::calculate_value_block_shape(model_config, cache_config.block_size);
let mut gpu_cache = Vec::new();
for _ in 0..model_config.num_layers() {

for i in 0..model_config.num_layers() {
let device = layer_devices[i].as_ref().unwrap_or(device);
let key_blocks = Tensor::zeros(
(
cache_config.num_gpu_blocks,
Expand Down
47 changes: 36 additions & 11 deletions mistralrs-core/src/paged_attention/layers/paged_attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,34 @@ impl PagedAttention {
input_metadata.slot_mappings.clone()
};

// When device mapping, these Tensors are fixed on the first device, and must be moved to the same device as q,k,v
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see you mentioned a performance penalty:

Currently, device-mapped paged attention is approximately 10% slower compared to single device paged attention. I found the slowdown is at least partially due to the overhead of moving tensors to the device on every layer forward pass.

To avoid this, can you please update PagedAttentionInputMetadata to store all tensors as hashmaps of device location to the actual tensor? Do you think this is a good solution?

This takes up more memory on each GPU but requires only one copy (in the inputs processor) and enables us to remove this section. I'm thinking something similar to this where we create multiple RoPE instantiations on different devices.

Additionally, I just merged #1014. Can you please merge with master to get these new changes, otherwise a conflict will occur with the addition of the changes I requested above.

// - slot_mapping
// - input_metadata.block_tables
// - input_metadata.context_lens
// - self.alibi_slopes
// - attention_mask
let slot_mapping = slot_mapping.to_device(query.device())?;
let block_tables = input_metadata
.block_tables
.as_ref()
.unwrap()
.to_device(query.device())?;
let context_lens = input_metadata
.context_lens
.as_ref()
.unwrap()
.to_device(query.device())?;
let alibi_slopes = if let Some(alibi_slopes) = self.alibi_slopes.as_ref() {
Some(alibi_slopes.to_device(query.device())?)
} else {
None
};
let attention_mask = if let Some(mask) = attention_mask {
Some(mask.to_device(query.device())?)
} else {
None
};

let (batch_size, attention_heads, seq_len, head_size) = query.shape().dims4()?;
let (_, key_value_heads, _, _) = key.shape().dims4()?;

Expand All @@ -80,7 +108,7 @@ impl PagedAttention {
query,
key,
value,
Some(mask),
Some(&mask),
None,
&SdpaParams {
n_kv_groups: self.n_kv_groups,
Expand All @@ -92,7 +120,7 @@ impl PagedAttention {
)?),
};

// // paged-attn expects [batch_size, num_tokens, num_heads, head_size]
// paged-attn expects [batch_size, num_tokens, num_heads, head_size]
let (query, key, value) = if seq_len > 1 {
let q = query
.transpose(1, 2)?
Expand All @@ -105,7 +133,7 @@ impl PagedAttention {
.reshape(((), key_value_heads, head_size))?;
(q, k, v)
} else {
//avoid unnecessary transpose for decoding
// avoid unnecessary transpose for decoding
let q = query.reshape(((), attention_heads, head_size))?;
let k = key.reshape(((), key_value_heads, head_size))?;
let v = value.reshape(((), key_value_heads, head_size))?;
Expand All @@ -131,7 +159,6 @@ impl PagedAttention {
// Return result in prefill
return Ok(att);
}

// Args:
// output: shape = [num_generation_tokens, num_heads, head_size]
//
Expand All @@ -147,18 +174,16 @@ impl PagedAttention {
//
// alibi_slopes: shape = [num_heads]
#[allow(clippy::cast_possible_truncation)]
let res = paged_attention(
paged_attention(
&query,
key_cache.as_ref().unwrap(),
value_cache.as_ref().unwrap(),
input_metadata.block_tables.as_ref().unwrap(),
input_metadata.context_lens.as_ref().unwrap(),
self.alibi_slopes.as_ref(),
&block_tables,
&context_lens,
alibi_slopes.as_ref(),
input_metadata.max_context_len.unwrap(),
self.scale,
softcapping.unwrap_or(1.0f64) as f32,
)?;

Ok(res)
)
}
}
46 changes: 31 additions & 15 deletions mistralrs-core/src/pipeline/gguf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ impl Loader for GGUFLoader {
silent: bool,
mapper: DeviceMapMetadata,
in_situ_quant: Option<IsqType>,
mut paged_attn_config: Option<PagedAttentionConfig>,
paged_attn_config: Option<PagedAttentionConfig>,
) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
if in_situ_quant.is_some() {
anyhow::bail!(
Expand All @@ -353,10 +353,11 @@ impl Loader for GGUFLoader {
self.get_id(),
device.device_pretty_repr()
);
} else if paged_attn_config.is_some() {
warn!("Device mapping or device topology and PagedAttention are incompatible, disabling PagedAttention.");
paged_attn_config = None;
}
// } else if paged_attn_config.is_some() {
EricLBuehler marked this conversation as resolved.
Show resolved Hide resolved
// warn!("Device mapping or device topology and PagedAttention are incompatible, disabling PagedAttention.");
// paged_attn_config = None;
// }

let mut readers = Vec::new();
for filename in paths.get_weight_filenames() {
Expand Down Expand Up @@ -408,7 +409,7 @@ impl Loader for GGUFLoader {
// Base config (quantization only):
let quant = ModelConfig::ParamsGGUF(
model,
(device, mapper, self.config.topology.as_ref()).into(),
(device, mapper.clone(), self.config.topology.as_ref()).into(),
if paged_attn_config.is_some() {
AttentionImplementation::PagedAttention
} else {
Expand Down Expand Up @@ -453,6 +454,24 @@ impl Loader for GGUFLoader {
_ => unreachable!(),
};

let num_hidden_layers = match model {
Model::Llama(ref model) => model.cache.normal().0.len(),
Model::Phi2(ref model) => model.cache.normal().0.len(),
Model::XLoraLlama(ref model) => model.cache.full().lock().len(),
Model::Phi3(ref model) => model.cache.normal().0.len(),
Model::XLoraPhi3(ref model) => model.cache.full().lock().len(),
Model::Starcoder2(ref model) => model.cache.normal().0.len(),
Model::Qwen2(ref model) => model.cache.normal().0.len(),
};

let mapper =
mapper.into_mapper(num_hidden_layers, device, self.config.topology.as_ref())?;
let mut layer_devices = Vec::new();
for layer in 0..num_hidden_layers {
let device = mapper.device_for(layer, false).cloned();
layer_devices.push(device);
}

let (cache_config, cache_engine) = if let Some(paged_attn_config) = paged_attn_config {
let model_config: &dyn ModelConfigLike = &model_config_metadata;
let cache_config = calculate_cache_config(
Expand All @@ -463,7 +482,13 @@ impl Loader for GGUFLoader {
model_config,
device,
)?;
let cache_engine = CacheEngine::new(model_config, &cache_config, DType::F32, device)?;
let cache_engine = CacheEngine::new(
model_config,
&cache_config,
DType::F32,
EricLBuehler marked this conversation as resolved.
Show resolved Hide resolved
device,
layer_devices,
)?;
(Some(cache_config), Some(cache_engine))
} else {
(None, None)
Expand Down Expand Up @@ -494,15 +519,6 @@ impl Loader for GGUFLoader {
Model::Qwen2(ref p) => p.max_seq_len,
};
let tok_env = build_tok_env(tokenizer.clone());
let num_hidden_layers = match model {
Model::Llama(ref model) => model.cache.normal().0.len(),
Model::Phi2(ref model) => model.cache.normal().0.len(),
Model::XLoraLlama(ref model) => model.cache.full().lock().len(),
Model::Phi3(ref model) => model.cache.normal().0.len(),
Model::XLoraPhi3(ref model) => model.cache.full().lock().len(),
Model::Starcoder2(ref model) => model.cache.normal().0.len(),
Model::Qwen2(ref model) => model.cache.normal().0.len(),
};

if chat_template.bos_token.is_none() && bos.is_some() {
chat_template.bos_token = Some(BeginEndUnkTok(Either::Left(bos.unwrap())));
Expand Down
14 changes: 8 additions & 6 deletions mistralrs-core/src/pipeline/normal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ impl Loader for NormalLoader {
silent: bool,
mapper: DeviceMapMetadata,
in_situ_quant: Option<IsqType>,
mut paged_attn_config: Option<PagedAttentionConfig>,
paged_attn_config: Option<PagedAttentionConfig>,
) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
let config = std::fs::read_to_string(paths.get_config_filename())?;
// Otherwise, the device mapper will print it
Expand All @@ -288,16 +288,17 @@ impl Loader for NormalLoader {
self.get_id(),
device.device_pretty_repr()
);
} else if paged_attn_config.is_some() {
warn!("Device mapping or device topology and PagedAttention are incompatible, disabling PagedAttention.");
paged_attn_config = None;
}

let mapper = mapper.into_mapper(
self.inner.get_total_device_mapping_num_layers(&config)?,
device,
self.config.topology.as_ref(),
)?;
let mut layer_devices = Vec::new();
for layer in 0..self.inner.get_total_device_mapping_num_layers(&config)? {
let device = mapper.device_for(layer, false).cloned();
layer_devices.push(device);
}
let dtype = mapper.get_min_dtype(dtype)?;

info!(
Expand Down Expand Up @@ -523,7 +524,8 @@ impl Loader for NormalLoader {
model.config(),
device,
)?;
let cache_engine = CacheEngine::new(model.config(), &cache_config, dtype, device)?;
let cache_engine =
CacheEngine::new(model.config(), &cache_config, dtype, device, layer_devices)?;
(Some(cache_config), Some(cache_engine))
} else {
(None, None)
Expand Down
8 changes: 7 additions & 1 deletion mistralrs-core/src/pipeline/vision.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,11 @@ impl Loader for VisionLoader {
device,
self.config.topology.as_ref(),
)?;
let mut layer_devices = Vec::new();
for layer in 0..self.inner.get_total_device_mapping_num_layers(&config)? {
let device = mapper.device_for(layer, false).cloned();
layer_devices.push(device);
}
let dtype = mapper.get_min_dtype(dtype)?;

let mut loading_isq = in_situ_quant.is_some() || self.config.from_uqff.is_some();
Expand Down Expand Up @@ -435,7 +440,8 @@ impl Loader for VisionLoader {
model.config(),
device,
)?;
let cache_engine = CacheEngine::new(model.config(), &cache_config, dtype, device)?;
let cache_engine =
CacheEngine::new(model.config(), &cache_config, dtype, device, layer_devices)?;
(Some(cache_config), Some(cache_engine))
} else {
(None, None)
Expand Down
Loading