Skip to content

Commit

Permalink
Support chat template from GGUF (#416)
Browse files Browse the repository at this point in the history
* Support loading chat template from gguf

* Implement sourcing chat template from gguf

* Add some logging

* Handle case where using gguf chat template

* Handle case where using gguf chat template

* Update docs
  • Loading branch information
EricLBuehler authored Jun 10, 2024
1 parent 46b0364 commit ac1537d
Show file tree
Hide file tree
Showing 13 changed files with 187 additions and 75 deletions.
10 changes: 9 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,13 @@ Throughout mistral.rs, any model ID argument or option may be a local path and s

### Running GGUF models locally

To run GGUF models fully locally, you do not need to specify the tokenizer model ID argument and instead should pass a path to the
To run GGUF models fully locally, the only mandatory arguments are the quantized model ID and the quantized filename.

#### Chat template

The chat template can be automatically detected and loaded from the GGUF file if no other chat template source is specified including the tokenizer model ID.

you do not need to specify the tokenizer model ID argument and instead should pass a path to the
chat template JSON file (examples [here](chat_templates), you will need to create your own by specifying the chat template and `bos`/`eos` tokens) as well as specifying a local model ID. For example:

```bash
Expand All @@ -318,6 +324,8 @@ chat template JSON file (examples [here](chat_templates), you will need to creat

If you do not specify a chat template, then the `--tok-model-id`/`-t` tokenizer model ID argument is expected where the `tokenizer_config.json` file should be provided. If that model ID contains a `tokenizer.json`, then that will be used over the GGUF tokenizer.

#### Tokenizer

The following tokenizer model types are currently supported. If you would like one to be added, please raise an issue. Otherwise,
please consider using the method demonstrated in examples below, where the tokenizer is sourced from Hugging Face.

Expand Down
6 changes: 5 additions & 1 deletion docs/CHAT_TOK.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ For example, to use the `chatml` template, `--chat-template` is specified *befor
./mitralrs-server --port 1234 --log output.log --chat-template ./chat_templates/chatml.json llama
```

> Note: For GGUF models, the chat template may be loaded directly from the GGUF file by omitting any other chat template sources.
## Tokenizer

Some models do not provide a `tokenizer.json` file although mistral.rs expects one. To solve this, please run [this](../scripts/get_tokenizers_json.py) script. It will output the `tokenizer.json` file for your specific model. This may be used by passing the `--tokenizer-json` flag *after* the model architecture. For example:
Expand All @@ -24,4 +26,6 @@ $ ./mistralrs_server --port 1234 --log output.log plain -m microsoft/Orca-2-13b
Putting it all together, to run, for example, an [Orca](https://huggingface.co/microsoft/Orca-2-13b) model (which does not come with a `tokenizer.json` or chat template):
1) Generate the `tokenizer.json` by running the script at `scripts/get_tokenizers_json.py`. This will output some files including `tokenizer.json` in the working directory.
2) Find and copy the correct chat template from `chat-templates` to the working directory (eg., `cp chat_templates/chatml.json .`)
3) Run `mistralrs-server`, specifying the tokenizer and chat template: `cargo run --release --features cuda -- --port 1234 --log output.txt --chat-template chatml.json plain -m microsoft/Orca-2-13b -t tokenizer.json -a llama`
3) Run `mistralrs-server`, specifying the tokenizer and chat template: `cargo run --release --features cuda -- --port 1234 --log output.txt --chat-template chatml.json plain -m microsoft/Orca-2-13b -t tokenizer.json -a llama`

> Note: For GGUF models, the tokenizer may be loaded directly from the GGUF file by omitting the tokenizer model ID.
39 changes: 39 additions & 0 deletions mistralrs-core/src/gguf/chat_template.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
use anyhow::Result;
use candle_core::quantized::gguf_file::Content;
use tracing::info;

use crate::utils::gguf_metadata::ContentMetadata;

struct PropsGGUFTemplate {
chat_template: Option<String>,
}

impl TryFrom<ContentMetadata<'_>> for PropsGGUFTemplate {
type Error = anyhow::Error;

fn try_from(c: ContentMetadata) -> Result<Self, Self::Error> {
// No required keys

let props = Self {
chat_template: c.get_option_value("chat_template")?,
};

Ok(props)
}
}

