diff --git a/xtask/src/cast.rs b/xtask/src/cast.rs index b53b7698..15449b71 100644 --- a/xtask/src/cast.rs +++ b/xtask/src/cast.rs @@ -11,7 +11,7 @@ pub(crate) struct CastArgs { #[clap(short, long)] target: Option, /// Target model type. - #[clap(short, long)] + #[clap(long)] dt: Option, } diff --git a/xtask/src/generate.rs b/xtask/src/generate.rs index bbe2c11a..f8636eef 100644 --- a/xtask/src/generate.rs +++ b/xtask/src/generate.rs @@ -1,7 +1,11 @@ -use crate::common::{argmax, logger_init, tokenizer}; +use crate::{ + common::{argmax, logger_init, tokenizer}, + Template, +}; use std::{ alloc::Layout, collections::HashMap, + fs::read_to_string, io::Write, path::{Path, PathBuf}, ptr::NonNull, @@ -31,9 +35,6 @@ pub(crate) struct GenerateArgs { /// Copy model parameters inside memory. #[clap(long)] inside_mem: bool, - /// Add bos before first token. - #[clap(long)] - insert_bos: bool, /// Log level, may be "off", "trace", "debug", "info" or "error". #[clap(long)] log: Option, @@ -82,21 +83,11 @@ impl GenerateArgs { let tokenizer = tokenizer(self.tokenizer, &model_dir); info!("build tokenizer ... {:?}", time.elapsed()); - let mut prompt = String::new(); - if self.insert_bos { - prompt.push_str(""); - } - match self.prompt.chars().next() { - Some(c) if c.is_ascii_alphabetic() => prompt.push(' '), - _ => {} - } - prompt.push_str(&self.prompt); - if self.nvidia { let preload_layers = if self.inside_mem { usize::MAX } else { 3 }; - on_nvidia_gpu(model_dir, tokenizer, prompt, step, preload_layers) + on_nvidia_gpu(model_dir, tokenizer, self.prompt, step, preload_layers) } else { - on_host(model_dir, tokenizer, prompt, step, self.inside_mem) + on_host(model_dir, tokenizer, self.prompt, step, self.inside_mem) } } } @@ -109,7 +100,18 @@ fn on_host( inside_mem: bool, ) { let model_dir = model_dir.as_ref(); - let prompt = prompt.as_ref(); + let template: Template = if model_dir + .as_os_str() + .to_str() + .unwrap() + .to_ascii_lowercase() + .contains("tinyllama") + { + Template::ChatTinyLlama + } else { + Template::Chat9G + }; + let prompt = apply_template(prompt.as_ref(), template); let time = Instant::now(); let mut model = Box::new(Memory::load_safetensors_from_dir(model_dir).unwrap()); @@ -121,13 +123,14 @@ fn on_host( model = Box::new(Memory::realloc_with(&*model, allocator)); info!("copy model ... {:?}", time.elapsed()); } + let eos = model.eos_token_id(); let time = Instant::now(); let mut transformer = Transformer::new(model); let mut kv_cache = transformer.new_cache(); info!("build transformer ... {:?}", time.elapsed()); let time = Instant::now(); - let prompt_tokens = tokenizer.encode(&prompt.trim()); + let prompt_tokens = tokenizer.encode(&prompt); info!("encode prompt ... {:?}", time.elapsed()); let time = Instant::now(); @@ -137,20 +140,23 @@ fn on_host( } info!("prefill transformer ... {:?}", time.elapsed()); - print!("{prompt}"); + print!("{}", prompt.replace('▁', " ")); let mut token = *last; let mut pos = tokens.len(); let time = Instant::now(); while pos < step.min(transformer.max_seq_len()) { let logits = transformer.forward(token, &mut kv_cache, pos as _); - let next = argmax(logits); + token = argmax(logits); - token = next; - pos += 1; - - print!("{}", tokenizer.decode(next).replace('▁', " ")); + print!("{}", tokenizer.decode(token).replace('▁', " ")); std::io::stdout().flush().unwrap(); + + if token == eos { + break; + } + + pos += 1; } println!(); let duration = time.elapsed(); @@ -181,7 +187,18 @@ fn on_nvidia_gpu( preload_layers: usize, ) { let model_dir = model_dir.as_ref(); - let prompt = prompt.as_ref(); + let template: Template = if model_dir + .as_os_str() + .to_str() + .unwrap() + .to_ascii_lowercase() + .contains("tinyllama") + { + Template::ChatTinyLlama + } else { + Template::Chat9G + }; + let prompt = apply_template(prompt.as_ref(), template); use std::{ fs::File, @@ -224,13 +241,14 @@ fn on_nvidia_gpu( let time = Instant::now(); let host = Memory::load_safetensors(config, host, false).unwrap(); + let eos = host.eos_token_id(); let mut transformer = Transformer::new(&host, preload_layers, &transfer); let kv_cache = transformer.new_cache(&compute); info!("build model host: {:?}", time.elapsed()); let step = step.min(host.max_position_embeddings()); let time = Instant::now(); - let prompt_tokens = tokenizer.encode(&prompt.trim().replace(' ', "▁")); + let prompt_tokens = tokenizer.encode(&prompt); info!("encode prompt ... {:?}", time.elapsed()); let time = Instant::now(); @@ -240,20 +258,23 @@ fn on_nvidia_gpu( } info!("prefill transformer ... {:?}", time.elapsed()); - print!("{prompt}"); + print!("{}", prompt.replace('▁', " ")); let mut token = *last; let mut pos = tokens.len(); let time = Instant::now(); while pos < step { let logits = transformer.forward(token, &kv_cache, pos as _, &compute, &transfer); - let next = argmax(logits); - - token = next; - pos += 1; + token = argmax(logits); - print!("{}", tokenizer.decode(next).replace('▁', " ")); + print!("{}", tokenizer.decode(token).replace('▁', " ")); std::io::stdout().flush().unwrap(); + + if token == eos { + break; + } + + pos += 1; } println!(); let duration = time.elapsed(); @@ -264,3 +285,39 @@ fn on_nvidia_gpu( ) }); } + +#[inline] +fn apply_template(prompt: &str, template: Template) -> String { + let maybe_file = Path::new(&prompt); + let prompt = if maybe_file.is_file() { + read_to_string(maybe_file).unwrap() + } else { + prompt.to_string() + }; + let prompt = prompt.trim(); + let mut ans = String::new(); + match template { + Template::Chat9G => { + ans.push_str(""); + match prompt.chars().next() { + Some(c) if c.is_ascii_alphabetic() => ans.push(' '), + _ => {} + } + ans.push_str(prompt); + ans + } + Template::ChatTinyLlama => { + match prompt.chars().next() { + Some(c) if c.is_ascii_alphabetic() => ans.push('▁'), + _ => {} + } + for c in prompt.chars() { + ans.push(match c { + ' ' => '▁', + c => c, + }); + } + ans + } + } +} diff --git a/xtask/src/main.rs b/xtask/src/main.rs index 6dc70f17..8c2c6f09 100644 --- a/xtask/src/main.rs +++ b/xtask/src/main.rs @@ -36,3 +36,10 @@ enum Commands { /// Start LLM inference service Service(service::ServiceArgs), } + +#[derive(Clone, Copy, Debug)] +#[repr(u8)] +enum Template { + Chat9G, + ChatTinyLlama, +} diff --git a/xtask/src/service/chat.rs b/xtask/src/service/chat.rs new file mode 100644 index 00000000..4c23ea54 --- /dev/null +++ b/xtask/src/service/chat.rs @@ -0,0 +1,12 @@ +use crate::Template; + +#[inline] +pub(super) fn apply_chat(prompt: &str, template: Template) -> String { + match template { + Template::Chat9G => todo!(), + Template::ChatTinyLlama => format!( + "<|user|>\n{}\n<|assistant|>\n", + prompt.replace(' ', "▁") + ), + } +} diff --git a/xtask/src/service/cpu.rs b/xtask/src/service/cpu.rs index cda0b548..d3bfa26f 100644 --- a/xtask/src/service/cpu.rs +++ b/xtask/src/service/cpu.rs @@ -1,13 +1,22 @@ -use super::{channel::channel, ServiceArgs}; +use super::{channel::channel, chat::apply_chat, ServiceArgs}; use crate::{ common::{argmax, tokenizer}, service::channel::{Query, Response}, + Template, }; use common::upos; use std::{collections::HashMap, path::Path, time::Instant}; -use transformer_cpu::{model_parameters::Memory, LayerCache, Transformer}; +use transformer_cpu::{ + model_parameters::{Llama2, Memory}, + LayerCache, Transformer, +}; pub(super) fn run(args: ServiceArgs) { + let template = if args.model.to_ascii_lowercase().contains("tinyllama") { + Template::ChatTinyLlama + } else { + Template::Chat9G + }; let model_dir = Path::new(&args.model); let time = Instant::now(); @@ -15,9 +24,12 @@ pub(super) fn run(args: ServiceArgs) { info!("build tokenizer ... {:?}", time.elapsed()); let time = Instant::now(); - let model = Box::new(Memory::load_safetensors_from_dir(model_dir).unwrap()); + let model = Box::new(Memory::load_safetensors_from_dir(&model_dir).unwrap()); info!("load model ... {:?}", time.elapsed()); + let _bos = model.bos_token_id(); + let eos = model.eos_token_id(); + let time = Instant::now(); let mut transformer = Transformer::new(model); info!("build transformer ... {:?}", time.elapsed()); @@ -35,14 +47,16 @@ pub(super) fn run(args: ServiceArgs) { loop { let Query { id, prompt } = channel.receive().unwrap(); + let prompt = apply_chat(&prompt, template); + + let prompt_tokens = tokenizer.encode(&prompt.trim()); + let (last, tokens) = prompt_tokens.split_last().expect("prompt is empty"); let session = sessions.entry(id).or_insert_with(|| SessionContext { pos: 0, kv_cache: transformer.new_cache(), }); - let prompt_tokens = tokenizer.encode(&prompt.trim()); - let (last, tokens) = prompt_tokens.split_last().expect("prompt is empty"); if !tokens.is_empty() { transformer.update(tokens, &mut session.kv_cache, session.pos as _); session.pos += tokens.len() as upos; @@ -53,12 +67,13 @@ pub(super) fn run(args: ServiceArgs) { let mut out = String::new(); while session.pos < max_pos { let logits = transformer.forward(token, &mut session.kv_cache, session.pos as _); - let next = argmax(logits); + token = argmax(logits); + if token == eos { + break; + } - token = next; + out.push_str(&tokenizer.decode(token).replace('▁', " ")); session.pos += 1; - - out.push_str(&tokenizer.decode(next).replace('▁', " ")); } channel.send(Response { id, prompt: out }).unwrap(); diff --git a/xtask/src/service/mod.rs b/xtask/src/service/mod.rs index b5f860be..132baba5 100644 --- a/xtask/src/service/mod.rs +++ b/xtask/src/service/mod.rs @@ -1,4 +1,5 @@ mod channel; +mod chat; mod cpu; #[cfg(detected_cuda)] mod nvidia;