From aa34b205bda0daa125f9b7b7c0f856d31824bcae Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Thu, 28 Mar 2024 15:45:39 +0800 Subject: [PATCH] =?UTF-8?q?refactor(tranformer-nvidia):=20=E5=88=A9?= =?UTF-8?q?=E7=94=A8=E4=B8=8A=E4=B8=8B=E6=96=87=E5=AD=A2=E5=AD=90=E6=A8=A1?= =?UTF-8?q?=E5=BC=8F=E5=BD=BB=E5=BA=95=E7=A7=BB=E9=99=A4=20Nvidia=20Transf?= =?UTF-8?q?ormer=20=E7=89=B9=E6=AE=8A=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- service/src/nvidia.rs | 117 ++++++++----------- transformer-cpu/src/lib.rs | 14 ++- transformer-nvidia/src/lib.rs | 208 +++++++++++++++++++++++----------- 3 files changed, 196 insertions(+), 143 deletions(-) diff --git a/service/src/nvidia.rs b/service/src/nvidia.rs index 2fc97602..234857c7 100644 --- a/service/src/nvidia.rs +++ b/service/src/nvidia.rs @@ -3,16 +3,11 @@ use common::utok; use std::{ collections::HashMap, fs::File, - io::Read, path::Path, sync::{mpsc::Receiver, Arc, Mutex}, time::Instant, }; -use transformer_cpu::{Llama2, Memory, SampleArgs}; -use transformer_nvidia::{ - cuda::{ContextResource, Device, Stream}, - LayerCache, Request, Transformer, -}; +use transformer_nvidia::{cuda::Device, LayerCache, Llama2, Request, SampleArgs, Transformer}; pub fn task( device: Device, @@ -25,89 +20,67 @@ pub fn task( let time = Instant::now(); let config = File::open(model_dir.join("config.json")).unwrap(); - let mut safetensors = File::open(model_dir.join("model.safetensors")).unwrap(); + let safetensors = File::open(model_dir.join("model.safetensors")).unwrap(); info!("open file {:?}", time.elapsed()); - device.context().apply(|ctx| { - let time = Instant::now(); - let host = ctx.malloc_host::(safetensors.metadata().unwrap().len() as _); - let mut host = host.sporulate(); - safetensors.read_exact(&mut host).unwrap(); - drop(safetensors); - info!("read to host {:?}", time.elapsed()); - - let compute = ctx.stream(); - let transfer = ctx.stream(); - - let time = Instant::now(); - let host = Memory::load_safetensors(config, host, false).unwrap(); - let max_seq_len = host.max_position_embeddings(); - let eos = host.eos_token_id(); - let transformer = Transformer::new(Box::new(host), usize::MAX, &transfer); - info!("build model host: {:?}", time.elapsed()); - - let mut sessions = HashMap::new(); + let context = Arc::new(device.context()); + let transformer = Transformer::new(config, safetensors, usize::MAX, context.clone()); + let mut sessions = HashMap::new(); - while let Ok(cmd) = receiver.recv() { - match cmd { - Command::Infer { - id, - prompt, - responsing, - } => { - let ctx = sessions - .entry(id) - .or_insert_with_key(|&id| SessionContext::new(&transformer, id, &transfer)); + let max_seq_len = transformer.model().max_position_embeddings(); + let eos = transformer.model().eos_token_id(); + while let Ok(cmd) = receiver.recv() { + match cmd { + Command::Infer { + id, + prompt, + responsing, + } => { + let ctx = sessions + .entry(id) + .or_insert_with_key(|&id| SessionContext::new(&transformer, id)); - let t0 = Instant::now(); - let mut token = transformer.decode( - vec![ctx.request(&prompt, max_seq_len)], + let t0 = Instant::now(); + let mut token = transformer.decode( + vec![ctx.request(&prompt, max_seq_len)], + &sample.lock().unwrap(), + )[0] + .1; + let t1 = Instant::now(); + let mut len = 0; + while token != eos { + responsing.send(token).unwrap(); + token = transformer.decode( + vec![ctx.request(&[token], max_seq_len)], &sample.lock().unwrap(), - &compute, - &transfer, )[0] .1; - let t1 = Instant::now(); - let mut len = 0; - while token != eos { - responsing.send(token).unwrap(); - token = transformer.decode( - vec![ctx.request(&[token], max_seq_len)], - &sample.lock().unwrap(), - &compute, - &transfer, - )[0] - .1; - len += 1; - } - let t2 = Instant::now(); - info!( - "First token delay: {:?}, average speed = {:?}/tok", - t1 - t0, - (t2 - t1).div_f32(len as _) - ); - } - Command::Drop { id } => { - sessions.remove(&id); + len += 1; } + let t2 = Instant::now(); + info!( + "First token delay: {:?}, average speed = {:?}/tok", + t1 - t0, + (t2 - t1).div_f32(len as _) + ); + } + Command::Drop { id } => { + sessions.remove(&id); } } - }); + } } -struct SessionContext<'a>(session::SessionContext>); +struct SessionContext(session::SessionContext); -impl<'a> SessionContext<'a> { +impl SessionContext { #[inline] - fn new(transformer: &Transformer, id: usize, stream: &'a Stream) -> Self { - Self(session::SessionContext::new( - transformer.new_cache(stream), - id, - )) + fn new(transformer: &Transformer, id: usize) -> Self { + Self(session::SessionContext::new(transformer.new_cache(), id)) } #[inline] - fn request(&mut self, tokens: &[utok], max_seq_len: usize) -> Request<'_, 'a, usize> { + fn request(&mut self, tokens: &[utok], max_seq_len: usize) -> Request<'_, usize> { let pos = self.0.request(tokens, max_seq_len); Request::new( self.0.id, diff --git a/transformer-cpu/src/lib.rs b/transformer-cpu/src/lib.rs index 33f00468..f5a8e19b 100644 --- a/transformer-cpu/src/lib.rs +++ b/transformer-cpu/src/lib.rs @@ -56,14 +56,22 @@ impl Transformer { for layer in 0..self.0.num_hidden_layers() { let (q, k, v) = self.before_att(layer, &x0, &mut x1, &mut buf.qkv, &pos); let o = &mut x1; - #[rustfmt::skip] - self.attention(layer, &mut requests, q, k, v, o, &mut buf.q_buf, &mut buf.att_buf); + self.attention( + layer, + &mut requests, + q, + k, + v, + o, + &mut buf.q_buf, + &mut buf.att_buf, + ); self.after_att(layer, &mut x0, &mut x1, &mut buf.gate_up); } // 解码 if requests[0].decode() { let x = self.move_decode(&requests, x0); - let requests = requests.into_iter().map(Request::id).collect::>(); + let requests = requests.into_iter().map(Request::id).collect(); Sample.sample(sample, requests, self.logits(x)) } else { vec![] diff --git a/transformer-nvidia/src/lib.rs b/transformer-nvidia/src/lib.rs index cf64fbef..5b78e835 100644 --- a/transformer-nvidia/src/lib.rs +++ b/transformer-nvidia/src/lib.rs @@ -10,24 +10,29 @@ extern crate log; use common::utok; use cublas::{Cublas, CublasSpore}; -use cuda::{AsRaw, ContextResource, ContextSpore, CudaDataType::half, Stream}; +use cuda::{ + AsRaw, Context, ContextResource, ContextSpore, CudaDataType::half, DevMemSpore, Stream, + StreamSpore, +}; use kernel::{gather, mat_mul, FusedSoftmax, Reform, RmsNormalization, RotaryEmbedding, Swiglu}; use parameters::{LayerParameter, LayersParameters, ModelParameters}; -use std::sync::Mutex; +use std::{cell::RefCell, fs::File, io::Read, sync::Arc, time::Instant}; use storage::Storage; use tensor::{slice, udim, DataType, Tensor}; -use transformer::{pos, LayerBuffer, Sample as _, SampleArgs}; +use transformer::{pos, LayerBuffer, Sample as _}; -pub type Request<'a, 'b, Id> = transformer::Request<'a, Id, Storage<'b>>; -pub type LayerCache<'a> = transformer::LayerCache>; +pub type Request<'a, Id> = transformer::Request<'a, Id, DevMemSpore>; +pub type LayerCache = transformer::LayerCache; pub use sample::Sample; -pub use transformer::{Llama2, Memory}; +pub use transformer::{Llama2, Memory, SampleArgs}; pub extern crate cuda; pub struct Transformer { - host: Box, + context: Arc, + transfer: StreamSpore, + host: Memory<'static>, model: ModelParameters, - layers: Mutex, + layers: RefCell, cublas: CublasSpore, rms_norm: RmsNormalization, rotary_embedding: RotaryEmbedding, @@ -37,76 +42,141 @@ pub struct Transformer { } impl Transformer { - pub fn new(host: Box, preload_layers: usize, stream: &Stream) -> Self { + pub fn new( + config: File, + mut safetensors: File, + preload_layers: usize, + context: Arc, + ) -> Self { + let time = Instant::now(); + let mut host = context.apply(|ctx| { + ctx.malloc_host::(safetensors.metadata().unwrap().len() as _) + .sporulate() + }); + safetensors.read_exact(&mut host).unwrap(); + drop(safetensors); + info!("read to host {:?}", time.elapsed()); + + let host = Memory::load_safetensors(config, host, false).unwrap(); let load_layers = preload_layers.min(host.num_hidden_layers()); - let ctx = stream.ctx(); - let dev = ctx.dev(); - let (block_size, _) = dev.max_block_dims(); + + let ( + model, + layers, + cublas, + rms_norm, + rotary_embedding, + reform, + fused_softmax, + swiglu, + transfer, + ) = context.apply(|ctx| { + let dev = ctx.dev(); + let (block_size, _) = dev.max_block_dims(); + let stream = ctx.stream(); + + ( + ModelParameters::new(&host, &stream), + RefCell::new(LayersParameters::new(load_layers, &host, &stream)), + Cublas::new(ctx).sporulate(), + RmsNormalization::new(half, host.hidden_size(), block_size, ctx), + RotaryEmbedding::new(block_size, ctx), + Reform::new(block_size, 32, ctx), + FusedSoftmax::new(half, host.max_position_embeddings(), block_size, ctx), + Swiglu::new(half, block_size, ctx), + stream.sporulate(), + ) + }); + Self { - model: ModelParameters::new(&*host, stream), - layers: Mutex::new(LayersParameters::new(load_layers, &*host, stream)), - cublas: Cublas::new(ctx).sporulate(), - rms_norm: RmsNormalization::new(half, host.hidden_size(), block_size, ctx), - rotary_embedding: RotaryEmbedding::new(block_size, ctx), - reform: Reform::new(block_size, 32, ctx), - fused_softmax: FusedSoftmax::new(half, host.max_position_embeddings(), block_size, ctx), - swiglu: Swiglu::new(half, block_size, ctx), + context, + transfer, host, + model, + layers, + cublas, + rms_norm, + rotary_embedding, + reform, + fused_softmax, + swiglu, } } #[inline] - pub fn new_cache<'b>(&self, stream: &'b Stream) -> Vec> { - LayerCache::new_layers(&*self.host, |dt, shape| tensor(dt, shape, stream)) + pub fn new_cache(&self) -> Vec { + self.context.apply(|ctx| { + let stream = unsafe { self.transfer.sprout(ctx) }; + LayerCache::new_layers(&self.host, |dt, shape| { + let len = shape.iter().product::() as usize * dt.size(); + Tensor::new(dt, shape, stream.malloc::(len).sporulate()) + }) + }) + } + + #[inline] + pub fn model(&self) -> &impl Llama2 { + &self.host } pub fn decode<'ctx, Id>( &self, mut requests: Vec>, sample: &SampleArgs, - compute: &Stream<'ctx>, - transfer: &Stream<'ctx>, ) -> Vec<(Id, utok)> { - let ctx = compute.ctx(); - unsafe { self.cublas.sprout(ctx) }.set_stream(compute); - // 归拢所有纯解码的请求到前面,减少批量解码的拷贝开销 - requests.sort_unstable_by_key(Request::purely_decode); - // 生成词嵌入并预分配空间 - let mut x0 = self.token_embed(&requests, compute); - let mut x1 = tensor(x0.data_type(), x0.shape(), transfer); - let mut buf = - LayerBuffer::alloc(&*self.host, &requests, |size| Storage::new(size, transfer)); - // 生成位置张量 - let nt = x0.shape()[0]; // `nt` for number of tokens - let pos_ = pos(&requests, nt); - let mut pos = tensor(DataType::U32, &[nt], transfer); - pos.physical_mut().copy_in_async(&pos_, transfer); - // 推理 - compute.wait_for(&transfer.record()); - { - // 层参数滚动加载是有状态的,必须由一个控制流独占。其他逻辑无状态,可以多流并发 - let mut layers = self.layers.lock().unwrap(); - for layer in 0..self.host.num_hidden_layers() { - let params = { - layers.load(layer, &*self.host, transfer); - layers.sync(layer, compute) - }; - - let (q, k, v) = self.before_att(params, &x0, &mut x1, &mut buf.qkv, &pos, compute); - let o = &mut x1; - #[rustfmt::skip] - self.attention(layer, &mut requests, q, k, v, o, &mut buf.q_buf, &mut buf. att_buf, compute); - self.after_att(params, &mut x0, &mut x1, &mut buf.gate_up, compute); + self.context.apply(|ctx| { + let transfer = unsafe { self.transfer.sprout(ctx) }; + let compute = ctx.stream(); + unsafe { self.cublas.sprout(ctx) }.set_stream(&compute); + // 归拢所有纯解码的请求到前面,减少批量解码的拷贝开销 + requests.sort_unstable_by_key(Request::purely_decode); + // 生成词嵌入并预分配空间 + let mut x0 = self.token_embed(&requests, &compute); + let mut x1 = tensor(x0.data_type(), x0.shape(), &transfer); + let mut buf = + LayerBuffer::alloc(&self.host, &requests, |size| Storage::new(size, &transfer)); + // 生成位置张量 + let nt = x0.shape()[0]; // `nt` for number of tokens + let pos_ = pos(&requests, nt); + let mut pos = tensor(DataType::U32, &[nt], &transfer); + pos.physical_mut().copy_in_async(&pos_, &transfer); + // 推理 + compute.wait_for(&transfer.record()); + { + // 层参数滚动加载是有状态的,必须由一个控制流独占。其他逻辑无状态,可以多流并发 + let mut layers = self.layers.borrow_mut(); + for layer in 0..self.host.num_hidden_layers() { + let params = { + layers.load(layer, &self.host, &transfer); + layers.sync(layer, &compute) + }; + + let (q, k, v) = + self.before_att(params, &x0, &mut x1, &mut buf.qkv, &pos, &compute); + let o = &mut x1; + self.attention( + layer, + &mut requests, + q, + k, + v, + o, + &mut buf.q_buf, + &mut buf.att_buf, + &compute, + ); + self.after_att(params, &mut x0, &mut x1, &mut buf.gate_up, &compute); + } } - } - // 解码 - if requests[0].decode() { - let x = self.move_decode(&requests, x0, compute); - let requests = requests.into_iter().map(Request::id).collect::>(); - Sample.sample(sample, requests, self.logits(x, compute)) - } else { - vec![] - } + // 解码 + if requests[0].decode() { + let x = self.move_decode(&requests, x0, &compute); + let requests = requests.into_iter().map(Request::id).collect(); + Sample.sample(sample, requests, self.logits(x, &compute)) + } else { + vec![] + } + }) } fn token_embed<'ctx, Id>( @@ -199,7 +269,8 @@ impl Transformer { let dh = d / nh; let head_group = nh / nkvh; let head_div = (dh as f32).sqrt().recip(); - let cublas = unsafe { self.cublas.sprout(compute.ctx()) }; + let ctx = compute.ctx(); + let cublas = unsafe { self.cublas.sprout(ctx) }; let q = q.as_ref().transpose(&[1, 0, 2]); let k = k.as_ref().transpose(&[1, 0, 2]); @@ -229,6 +300,9 @@ impl Transformer { let mut q_att = Tensor::new(dt, &[nh, seq_len, dh], &mut **q_buf); let (k_cache, v_cache) = r.cache(layer); + let mut k_cache = unsafe { k_cache.as_mut().map_physical(|s| s.sprout(ctx)) }; + let mut v_cache = unsafe { v_cache.as_mut().map_physical(|s| s.sprout(ctx)) }; + let k_cat = k_cache.as_mut().slice(cat_slice); let v_cat = v_cache.as_mut().slice(cat_slice); let mut k_cat = unsafe { k_cat.map_physical(|u| &mut **u) }; @@ -238,10 +312,8 @@ impl Transformer { self.reform.launch(&mut v_cat, &v, compute); let q_att = q_att.reshape(&[nkvh, head_group * seq_len, dh]); - let k_att = k_cache.as_ref().slice(att_slice).transpose(&[0, 2, 1]); - let v_att = v_cache.as_ref().slice(att_slice); - let k_att = unsafe { k_att.map_physical(|u| &**u) }; - let v_att = unsafe { v_att.map_physical(|u| &**u) }; + let k_att = k_cache.slice(att_slice).transpose(&[0, 2, 1]); + let v_att = v_cache.slice(att_slice); // println!("layer {layer} q attention:\n{}", q_att); // println!("layer {layer} k attention:\n{}", k_att.access()); // println!("layer {layer} v attention:\n{}", v_att.access());