// Get chat template from GGUF metadata if it exists
pub fn get_gguf_chat_template(content: &Content) -> Result<Option<String>> {
let metadata = ContentMetadata {
path_prefix: "tokenizer",
metadata: &content.metadata,
};
let props = PropsGGUFTemplate::try_from(metadata)?;
if let Some(ref chat_template) = props.chat_template {
info!(
"Discovered and using GGUF chat template: `{}`",
chat_template.replace('\n', "\\n")
);
}
Ok(props.chat_template)
}
2 changes: 2 additions & 0 deletions mistralrs-core/src/gguf/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
mod chat_template;
mod gguf_tokenizer;

pub use chat_template::get_gguf_chat_template;
pub(crate) use gguf_tokenizer::{convert_gguf_to_hf_tokenizer, GgufTokenizerConversion};
2 changes: 1 addition & 1 deletion mistralrs-core/src/pipeline/chat_template.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ pub struct BeginEndUnkTok(
);

#[allow(dead_code)]
#[derive(Debug, Deserialize)]
#[derive(Debug, Deserialize, Default)]
/// Template for chat models including bos/eos/unk as well as the chat template.
pub struct ChatTemplate {
add_bos_token: Option<bool>,
Expand Down
2 changes: 1 addition & 1 deletion mistralrs-core/src/pipeline/ggml.rs
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ impl Loader for GGMLLoader {
let gen_conf: Option<GenerationConfig> = paths
.get_gen_conf_filename()
.map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
let chat_template = get_chat_template(paths, &self.chat_template);
let chat_template = get_chat_template(paths, &self.chat_template, None);

let max_seq_len = match model {
Model::Llama(ref l) => l.max_seq_len,
Expand Down
21 changes: 17 additions & 4 deletions mistralrs-core/src/pipeline/gguf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,23 @@ use super::{
};
use crate::aici::bintokens::build_tok_trie;
use crate::aici::toktree::TokTrie;
use crate::gguf::{convert_gguf_to_hf_tokenizer, GgufTokenizerConversion};
use crate::gguf::{
get_gguf_chat_template, {convert_gguf_to_hf_tokenizer, GgufTokenizerConversion},
};
use crate::lora::Ordering;
use crate::pipeline::chat_template::{calculate_eos_tokens, BeginEndUnkTok, GenerationConfig};
use crate::pipeline::ChatTemplate;
use crate::pipeline::{get_chat_template, Cache};
use crate::pipeline::{ChatTemplate, LocalModelPaths};
use crate::prefix_cacher::PrefixCacheManager;
use crate::sequence::Sequence;
use crate::utils::debug::setup_logger_and_debug;
use crate::utils::model_config as ModelConfig;
use crate::utils::tokenizer::get_tokenizer;
use crate::xlora_models::NonGranularState;
use crate::{do_sample, get_mut_arcmutex, get_paths_gguf, DeviceMapMetadata, Pipeline, DEBUG};
use crate::{
do_sample, get_mut_arcmutex, get_paths_gguf, DeviceMapMetadata, LocalModelPaths, Pipeline,
DEBUG,
};
use crate::{
models::quantized_llama::ModelWeights as QLlama,
models::quantized_phi2::ModelWeights as QPhi,
Expand Down Expand Up @@ -378,6 +383,14 @@ impl Loader for GGUFLoader {
}
};

// Only load gguf chat template if there is nothing else
let gguf_chat_template =
if paths.get_template_filename().is_none() && self.chat_template.is_none() {
get_gguf_chat_template(&model)?
} else {
None
};

let has_adapter = self.kind.is_adapted();
let is_xlora = self.kind.is_adapted_and(|a| a.is_x_lora());

Expand Down Expand Up @@ -421,7 +434,7 @@ impl Loader for GGUFLoader {
let gen_conf: Option<GenerationConfig> = paths
.get_gen_conf_filename()
.map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
let mut chat_template = get_chat_template(paths, &self.chat_template);
let mut chat_template = get_chat_template(paths, &self.chat_template, gguf_chat_template);

let max_seq_len = match model {
Model::Llama(ref l) => l.max_seq_len,
Expand Down
34 changes: 24 additions & 10 deletions mistralrs-core/src/pipeline/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,17 @@ macro_rules! get_paths {
None
};

info!("Loading `tokenizer_config.json` at `{}`", $this.model_id);
let template_filename = $crate::api_get_file!(api, "tokenizer_config.json", model_id);
let template_filename = if let Some(ref p) = $this.chat_template {
info!("Using chat template file at `{p}`");
Some(PathBuf::from_str(p)?)
} else {
info!("Loading `tokenizer_config.json` at `{}`", $this.model_id);
Some($crate::api_get_file!(
api,
"tokenizer_config.json",
model_id
))
};

Ok(Box::new($path_name {
tokenizer_filename,
Expand Down Expand Up @@ -209,17 +218,22 @@ macro_rules! get_paths_gguf {
let chat_template = if let Some(ref p) = $this.chat_template {
if p.ends_with(".json") {
info!("Using chat template file at `{p}`");
PathBuf::from_str(p)?
Some(PathBuf::from_str(p)?)
} else {
PathBuf::from_str("")?
panic!("Specified chat template file must end with .json");
}
} else {
info!("Loading `tokenizer_config.json` at `{}` because no chat template file was specified.", this_model_id);
$crate::api_get_file!(
api,
"tokenizer_config.json",
model_id
) // Will be loaded from inside gguf file
if $this.model_id.is_none() {
None
} else {
info!("Loading `tokenizer_config.json` at `{}` because no chat template file was specified.", this_model_id);
let res = $crate::api_get_file!(
api,
"tokenizer_config.json",
model_id
);
Some(res)
}
};

let filenames = get_model_paths(
Expand Down
10 changes: 5 additions & 5 deletions mistralrs-core/src/pipeline/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ pub trait ModelPaths {
/// [`tokenizers.Tokenizer`]: https://huggingface.co/docs/transformers/v4.40.2/en/main_classes/tokenizer
fn get_tokenizer_filename(&self) -> &PathBuf;

/// Content expected to deserialize to [`ChatTemplate`].
fn get_template_filename(&self) -> &PathBuf;
/// File where the content is expected to deserialize to [`ChatTemplate`].
fn get_template_filename(&self) -> &Option<PathBuf>;

/// Optional adapter files. `(String, PathBuf)` is of the form `(id name, path)`.
fn get_adapter_filenames(&self) -> &Option<Vec<(String, PathBuf)>>;
Expand Down Expand Up @@ -107,7 +107,7 @@ pub trait ModelPaths {
pub struct LocalModelPaths<P> {
tokenizer_filename: P,
config_filename: P,
template_filename: P,
template_filename: Option<P>,
filenames: Vec<P>,
xlora_adapter_filenames: Option<Vec<(String, P)>>,
xlora_adapter_configs: Option<Vec<((String, String), LoraConfig)>>,
Expand Down Expand Up @@ -140,7 +140,7 @@ impl<P> LocalModelPaths<P> {
Self {
tokenizer_filename,
config_filename,
template_filename,
template_filename: Some(template_filename),
filenames,
xlora_adapter_filenames,
xlora_adapter_configs,
Expand Down Expand Up @@ -180,7 +180,7 @@ impl ModelPaths for LocalModelPaths<PathBuf> {
fn get_ordering(&self) -> &Option<Ordering> {
&self.xlora_ordering
}
fn get_template_filename(&self) -> &PathBuf {
fn get_template_filename(&self) -> &Option<PathBuf> {
&self.template_filename
}
fn get_gen_conf_filename(&self) -> Option<&PathBuf> {
Expand Down
2 changes: 1 addition & 1 deletion mistralrs-core/src/pipeline/normal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ impl Loader for NormalLoader {
let gen_conf: Option<GenerationConfig> = paths
.get_gen_conf_filename()
.map(|f| serde_json::from_str(&fs::read_to_string(f).unwrap()).unwrap());
let chat_template = get_chat_template(paths, &self.chat_template);
let chat_template = get_chat_template(paths, &self.chat_template, None);

if let Some(in_situ_quant) = in_situ_quant {
model.quantize(in_situ_quant, device.clone())?;
Expand Down
Loading

0 comments on commit ac1537d

Please sign in to comment.