Skip to content

Commit

Permalink
Support BF16 kvcache, rope and attentions for inference of GGUF/GGML …
Browse files Browse the repository at this point in the history
…models (#1009)

* Support BF16 kvcache & attention for GGUF/GGML quantization

* Fix clippy

* Pass dtype to xlora gguf/ggml model

* Remove the hardcoded fix for the literal chat template (side effect: the model cannot terminate itself for running GGUF file)

* Pass dtype to Lora GGUF/GGML models
  • Loading branch information
guoqingbao authored Dec 30, 2024
1 parent d8fa819 commit 1880c0b
Show file tree
Hide file tree
Showing 16 changed files with 253 additions and 107 deletions.
9 changes: 8 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ Mistral.rs uses subcommands to control the model type. They are generally of for

### Architecture for plain models

> Note: for plain models, you can specify the data type to load and run in. This must be one of `f32`, `f16`, `bf16` or `auto` to choose based on the device. This is specified in the `--dype`/`-d` parameter after the model architecture (`plain`).
> Note: for plain models, you can specify the data type to load and run in. This must be one of `f32`, `f16`, `bf16` or `auto` to choose based on the device. This is specified in the `--dype`/`-d` parameter after the model architecture (`plain`). For quantized models (gguf/ggml), you may specify data type of `f32` or `bf16` (`f16` is not recommended due to its lower precision in quantized inference).
If you do not specify the architecture, an attempt will be made to use the model's config. If this fails, please raise an issue.

Expand Down Expand Up @@ -455,6 +455,13 @@ And even diffusion models:
./mistralrs-server -i diffusion-plain -m black-forest-labs/FLUX.1-schnell -a flux
```

On Apple Silicon (`Metal`), run with throughput log, settings of paged attention (maximum usage of 4GB for kv cache) and dtype (bf16 for kv cache and attention)

```bash
cargo build --release --features metal
./target/release/mistralrs-server -i --throughput --paged-attn --pa-gpu-mem 4096 gguf --dtype bf16 -m /Users/Downloads/ -f Phi-3.5-mini-instruct-Q4_K_M.gguf
```

### OpenAI HTTP server

You can an HTTP server
Expand Down
20 changes: 13 additions & 7 deletions mistralrs-core/src/model_loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,13 @@ pub fn get_model_dtype(model: &ModelSelected) -> anyhow::Result<ModelDType> {
| ModelSelected::Lora { dtype, .. }
| ModelSelected::XLora { dtype, .. }
| ModelSelected::VisionPlain { dtype, .. }
| ModelSelected::DiffusionPlain { dtype, .. } => Ok(*dtype),
ModelSelected::GGUF { .. }
| ModelSelected::LoraGGUF { .. }
| ModelSelected::GGML { .. }
| ModelSelected::LoraGGML { .. }
| ModelSelected::XLoraGGUF { .. }
| ModelSelected::XLoraGGML { .. } => Ok(ModelDType::Auto),
| ModelSelected::DiffusionPlain { dtype, .. }
| ModelSelected::GGML { dtype, .. }
| ModelSelected::GGUF { dtype, .. }
| ModelSelected::XLoraGGUF { dtype, .. }
| ModelSelected::XLoraGGML { dtype, .. }
| ModelSelected::LoraGGUF { dtype, .. }
| ModelSelected::LoraGGML { dtype, .. } => Ok(*dtype),
ModelSelected::Toml { file } => {
let selector: TomlSelector = toml::from_str(
&fs::read_to_string(file.clone())
Expand Down Expand Up @@ -222,6 +222,7 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loa
quantized_model_id,
quantized_filename,
topology,
..
} => GGUFLoaderBuilder::new(
args.chat_template,
tok_model_id,
Expand All @@ -244,6 +245,7 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loa
order,
tgt_non_granular_index,
topology,
..
} => GGUFLoaderBuilder::new(
args.chat_template,
tok_model_id,
Expand Down Expand Up @@ -275,6 +277,7 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loa
adapters_model_id,
order,
topology,
..
} => GGUFLoaderBuilder::new(
args.chat_template,
tok_model_id,
Expand Down Expand Up @@ -304,6 +307,7 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loa
quantized_filename,
gqa,
topology,
..
} => GGMLLoaderBuilder::new(
GGMLSpecificConfig {
gqa,
Expand All @@ -328,6 +332,7 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loa
tgt_non_granular_index,
gqa,
topology,
..
} => GGMLLoaderBuilder::new(
GGMLSpecificConfig {
gqa,
Expand Down Expand Up @@ -360,6 +365,7 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loa
order,
gqa,
topology,
..
} => GGMLLoaderBuilder::new(
GGMLSpecificConfig {
gqa,
Expand Down
24 changes: 24 additions & 0 deletions mistralrs-core/src/model_selected.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,10 @@ pub enum ModelSelected {
#[arg(short = 'f', long)]
quantized_filename: String,

/// Model data type. Defaults to `auto`.
#[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)]
dtype: ModelDType,

/// Path to a topology YAML file.
#[arg(long)]
topology: Option<String>,
Expand Down Expand Up @@ -215,6 +219,10 @@ pub enum ModelSelected {
#[arg(long)]
tgt_non_granular_index: Option<usize>,

/// Model data type. Defaults to `auto`.
#[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)]
dtype: ModelDType,

/// Path to a topology YAML file.
#[arg(long)]
topology: Option<String>,
Expand Down Expand Up @@ -246,6 +254,10 @@ pub enum ModelSelected {
#[arg(short, long)]
order: String,

/// Model data type. Defaults to `auto`.
#[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)]
dtype: ModelDType,

/// Path to a topology YAML file.
#[arg(long)]
topology: Option<String>,
Expand Down Expand Up @@ -274,6 +286,10 @@ pub enum ModelSelected {
#[arg(short, long, default_value_t = 1)]
gqa: usize,

/// Model data type. Defaults to `auto`.
#[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)]
dtype: ModelDType,

/// Path to a topology YAML file.
#[arg(long)]
topology: Option<String>,
Expand Down Expand Up @@ -315,6 +331,10 @@ pub enum ModelSelected {
#[arg(short, long, default_value_t = 1)]
gqa: usize,

/// Model data type. Defaults to `auto`.
#[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)]
dtype: ModelDType,

/// Path to a topology YAML file.
#[arg(long)]
topology: Option<String>,
Expand Down Expand Up @@ -351,6 +371,10 @@ pub enum ModelSelected {
#[arg(short, long, default_value_t = 1)]
gqa: usize,

/// Model data type. Defaults to `auto`.
#[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)]
dtype: ModelDType,

/// Path to a topology YAML file.
#[arg(long)]
topology: Option<String>,
Expand Down
29 changes: 21 additions & 8 deletions mistralrs-core/src/models/quantized_llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ struct LayerWeights {
rotary: Arc<RotaryEmbedding>,
paged_attn: Option<PagedAttention>,
sdpa_params: SdpaParams,
dtype: DType,
}

impl LayerWeights {
Expand All @@ -149,9 +150,15 @@ impl LayerWeights {
) -> Result<Tensor> {
let (b_sz, seq_len, n_embd) = x.dims3()?;

let q = MatMul.qmethod_matmul(x, &*self.attention_wq)?;
let k = MatMul.qmethod_matmul(x, &*self.attention_wk)?;
let v = MatMul.qmethod_matmul(x, &*self.attention_wv)?;
let q = MatMul
.qmethod_matmul(x, &*self.attention_wq)?
.to_dtype(self.dtype)?;
let k = MatMul
.qmethod_matmul(x, &*self.attention_wk)?
.to_dtype(self.dtype)?;
let v = MatMul
.qmethod_matmul(x, &*self.attention_wv)?
.to_dtype(self.dtype)?;

let mut q = q.reshape((b_sz * seq_len, self.n_head, self.head_dim))?;
let mut k = k.reshape((b_sz * seq_len, self.n_kv_head, self.head_dim))?;
Expand Down Expand Up @@ -212,7 +219,7 @@ impl LayerWeights {
y.reshape(&[b_sz, seq_len, n_embd])?
};

let y = MatMul.qmethod_matmul(&y, &*self.attention_wo)?;
let y = MatMul.qmethod_matmul(&y.to_dtype(x.dtype())?, &*self.attention_wo)?;
Ok(y)
}
}
Expand All @@ -226,10 +233,11 @@ pub struct ModelWeights {
pub cache: EitherCache,
pub max_seq_len: usize,
mapper: Option<Box<dyn DeviceMapper + Send + Sync>>,
dtype: DType,
}

impl ModelConfig::FromGGML for ModelWeights {
fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result<Self> {
fn from_ggml(mut ct: ggml_file::Content, gqa: usize, dtype: DType) -> Result<Self> {
let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize;
let rotary = RotaryEmbedding::new_partial(
10000.,
Expand All @@ -238,7 +246,7 @@ impl ModelConfig::FromGGML for ModelWeights {
MAX_SEQ_LEN as usize,
&ct.device,
false,
DType::F32,
dtype,
)?;
let tok_embeddings = ct.remove("tok_embeddings.weight")?;
let tok_embeddings = tok_embeddings.dequantize(&ct.device)?;
Expand Down Expand Up @@ -307,6 +315,7 @@ impl ModelConfig::FromGGML for ModelWeights {
softmax_scale: 1.0 / (head_dim as f32).sqrt(),
sliding_window: None,
},
dtype,
})
}
Ok(Self {
Expand All @@ -324,6 +333,7 @@ impl ModelConfig::FromGGML for ModelWeights {
)),
max_seq_len: MAX_SEQ_LEN as usize, // Cannot determine from ggml.
mapper: None,
dtype,
})
}
}
Expand Down Expand Up @@ -405,6 +415,7 @@ impl ModelConfig::FromGGUF for ModelWeights {
mapper: DeviceMapMetadata,
topology: Option<&'_ Topology>,
attention_mechanism: AttentionImplementation,
dtype: DType,
) -> Result<Self> {
// Parameter extraction from metadata.
let metadata = ContentMetadata {
Expand Down Expand Up @@ -456,7 +467,7 @@ impl ModelConfig::FromGGUF for ModelWeights {
max_seq_len,
device,
false,
DType::F32,
dtype,
)?),
);
}
Expand Down Expand Up @@ -632,6 +643,7 @@ impl ModelConfig::FromGGUF for ModelWeights {
softmax_scale: 1.0 / (head_dim as f32).sqrt(),
sliding_window: None,
},
dtype,
})
}
Ok(Self {
Expand All @@ -646,6 +658,7 @@ impl ModelConfig::FromGGUF for ModelWeights {
cache: EitherCache::Normal(NormalCache::new(block_count, max_seq_len)),
max_seq_len,
mapper: Some(mapper),
dtype,
})
}
}
Expand All @@ -667,7 +680,7 @@ impl ModelWeights {
.as_ref()
.map(|(_, _)| &start_offsets as &dyn PastKvLenCache)
.unwrap_or(cache as &dyn PastKvLenCache),
DType::F32,
self.dtype,
self.layers[0].n_head,
)?;
for (i, layer) in self.layers.iter().enumerate() {
Expand Down
25 changes: 16 additions & 9 deletions mistralrs-core/src/models/quantized_phi2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ struct LayerWeights {
rope_dim: usize,
paged_attn: Option<PagedAttention>,
sdpa_params: SdpaParams,
dtype: DType,
}

impl LayerWeights {
Expand Down Expand Up @@ -83,10 +84,11 @@ impl LayerWeights {
metadata: Option<((Tensor, Tensor), &mut PagedAttentionInputMetadata)>,
) -> Result<Tensor> {
let (b_sz, seq_len, n_embd) = x.dims3()?;
let qkv =
self.attn_qkv
.forward(x)?
.reshape((b_sz, seq_len, 3, self.n_head, self.head_dim))?;
let qkv = self
.attn_qkv
.forward(x)?
.reshape((b_sz, seq_len, 3, self.n_head, self.head_dim))?
.to_dtype(self.dtype)?;

let q = qkv.i((.., .., 0))?.transpose(1, 2)?;
let k = qkv.i((.., .., 1))?.transpose(1, 2)?;
Expand Down Expand Up @@ -125,7 +127,7 @@ impl LayerWeights {
} else {
y.reshape(&[b_sz, seq_len, n_embd])?
};
let y = self.attn_output.forward(&y)?;
let y = self.attn_output.forward(&y.to_dtype(x.dtype())?)?;
Ok(y)
}
}
Expand All @@ -139,13 +141,15 @@ pub struct ModelWeights {
pub cache: EitherCache,
pub max_seq_len: usize,
mapper: Box<dyn DeviceMapper + Send + Sync>,
dtype: DType,
}

fn precomput_freqs_cis(
head_dim: usize,
freq_base: f32,
device: &Device,
max_seq_len: usize,
dtype: DType,
) -> Result<(Tensor, Tensor)> {
let theta: Vec<_> = (0..head_dim)
.step_by(2)
Expand All @@ -156,8 +160,8 @@ fn precomput_freqs_cis(
.to_dtype(DType::F32)?
.reshape((max_seq_len, 1))?
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
let cos = idx_theta.cos()?;
let sin = idx_theta.sin()?;
let cos = idx_theta.cos()?.to_dtype(dtype)?;
let sin = idx_theta.sin()?.to_dtype(dtype)?;
Ok((cos, sin))
}

Expand Down Expand Up @@ -224,6 +228,7 @@ impl ModelConfig::FromGGUF for ModelWeights {
mapper: DeviceMapMetadata,
topology: Option<&'_ Topology>,
attention_mechanism: AttentionImplementation,
dtype: DType,
) -> Result<Self> {
// Parameter extraction from metadata.
let metadata = ContentMetadata {
Expand All @@ -240,7 +245,7 @@ impl ModelConfig::FromGGUF for ModelWeights {
max_seq_len,
} = PropsGGUF::try_from(metadata).or_else(|err| candle_core::bail!("{err}"))?;

let (cos, sin) = precomput_freqs_cis(rope_dim, 10_000., device, max_seq_len)?;
let (cos, sin) = precomput_freqs_cis(rope_dim, 10_000., device, max_seq_len, dtype)?;

let tok_embeddings = ct.tensor("token_embd.weight", device)?;
let tok_embeddings = tok_embeddings.dequantize(device)?;
Expand Down Expand Up @@ -326,6 +331,7 @@ impl ModelConfig::FromGGUF for ModelWeights {
softmax_scale: 1.0 / (head_dim as f32).sqrt(),
sliding_window: None,
},
dtype,
})
}
Ok(Self {
Expand All @@ -337,6 +343,7 @@ impl ModelConfig::FromGGUF for ModelWeights {
cache: EitherCache::Normal(NormalCache::new(block_count, max_seq_len)),
max_seq_len,
mapper,
dtype,
})
}
}
Expand All @@ -357,7 +364,7 @@ impl ModelWeights {
.as_ref()
.map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
.unwrap_or(cache as &dyn PastKvLenCache),
DType::F32,
self.dtype,
self.layers[0].n_head,
)?;
for (i, layer) in self.layers.iter().enumerate() {
Expand Down
Loading

0 comments on commit 1880c0b

Please sign in to comment.