Skip to content

Commit

Permalink
refactor(tranformer-nvidia): 利用上下文孢子模式彻底移除 Nvidia Transformer 特殊性
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Mar 28, 2024
1 parent 745948f commit aa34b20
Show file tree
Hide file tree
Showing 3 changed files with 196 additions and 143 deletions.
117 changes: 45 additions & 72 deletions service/src/nvidia.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,11 @@ use common::utok;
use std::{
collections::HashMap,
fs::File,
io::Read,
path::Path,
sync::{mpsc::Receiver, Arc, Mutex},
time::Instant,
};
use transformer_cpu::{Llama2, Memory, SampleArgs};
use transformer_nvidia::{
cuda::{ContextResource, Device, Stream},
LayerCache, Request, Transformer,
};
use transformer_nvidia::{cuda::Device, LayerCache, Llama2, Request, SampleArgs, Transformer};

pub fn task(
device: Device,
Expand All @@ -25,89 +20,67 @@ pub fn task(

let time = Instant::now();
let config = File::open(model_dir.join("config.json")).unwrap();
let mut safetensors = File::open(model_dir.join("model.safetensors")).unwrap();
let safetensors = File::open(model_dir.join("model.safetensors")).unwrap();
info!("open file {:?}", time.elapsed());

device.context().apply(|ctx| {
let time = Instant::now();
let host = ctx.malloc_host::<u8>(safetensors.metadata().unwrap().len() as _);
let mut host = host.sporulate();
safetensors.read_exact(&mut host).unwrap();
drop(safetensors);
info!("read to host {:?}", time.elapsed());

let compute = ctx.stream();
let transfer = ctx.stream();

let time = Instant::now();
let host = Memory::load_safetensors(config, host, false).unwrap();
let max_seq_len = host.max_position_embeddings();
let eos = host.eos_token_id();
let transformer = Transformer::new(Box::new(host), usize::MAX, &transfer);
info!("build model host: {:?}", time.elapsed());

let mut sessions = HashMap::new();
let context = Arc::new(device.context());
let transformer = Transformer::new(config, safetensors, usize::MAX, context.clone());
let mut sessions = HashMap::new();

while let Ok(cmd) = receiver.recv() {
match cmd {
Command::Infer {
id,
prompt,
responsing,
} => {
let ctx = sessions
.entry(id)
.or_insert_with_key(|&id| SessionContext::new(&transformer, id, &transfer));
let max_seq_len = transformer.model().max_position_embeddings();
let eos = transformer.model().eos_token_id();
while let Ok(cmd) = receiver.recv() {
match cmd {
Command::Infer {
id,
prompt,
responsing,
} => {
let ctx = sessions
.entry(id)
.or_insert_with_key(|&id| SessionContext::new(&transformer, id));

let t0 = Instant::now();
let mut token = transformer.decode(
vec![ctx.request(&prompt, max_seq_len)],
let t0 = Instant::now();
let mut token = transformer.decode(
vec![ctx.request(&prompt, max_seq_len)],
&sample.lock().unwrap(),
)[0]
.1;
let t1 = Instant::now();
let mut len = 0;
while token != eos {
responsing.send(token).unwrap();
token = transformer.decode(
vec![ctx.request(&[token], max_seq_len)],
&sample.lock().unwrap(),
&compute,
&transfer,
)[0]
.1;
let t1 = Instant::now();
let mut len = 0;
while token != eos {
responsing.send(token).unwrap();
token = transformer.decode(
vec![ctx.request(&[token], max_seq_len)],
&sample.lock().unwrap(),
&compute,
&transfer,
)[0]
.1;
len += 1;
}
let t2 = Instant::now();
info!(
"First token delay: {:?}, average speed = {:?}/tok",
t1 - t0,
(t2 - t1).div_f32(len as _)
);
}
Command::Drop { id } => {
sessions.remove(&id);
len += 1;
}
let t2 = Instant::now();
info!(
"First token delay: {:?}, average speed = {:?}/tok",
t1 - t0,
(t2 - t1).div_f32(len as _)
);
}
Command::Drop { id } => {
sessions.remove(&id);
}
}
});
}
}

struct SessionContext<'a>(session::SessionContext<LayerCache<'a>>);
struct SessionContext(session::SessionContext<LayerCache>);

impl<'a> SessionContext<'a> {
impl SessionContext {
#[inline]
fn new(transformer: &Transformer, id: usize, stream: &'a Stream) -> Self {
Self(session::SessionContext::new(
transformer.new_cache(stream),
id,
))
fn new(transformer: &Transformer, id: usize) -> Self {
Self(session::SessionContext::new(transformer.new_cache(), id))
}

#[inline]
fn request(&mut self, tokens: &[utok], max_seq_len: usize) -> Request<'_, 'a, usize> {
fn request(&mut self, tokens: &[utok], max_seq_len: usize) -> Request<'_, usize> {
let pos = self.0.request(tokens, max_seq_len);
Request::new(
self.0.id,
Expand Down
14 changes: 11 additions & 3 deletions transformer-cpu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,22 @@ impl Transformer {
for layer in 0..self.0.num_hidden_layers() {
let (q, k, v) = self.before_att(layer, &x0, &mut x1, &mut buf.qkv, &pos);
let o = &mut x1;
#[rustfmt::skip]
self.attention(layer, &mut requests, q, k, v, o, &mut buf.q_buf, &mut buf.att_buf);
self.attention(
layer,
&mut requests,
q,
k,
v,
o,
&mut buf.q_buf,
&mut buf.att_buf,
);
self.after_att(layer, &mut x0, &mut x1, &mut buf.gate_up);
}
// 解码
if requests[0].decode() {
let x = self.move_decode(&requests, x0);
let requests = requests.into_iter().map(Request::id).collect::<Vec<_>>();
let requests = requests.into_iter().map(Request::id).collect();
Sample.sample(sample, requests, self.logits(x))
} else {
vec![]
Expand Down
Loading

0 comments on commit aa34b20

Please sign in to comment.