Skip to content

Commit

Permalink
feat(xtask): 搭建服务的基本结构
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Mar 1, 2024
1 parent 0c2f76b commit 78c8001
Show file tree
Hide file tree
Showing 14 changed files with 209 additions and 60 deletions.
1 change: 1 addition & 0 deletions .cargo/config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ xtask = "run --package xtask --release --"
debug = "run --package xtask --"
cast = "xtask cast"
generate = "xtask generate"
service = "xtask service"
1 change: 1 addition & 0 deletions tokenizer/src/bpe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ impl Tokenizer for BPE {
}

fn encode(&self, text: &str) -> Vec<utok> {
let text = text.replace(' ', "▁"); // FIXME: 从 tokenizer.json 读取 normalizer
let mut tokens = Vec::new();

text.chars().map(|c| c.to_string()).for_each(|c| {
Expand Down
5 changes: 5 additions & 0 deletions transformer-cpu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ impl Transformer {
LayerCache::new_layers(&*self.model)
}

#[inline]
pub fn max_seq_len(&self) -> usize {
self.model.max_position_embeddings()
}

pub fn update(&self, tokens: &[utok], cache: &mut [LayerCache], pos: upos) -> Tensor<Storage> {
let seq_len = tokens.len() as udim;
let d = self.model.hidden_size() as udim;
Expand Down
2 changes: 1 addition & 1 deletion transformer-nvidia/src/kernel/fused_softmax.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ extern "C" __global__ void {folding}(
unreachable!();
};

let grid_dims = (seq_len, nh);
let grid_dims = (nh, seq_len);
let (kernel, block_dims) = if att_len <= self.block_size {
(&self.padding, att_len)
} else {
Expand Down
4 changes: 2 additions & 2 deletions transformer-nvidia/src/kernel/reform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ extern "C" __global__ void {name}(
];

let max_warp_per_block = self.block_size / self.warp_size;
let grid_dims = ((c + max_warp_per_block - 1) / max_warp_per_block, r);
let block_dims = (self.warp_size, (c + grid_dims.0 - 1) / grid_dims.0);
let grid_dims = (r, (c + max_warp_per_block - 1) / max_warp_per_block);
let block_dims = ((c + grid_dims.1 - 1) / grid_dims.1, self.warp_size);
self.f
.launch(grid_dims, block_dims, params.as_ptr(), 0, Some(stream));
}
Expand Down
2 changes: 1 addition & 1 deletion transformer-nvidia/src/kernel/rotary_embedding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,6 @@ extern "C" __global__ void {name}(
];

self.f
.launch((n, nh), dh / 2, params.as_ptr(), 0, Some(stream))
.launch((nh, n), dh / 2, params.as_ptr(), 0, Some(stream))
}
}
4 changes: 2 additions & 2 deletions transformer-nvidia/src/kernel/swiglu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ extern "C" __global__ void {name}(
a
}

let block_dims = gcd(di, self.block_size);
let grid_dims = (di / block_dims, seq_len);
let block_dims = gcd(self.block_size, di);
let grid_dims = (seq_len, di / block_dims);
self.f
.launch(grid_dims, block_dims, params.as_ptr(), 0, Some(stream));
}
Expand Down
4 changes: 2 additions & 2 deletions transformer-nvidia/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ impl<'a> Transformer<'a> {
cublas!(cublasCreate_v2(&mut cublas_handle));

let ctx = stream.ctx();
let _dev = ctx.dev();
let block_size = 1024;
let dev = ctx.dev();
let (block_size, _) = dev.max_block_dims();
Self {
model: ModelParameters::new(host, stream),
layers: LayersParameters::new(load_layers, host, stream),
Expand Down
51 changes: 51 additions & 0 deletions xtask/src/common.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
use common::utok;
use log::LevelFilter;
use simple_logger::SimpleLogger;
use std::{io::ErrorKind::NotFound, path::Path};
use tokenizer::{Tokenizer, VocabTxt, BPE};

pub(crate) fn logger_init(log_level: Option<String>) {
let log = log_level
.and_then(|log| match log.to_lowercase().as_str() {
"off" | "none" => Some(LevelFilter::Off),
"trace" => Some(LevelFilter::Trace),
"debug" => Some(LevelFilter::Debug),
"info" => Some(LevelFilter::Info),
"error" => Some(LevelFilter::Error),
_ => None,
})
.unwrap_or(LevelFilter::Warn);
SimpleLogger::new().with_level(log).init().unwrap();
}

pub(crate) fn tokenizer(path: Option<String>, model_dir: impl AsRef<Path>) -> Box<dyn Tokenizer> {
match path {
Some(path) => match Path::new(&path).extension() {
Some(ext) if ext == "txt" => Box::new(VocabTxt::from_txt_file(path).unwrap()),
Some(ext) if ext == "model" => Box::new(BPE::from_model_file(path).unwrap()),
_ => panic!("Tokenizer file {path:?} not supported"),
},
None => {
match BPE::from_model_file(model_dir.as_ref().join("tokenizer.model")) {
Ok(bpe) => return Box::new(bpe),
Err(e) if e.kind() == NotFound => {}
Err(e) => panic!("{e:?}"),
}
match VocabTxt::from_txt_file(model_dir.as_ref().join("vocabs.txt")) {
Ok(voc) => return Box::new(voc),
Err(e) if e.kind() == NotFound => {}
Err(e) => panic!("{e:?}"),
}
panic!("Tokenizer file not found");
}
}
}

pub(crate) fn argmax<T: PartialOrd>(logits: &[T]) -> utok {
logits
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.unwrap()
.0 as _
}
58 changes: 6 additions & 52 deletions xtask/src/generate.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
use common::utok;
use log::LevelFilter;
use simple_logger::SimpleLogger;
use crate::common::{argmax, logger_init, tokenizer};
use std::{
alloc::Layout,
collections::HashMap,
io::{ErrorKind, Write},
io::Write,
path::{Path, PathBuf},
ptr::NonNull,
sync::Mutex,
time::Instant,
};
use tokenizer::{Tokenizer, VocabTxt, BPE};
use tokenizer::Tokenizer;
use transformer_cpu::{
model_parameters::{Allocator, Llama2, Memory},
Transformer,
Expand Down Expand Up @@ -75,18 +73,7 @@ impl Allocator for NormalAllocator {

impl GenerateArgs {
pub fn invoke(self) {
let log = self
.log
.and_then(|log| match log.to_lowercase().as_str() {
"off" | "none" => Some(LevelFilter::Off),
"trace" => Some(LevelFilter::Trace),
"debug" => Some(LevelFilter::Debug),
"info" => Some(LevelFilter::Info),
"error" => Some(LevelFilter::Error),
_ => None,
})
.unwrap_or(LevelFilter::Warn);
SimpleLogger::new().with_level(log).init().unwrap();
logger_init(self.log);

let model_dir = PathBuf::from(self.model);
let step = self.step.unwrap_or(usize::MAX);
Expand Down Expand Up @@ -114,29 +101,6 @@ impl GenerateArgs {
}
}

fn tokenizer(path: Option<String>, model_dir: impl AsRef<Path>) -> Box<dyn Tokenizer> {
match path {
Some(path) => match Path::new(&path).extension() {
Some(ext) if ext == "txt" => Box::new(VocabTxt::from_txt_file(path).unwrap()),
Some(ext) if ext == "model" => Box::new(BPE::from_model_file(path).unwrap()),
_ => panic!("Tokenizer file {path:?} not supported"),
},
None => {
match BPE::from_model_file(model_dir.as_ref().join("tokenizer.model")) {
Ok(bpe) => return Box::new(bpe),
Err(e) if e.kind() == ErrorKind::NotFound => {}
Err(e) => panic!("{e:?}"),
}
match VocabTxt::from_txt_file(model_dir.as_ref().join("vocabs.txt")) {
Ok(voc) => return Box::new(voc),
Err(e) if e.kind() == ErrorKind::NotFound => {}
Err(e) => panic!("{e:?}"),
}
panic!("Tokenizer file not found");
}
}
}

fn on_host(
model_dir: impl AsRef<Path>,
tokenizer: Box<dyn Tokenizer>,
Expand All @@ -157,14 +121,13 @@ fn on_host(
model = Box::new(Memory::realloc_with(&*model, allocator));
info!("copy model ... {:?}", time.elapsed());
}
let step = step.min(model.max_position_embeddings());
let time = Instant::now();
let mut transformer = Transformer::new(model);
let mut kv_cache = transformer.new_cache();
info!("build transformer ... {:?}", time.elapsed());

let time = Instant::now();
let prompt_tokens = tokenizer.encode(&prompt.trim().replace(' ', "▁"));
let prompt_tokens = tokenizer.encode(&prompt.trim());
info!("encode prompt ... {:?}", time.elapsed());

let time = Instant::now();
Expand All @@ -179,7 +142,7 @@ fn on_host(
let mut token = *last;
let mut pos = tokens.len();
let time = Instant::now();
while pos < step {
while pos < step.min(transformer.max_seq_len()) {
let logits = transformer.forward(token, &mut kv_cache, pos as _);
let next = argmax(logits);

Expand Down Expand Up @@ -301,12 +264,3 @@ fn on_nvidia_gpu(
)
});
}

fn argmax<T: PartialOrd>(logits: &[T]) -> utok {
logits
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.unwrap()
.0 as _
}
5 changes: 5 additions & 0 deletions xtask/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
mod cast;
mod common;
mod generate;
mod service;

use clap::Parser;

Expand All @@ -13,6 +15,7 @@ fn main() {
match Cli::parse().command {
Cast(cast) => cast.invode(),
Generate(generate) => generate.invoke(),
Service(service) => service.launch(),
}
}

Expand All @@ -30,4 +33,6 @@ enum Commands {
Cast(cast::CastArgs),
/// Generate following text
Generate(generate::GenerateArgs),
/// Start LLM inference service
Service(service::ServiceArgs),
}
80 changes: 80 additions & 0 deletions xtask/src/service/cpu.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
use super::ServiceArgs;
use crate::common::{argmax, logger_init, tokenizer};
use common::upos;
use std::{collections::HashMap, path::Path, time::Instant};
use tokenizer::Tokenizer;
use transformer_cpu::{model_parameters::Memory, LayerCache, Transformer};

pub(super) struct CpuService {
transformer: Transformer,
sessions: HashMap<usize, SessionContext>,
tokenizer: Box<dyn Tokenizer>,
}

struct SessionContext {
pos: upos,
kv_cache: Vec<LayerCache>,
}

impl From<ServiceArgs> for CpuService {
fn from(args: ServiceArgs) -> Self {
logger_init(args.log);

let model_dir = Path::new(&args.model);

let time = Instant::now();
let tokenizer = tokenizer(args.tokenizer, &model_dir);
info!("build tokenizer ... {:?}", time.elapsed());

let time = Instant::now();
let model = Box::new(Memory::load_safetensors_from_dir(model_dir).unwrap());
info!("load model ... {:?}", time.elapsed());

let time = Instant::now();
let transformer = Transformer::new(model);
info!("build transformer ... {:?}", time.elapsed());

Self {
transformer,
sessions: HashMap::new(),
tokenizer,
}
}
}

impl CpuService {
pub fn run(mut self) {
loop {
let id = 0;
let prompt = "The quick brown fox jumps over the lazy dog";

let session = self.sessions.entry(id).or_insert_with(|| SessionContext {
pos: 0,
kv_cache: self.transformer.new_cache(),
});

let prompt_tokens = self.tokenizer.encode(&prompt.trim());
let (last, tokens) = prompt_tokens.split_last().expect("prompt is empty");
if !tokens.is_empty() {
self.transformer
.update(tokens, &mut session.kv_cache, session.pos as _);
session.pos += tokens.len() as upos;
}

let mut token = *last;
let max_pos = self.transformer.max_seq_len() as upos;
let mut out = String::new();
while session.pos < max_pos {
let logits =
self.transformer
.forward(token, &mut session.kv_cache, session.pos as _);
let next = argmax(logits);

token = next;
session.pos += 1;

out.push_str(&self.tokenizer.decode(next).replace('▁', " "));
}
}
}
}
37 changes: 37 additions & 0 deletions xtask/src/service/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
mod cpu;
#[cfg(detected_cuda)]
mod nvidia;

#[derive(Args, Default)]
pub(crate) struct ServiceArgs {
/// Model directory.
#[clap(short, long)]
model: String,
/// Tokenizer file.
#[clap(short, long)]
tokenizer: Option<String>,
/// Log level, may be "off", "trace", "debug", "info" or "error".
#[clap(long)]
log: Option<String>,

/// Use Nvidia GPU.
#[clap(long)]
nvidia: bool,
}

impl ServiceArgs {
pub fn launch(self) {
if self.nvidia {
#[cfg(detected_cuda)]
{
nvidia::NvidiaService::from(self).run();
}
#[cfg(not(detected_cuda))]
{
panic!("Nvidia GPU is not available");
}
} else {
cpu::CpuService::from(self).run();
}
}
}
15 changes: 15 additions & 0 deletions xtask/src/service/nvidia.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
use super::ServiceArgs;

pub(super) struct NvidiaService {}

impl From<ServiceArgs> for NvidiaService {
fn from(_: ServiceArgs) -> Self {
todo!()
}
}

impl NvidiaService {
pub fn run(self) {
todo!()
}
}

0 comments on commit 78c8001

Please sign in to comment.