Skip to content

Commit

Permalink
feat(xtask): 暂时根据模型目录的名字推断需要对 prompt 做哪种预处理
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Mar 4, 2024
1 parent b6f6f51 commit 724b64c
Show file tree
Hide file tree
Showing 6 changed files with 134 additions and 42 deletions.
2 changes: 1 addition & 1 deletion xtask/src/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pub(crate) struct CastArgs {
#[clap(short, long)]
target: Option<String>,
/// Target model type.
#[clap(short, long)]
#[clap(long)]
dt: Option<String>,
}

Expand Down
121 changes: 89 additions & 32 deletions xtask/src/generate.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
use crate::common::{argmax, logger_init, tokenizer};
use crate::{
common::{argmax, logger_init, tokenizer},
Template,
};
use std::{
alloc::Layout,
collections::HashMap,
fs::read_to_string,
io::Write,
path::{Path, PathBuf},
ptr::NonNull,
Expand Down Expand Up @@ -31,9 +35,6 @@ pub(crate) struct GenerateArgs {
/// Copy model parameters inside memory.
#[clap(long)]
inside_mem: bool,
/// Add bos before first token.
#[clap(long)]
insert_bos: bool,
/// Log level, may be "off", "trace", "debug", "info" or "error".
#[clap(long)]
log: Option<String>,
Expand Down Expand Up @@ -82,21 +83,11 @@ impl GenerateArgs {
let tokenizer = tokenizer(self.tokenizer, &model_dir);
info!("build tokenizer ... {:?}", time.elapsed());

let mut prompt = String::new();
if self.insert_bos {
prompt.push_str("<s>");
}
match self.prompt.chars().next() {
Some(c) if c.is_ascii_alphabetic() => prompt.push(' '),
_ => {}
}
prompt.push_str(&self.prompt);

if self.nvidia {
let preload_layers = if self.inside_mem { usize::MAX } else { 3 };
on_nvidia_gpu(model_dir, tokenizer, prompt, step, preload_layers)
on_nvidia_gpu(model_dir, tokenizer, self.prompt, step, preload_layers)
} else {
on_host(model_dir, tokenizer, prompt, step, self.inside_mem)
on_host(model_dir, tokenizer, self.prompt, step, self.inside_mem)
}
}
}
Expand All @@ -109,7 +100,18 @@ fn on_host(
inside_mem: bool,
) {
let model_dir = model_dir.as_ref();
let prompt = prompt.as_ref();
let template: Template = if model_dir
.as_os_str()
.to_str()
.unwrap()
.to_ascii_lowercase()
.contains("tinyllama")
{
Template::ChatTinyLlama
} else {
Template::Chat9G
};
let prompt = apply_template(prompt.as_ref(), template);

let time = Instant::now();
let mut model = Box::new(Memory::load_safetensors_from_dir(model_dir).unwrap());
Expand All @@ -121,13 +123,14 @@ fn on_host(
model = Box::new(Memory::realloc_with(&*model, allocator));
info!("copy model ... {:?}", time.elapsed());
}
let eos = model.eos_token_id();
let time = Instant::now();
let mut transformer = Transformer::new(model);
let mut kv_cache = transformer.new_cache();
info!("build transformer ... {:?}", time.elapsed());

let time = Instant::now();
let prompt_tokens = tokenizer.encode(&prompt.trim());
let prompt_tokens = tokenizer.encode(&prompt);
info!("encode prompt ... {:?}", time.elapsed());

let time = Instant::now();
Expand All @@ -137,20 +140,23 @@ fn on_host(
}
info!("prefill transformer ... {:?}", time.elapsed());

print!("{prompt}");
print!("{}", prompt.replace('▁', " "));

let mut token = *last;
let mut pos = tokens.len();
let time = Instant::now();
while pos < step.min(transformer.max_seq_len()) {
let logits = transformer.forward(token, &mut kv_cache, pos as _);
let next = argmax(logits);
token = argmax(logits);

token = next;
pos += 1;

print!("{}", tokenizer.decode(next).replace('▁', " "));
print!("{}", tokenizer.decode(token).replace('▁', " "));
std::io::stdout().flush().unwrap();

if token == eos {
break;
}

pos += 1;
}
println!();
let duration = time.elapsed();
Expand Down Expand Up @@ -181,7 +187,18 @@ fn on_nvidia_gpu(
preload_layers: usize,
) {
let model_dir = model_dir.as_ref();
let prompt = prompt.as_ref();
let template: Template = if model_dir
.as_os_str()
.to_str()
.unwrap()
.to_ascii_lowercase()
.contains("tinyllama")
{
Template::ChatTinyLlama
} else {
Template::Chat9G
};
let prompt = apply_template(prompt.as_ref(), template);

use std::{
fs::File,
Expand Down Expand Up @@ -224,13 +241,14 @@ fn on_nvidia_gpu(

let time = Instant::now();
let host = Memory::load_safetensors(config, host, false).unwrap();
let eos = host.eos_token_id();
let mut transformer = Transformer::new(&host, preload_layers, &transfer);
let kv_cache = transformer.new_cache(&compute);
info!("build model host: {:?}", time.elapsed());

let step = step.min(host.max_position_embeddings());
let time = Instant::now();
let prompt_tokens = tokenizer.encode(&prompt.trim().replace(' ', "▁"));
let prompt_tokens = tokenizer.encode(&prompt);
info!("encode prompt ... {:?}", time.elapsed());

let time = Instant::now();
Expand All @@ -240,20 +258,23 @@ fn on_nvidia_gpu(
}
info!("prefill transformer ... {:?}", time.elapsed());

print!("{prompt}");
print!("{}", prompt.replace('▁', " "));

let mut token = *last;
let mut pos = tokens.len();
let time = Instant::now();
while pos < step {
let logits = transformer.forward(token, &kv_cache, pos as _, &compute, &transfer);
let next = argmax(logits);

token = next;
pos += 1;
token = argmax(logits);

print!("{}", tokenizer.decode(next).replace('▁', " "));
print!("{}", tokenizer.decode(token).replace('▁', " "));
std::io::stdout().flush().unwrap();

if token == eos {
break;
}

pos += 1;
}
println!();
let duration = time.elapsed();
Expand All @@ -264,3 +285,39 @@ fn on_nvidia_gpu(
)
});
}

#[inline]
fn apply_template(prompt: &str, template: Template) -> String {
let maybe_file = Path::new(&prompt);
let prompt = if maybe_file.is_file() {
read_to_string(maybe_file).unwrap()
} else {
prompt.to_string()
};
let prompt = prompt.trim();
let mut ans = String::new();
match template {
Template::Chat9G => {
ans.push_str("<s>");
match prompt.chars().next() {
Some(c) if c.is_ascii_alphabetic() => ans.push(' '),
_ => {}
}
ans.push_str(prompt);
ans
}
Template::ChatTinyLlama => {
match prompt.chars().next() {
Some(c) if c.is_ascii_alphabetic() => ans.push('▁'),
_ => {}
}
for c in prompt.chars() {
ans.push(match c {
' ' => '▁',
c => c,
});
}
ans
}
}
}
7 changes: 7 additions & 0 deletions xtask/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,10 @@ enum Commands {
/// Start LLM inference service
Service(service::ServiceArgs),
}

#[derive(Clone, Copy, Debug)]
#[repr(u8)]
enum Template {
Chat9G,
ChatTinyLlama,
}
12 changes: 12 additions & 0 deletions xtask/src/service/chat.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
use crate::Template;

#[inline]
pub(super) fn apply_chat(prompt: &str, template: Template) -> String {
match template {
Template::Chat9G => todo!(),
Template::ChatTinyLlama => format!(
"<|user|>\n{}</s>\n<|assistant|>\n",
prompt.replace(' ', "▁")
),
}
}
33 changes: 24 additions & 9 deletions xtask/src/service/cpu.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,35 @@
use super::{channel::channel, ServiceArgs};
use super::{channel::channel, chat::apply_chat, ServiceArgs};
use crate::{
common::{argmax, tokenizer},
service::channel::{Query, Response},
Template,
};
use common::upos;
use std::{collections::HashMap, path::Path, time::Instant};
use transformer_cpu::{model_parameters::Memory, LayerCache, Transformer};
use transformer_cpu::{
model_parameters::{Llama2, Memory},
LayerCache, Transformer,
};

pub(super) fn run(args: ServiceArgs) {
let template = if args.model.to_ascii_lowercase().contains("tinyllama") {
Template::ChatTinyLlama
} else {
Template::Chat9G
};
let model_dir = Path::new(&args.model);

let time = Instant::now();
let tokenizer = tokenizer(args.tokenizer, &model_dir);
info!("build tokenizer ... {:?}", time.elapsed());

let time = Instant::now();
let model = Box::new(Memory::load_safetensors_from_dir(model_dir).unwrap());
let model = Box::new(Memory::load_safetensors_from_dir(&model_dir).unwrap());
info!("load model ... {:?}", time.elapsed());

let _bos = model.bos_token_id();
let eos = model.eos_token_id();

let time = Instant::now();
let mut transformer = Transformer::new(model);
info!("build transformer ... {:?}", time.elapsed());
Expand All @@ -35,14 +47,16 @@ pub(super) fn run(args: ServiceArgs) {

loop {
let Query { id, prompt } = channel.receive().unwrap();
let prompt = apply_chat(&prompt, template);

let prompt_tokens = tokenizer.encode(&prompt.trim());
let (last, tokens) = prompt_tokens.split_last().expect("prompt is empty");

let session = sessions.entry(id).or_insert_with(|| SessionContext {
pos: 0,
kv_cache: transformer.new_cache(),
});

let prompt_tokens = tokenizer.encode(&prompt.trim());
let (last, tokens) = prompt_tokens.split_last().expect("prompt is empty");
if !tokens.is_empty() {
transformer.update(tokens, &mut session.kv_cache, session.pos as _);
session.pos += tokens.len() as upos;
Expand All @@ -53,12 +67,13 @@ pub(super) fn run(args: ServiceArgs) {
let mut out = String::new();
while session.pos < max_pos {
let logits = transformer.forward(token, &mut session.kv_cache, session.pos as _);
let next = argmax(logits);
token = argmax(logits);
if token == eos {
break;
}

token = next;
out.push_str(&tokenizer.decode(token).replace('▁', " "));
session.pos += 1;

out.push_str(&tokenizer.decode(next).replace('▁', " "));
}

channel.send(Response { id, prompt: out }).unwrap();
Expand Down
1 change: 1 addition & 0 deletions xtask/src/service/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
mod channel;
mod chat;
mod cpu;
#[cfg(detected_cuda)]
mod nvidia;
Expand Down

0 comments on commit 724b64c

Please sign in to comment.