diff --git a/Cargo.lock b/Cargo.lock index ec6f9ac5..dd58a13b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -916,9 +916,9 @@ dependencies = [ [[package]] name = "flate2" -version = "1.0.29" +version = "1.0.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4556222738635b7a3417ae6130d8f52201e45a0c4d1a907f0826383adb5f85e7" +checksum = "5f54427cfd1c7829e2a139fcefea601bf088ebca651d2bf53ebc600eac295dae" dependencies = [ "crc32fast", "miniz_oxide", @@ -1450,7 +1450,9 @@ checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c" name = "llama" version = "0.0.0" dependencies = [ + "causal-lm", "common", + "itertools", "rayon", "serde", "serde_json", @@ -2273,9 +2275,9 @@ checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" [[package]] name = "socket2" -version = "0.5.6" +version = "0.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05ffd9c0a93b7543e062e759284fcf5f5e3b098501104bfbdde4d404db792871" +checksum = "ce305eb0b4296696835b71df73eb912e0f1ffd2556a501fcede6e0c50349191c" dependencies = [ "libc", "windows-sys 0.52.0", @@ -2552,7 +2554,6 @@ dependencies = [ "gemm", "intel-mkl-src", "intel-mkl-tool", - "itertools", "kernel-lib", "llama", "tensor", @@ -2565,6 +2566,7 @@ dependencies = [ "causal-lm", "common-nv", "itertools", + "llama", "log", "search-cuda-tools", "transformer", diff --git a/models/llama/Cargo.toml b/models/llama/Cargo.toml index 99add97d..7c90810f 100644 --- a/models/llama/Cargo.toml +++ b/models/llama/Cargo.toml @@ -9,6 +9,8 @@ authors = ["YdrMaster "] [dependencies] common = { path = "../../common" } tensor = { path = "../../tensor" } +causal-lm = { path = "../../causal-lm" } +itertools.workspace = true serde = { workspace = true, features = ["derive"] } serde_json.workspace = true rayon.workspace = true diff --git a/models/llama/src/compute.rs b/models/llama/src/compute.rs new file mode 100644 index 00000000..7e60fadd --- /dev/null +++ b/models/llama/src/compute.rs @@ -0,0 +1,202 @@ +use causal_lm::QueryContext; +use itertools::izip; +use std::ops::{Deref, DerefMut}; +use tensor::{slice, split, udim, LocalSplitable, Tensor}; + +pub trait ComputeStream { + type Byte; + type Storage; + type Buf<'m>: DerefMut; + type Pos<'m>: Deref; + + fn malloc(&self, len: usize) -> Self::Buf<'_>; + fn free(&self, _mem: Self::Buf<'_>) {} + fn map_pos<'p>(&self, pos: &'p [u32]) -> Self::Pos<'p> + where + Self: 'p; + fn free_pos(&self, _mem: Self::Pos<'_>) {} + fn map_storage(&self, storage: &mut Self::Storage) -> impl DerefMut; + + fn rms_norm(&self, o: &mut Tensor, x: &Tensor, w: &Tensor) + where + O: DerefMut, + X: Deref, + W: Deref; + + fn mat_mul( + &self, + o: &mut Tensor, + beta: f32, + a: &Tensor, + b: &Tensor, + alpha: f32, + ) where + O: DerefMut, + A: Deref, + B: Deref; + + fn rotary_embedding(&self, x: &mut Tensor, pos: &Tensor>) + where + X: DerefMut; + + fn reform(&self, y: &mut Tensor, x: &Tensor) + where + Y: DerefMut, + X: Deref; + + fn softmax(&self, x: &mut Tensor) + where + X: DerefMut; + + fn swiglu(&self, a: &mut Tensor, b: &Tensor) + where + A: DerefMut, + B: Deref; + + fn nh(&self) -> udim; + fn nkvh(&self) -> udim; + fn di(&self) -> udim; + fn layers(&self) -> impl Iterator>; + + fn forward<'q>( + &self, + queries: impl IntoIterator>, + mut token_embedded: Tensor, + ) -> Tensor + where + Self::Storage: 'q, + { + let mut queries = queries.into_iter().collect::>(); + let mut nt = 0; + let mut max_seq_len = 0; + let mut max_att_len = 0; + let seq_len = queries + .iter() + .map(|q| { + let seq = q.seq_len(); + let att = q.att_len(); + nt += seq; + max_seq_len = max_seq_len.max(seq); + max_att_len = max_att_len.max(att); + seq + }) + .collect::>(); + + let dt = token_embedded.data_type(); + let d = token_embedded.shape()[1]; + let nh = self.nh(); + let nkvh = self.nkvh(); + let dh = d / nh; + let dkv = nkvh * dh; + let di = self.di(); + let head_group = nh / nkvh; + let head_div = (dh as f32).sqrt().recip(); + + let mut x = token_embedded + .as_mut() + .map_physical(|u| self.map_storage(u)); + let reusing = (d + dkv + dkv).max(di + di); + let mut state_buf = Tensor::alloc(dt, &[nt, d + reusing], |len| self.malloc(len)); + + let mut q_buf = self.malloc((nh * max_seq_len * dh) as usize * dt.size()); + let mut att_buf = self.malloc((nh * max_seq_len * max_att_len) as usize * dt.size()); + let pos = causal_lm::pos(&queries, nt); + let pos = pos.as_ref().map_physical(|u| self.map_pos(u)); + + for (layer, params) in self.layers().into_iter().enumerate() { + let (mut x1, qkv) = split!(state_buf.as_mut().map_physical(|u| LocalSplitable::from(&mut **u)); [1]: d, reusing); + let mut qkv = qkv.slice(&[slice![=>], slice![=> d + dkv + dkv]]); + + self.rms_norm(&mut x1, &x, ¶ms.att_layernorm()); + self.mat_mul(&mut qkv, 0., &x1, ¶ms.att_qkv(), 1.); + + let (q, k, v) = split!(qkv; [1]: d, dkv, dkv); + let mut q = q.reshape(&[nt, nh, dh]); + let mut k = k.reshape(&[nt, nkvh, dh]); + let v = v.reshape(&[nt, nkvh, dh]); + let o = x1.reshape(&[nt, nh, dh]); + + self.rotary_embedding(&mut q, &pos); + self.rotary_embedding(&mut k, &pos); + + let q = q.transpose(&[1, 0, 2]).split(1, &seq_len); + let k = k.transpose(&[1, 0, 2]).split(1, &seq_len); + let v = v.transpose(&[1, 0, 2]).split(1, &seq_len); + let o = o.transpose(&[1, 0, 2]).split(1, &seq_len); + + for (query, q, k, v, mut o) in izip!(&mut queries, q, k, v, o) { + let pos = query.pos(); + let seq_len = query.seq_len(); + let att_len = query.att_len(); + let mut cache = query + .cache + .as_mut() + .map(|t| t.as_mut().map_physical(|u| self.map_storage(u))); + let mut query = QueryContext { + cache: cache.as_mut(), + range: query.range.clone(), + }; + let Some((mut k_cache, mut v_cache)) = query.cache(layer as _) else { + continue; + }; + + let slice_cat = &[slice![=>], slice![pos =>=> seq_len], slice![=>]]; + let slice_att = &[slice![=>], slice![ => att_len], slice![=>]]; + let shape_q0 = &[nkvh * head_group, seq_len, dh]; + let shape_q1 = &[nkvh, head_group * seq_len, dh]; + let shape_att0 = &[nkvh, head_group * seq_len, att_len]; + let shape_att1 = &[nkvh * head_group, seq_len, att_len]; + + let mut q_att = Tensor::new(dt, shape_q0, &mut q_buf[..]); + let mut k_cat = k_cache.as_mut().slice(slice_cat).map_physical(|u| &mut **u); + let mut v_cat = v_cache.as_mut().slice(slice_cat).map_physical(|u| &mut **u); + self.reform(&mut q_att, &q); + self.reform(&mut k_cat, &k); + self.reform(&mut v_cat, &v); + + let q_att = q_att.reshape(shape_q1); + let k_att = k_cache.slice(slice_att).transpose(&[0, 2, 1]); + let v_att = v_cache.slice(slice_att); + + let mut att = Tensor::new(dt, shape_att0, &mut att_buf[..]); + self.mat_mul(&mut att, 0., &q_att, &k_att, head_div); + let mut att = att.reshape(shape_att1); + self.softmax(&mut att); + let mut x2 = q_att; + self.mat_mul(&mut x2, 0., &att.reshape(shape_att0), &v_att, 1.); + + self.reform(&mut o, &x2.reshape(shape_q0)); + } + + let (mut x1, gate_up) = split!(state_buf.as_mut().map_physical(|u| LocalSplitable::from(&mut **u)); [1]: d, reusing); + let mut gate_up = gate_up.slice(&[slice![=>], slice![=> di + di]]); + + self.mat_mul(&mut x, 1., &x1, ¶ms.att_o(), 1.); + self.rms_norm(&mut x1, &x, ¶ms.mlp_layernorm()); + self.mat_mul(&mut gate_up, 0., &x1, ¶ms.mlp_gate_up(), 1.); + let (mut gate, up) = split!(gate_up; [1]: di, di); + self.swiglu(&mut gate, &up); + self.mat_mul(&mut x, 1., &gate, ¶ms.mlp_down(), 1.); + } + self.free_pos(pos.take_physical()); + self.free(state_buf.take_physical()); + self.free(q_buf); + self.free(att_buf); + drop(x); + token_embedded + } +} + +pub trait LLamaLayer { + type Byte; + type Storage<'m>: Deref + where + Self: 'm; + + fn att_layernorm(&self) -> Tensor>; + fn att_qkv(&self) -> Tensor>; + fn att_o(&self) -> Tensor>; + fn mlp_layernorm(&self) -> Tensor>; + fn mlp_gate_up(&self) -> Tensor>; + fn mlp_down(&self) -> Tensor>; +} diff --git a/models/llama/src/lib.rs b/models/llama/src/lib.rs index 3e6a36fb..2c5e3d6c 100644 --- a/models/llama/src/lib.rs +++ b/models/llama/src/lib.rs @@ -1,4 +1,5 @@ mod cast; +mod compute; mod json; mod load; mod save; @@ -7,6 +8,8 @@ use common::{safe_tensors::SharedTensor, utok, Blob}; use std::{ops::Deref, sync::Arc}; use tensor::{udim, DataType, Tensor}; +pub use compute::{ComputeStream, LLamaLayer}; + pub struct Storage { pub config: InferenceConfig, diff --git a/nvidia/transformer/Cargo.toml b/nvidia/transformer/Cargo.toml index d868363e..b903b439 100644 --- a/nvidia/transformer/Cargo.toml +++ b/nvidia/transformer/Cargo.toml @@ -8,6 +8,7 @@ authors = ["YdrMaster "] [dependencies] causal-lm = { path = "../../causal-lm" } +llama = { path = "../../models/llama" } transformer = { path = "../../models/llama-legacy" } common-nv = { path = "../common" } log.workspace = true diff --git a/nvidia/transformer/src/lib.rs b/nvidia/transformer/src/lib.rs index 94c3bf33..7a7134a7 100644 --- a/nvidia/transformer/src/lib.rs +++ b/nvidia/transformer/src/lib.rs @@ -7,32 +7,36 @@ extern crate log; use causal_lm::{CausalLM, DecodingMeta, Model, QueryContext, SampleMeta}; use common_nv::{ - cuda::{memcpy_d2h, DevMemSpore}, - f16, slice, split, udim, upos, utok, DataType, FileLoadError, Kernels, LocalSplitable, - NvidiaKernels, NvidiaKernelsPtx, Tensor, + cuda::{memcpy_d2h, DevByte, DevMem, DevMemSpore, EventSpore, HostMemSpore, Stream}, + f16, slice, udim, upos, utok, DataType, FileLoadError, KernelRuntime, Kernels, NvidiaKernels, + NvidiaKernelsPtx, Tensor, }; use cuda::{Context, ContextResource, ContextSpore, Device, StreamSpore}; -use itertools::izip; -use parameters::{LayersParameters, ModelParameters}; +use parameters::ModelParameters; use std::{ + cell::RefCell, + collections::VecDeque, iter::repeat, + ops::{Deref, DerefMut}, path::Path, + rc::Rc, slice::from_raw_parts, - sync::{Arc, Mutex}, + sync::{Arc, Mutex, MutexGuard}, time::Instant, }; -use transformer::{Llama2, Memory}; +use transformer::{Llama2, Memory, Storage}; pub use common_nv::{cuda, synchronize}; pub struct Transformer { host: Memory, model: ModelParameters, - layers: Mutex, context: Arc, transfer: StreamSpore, compute: StreamSpore, kernels: NvidiaKernels, + host_: Vec>, + layers: Mutex, EventSpore)>>, } impl Model for Transformer { @@ -50,27 +54,79 @@ impl Model for Transformer { info!("load host: {:?}", time.elapsed()); let load_layers = host.num_hidden_layers(); - let (model, layers, kernels, transfer, compute) = context.apply(|ctx| { + context.apply(|ctx| { let transfer = ctx.stream(); let compute = ctx.stream(); let block_size = ctx.dev().max_block_dims().0; - ( - ModelParameters::new(&host, &compute), - Mutex::new(LayersParameters::new(load_layers, &host, &transfer)), - NvidiaKernelsPtx::new(&host, block_size).load(ctx), - transfer.sporulate(), - compute.sporulate(), - ) - }); - - Ok(Self { - host, - model, - layers, - context, - transfer, - compute, - kernels, + let host_ = (0..host.num_hidden_layers()) + .map(|l| { + let memcpy = |u: Storage| { + let mut host = ctx.malloc_host::(u.len()); + host.copy_from_slice(&u); + host.sporulate() + }; + macro_rules! memcpy { + ($($f:ident => $ident:ident)+) => { + LayerStorage{ + $( + $ident: host.$f(l).map_physical(memcpy), + )+ + } + }; + } + memcpy! { + input_layernorm => att_layernorm + w_qkv => att_qkv + self_attn_o_proj => att_o + post_attention_layernorm => mlp_layernorm + mlp_gate_up => mlp_gate_up + mlp_down => mlp_down + } + }) + .collect::>(); + let pool = Mutex::new( + (0..load_layers) + .map(|l| { + macro_rules! memcpy { + ($($ident:ident)+) => { + LayerStorage { + $( + $ident: { + host_[l] + .$ident + .as_ref() + .map_physical(|u| transfer.from_host(u).sporulate()) + }, + )+ + } + }; + } + ( + memcpy! { + att_layernorm + att_qkv + att_o + mlp_layernorm + mlp_gate_up + mlp_down + }, + transfer.record().sporulate(), + ) + }) + .collect(), + ); + let model = ModelParameters::new(&host, &compute); + let kernels = NvidiaKernelsPtx::new(&host, block_size).load(ctx); + Ok(Self { + host, + model, + context: context.clone(), + transfer: transfer.sporulate(), + compute: compute.sporulate(), + kernels, + host_, + layers: pool, + }) }) } } @@ -157,134 +213,21 @@ impl CausalLM for Transformer { where Self: 'a, { - let mut queries = queries.into_iter().collect::>(); - let mut nt = 0; - let mut max_seq_len = 0; - let mut max_att_len = 0; - let seq_len = queries - .iter() - .map(|q| { - let seq = q.seq_len(); - let att = q.att_len(); - nt += seq; - max_seq_len = max_seq_len.max(seq); - max_att_len = max_att_len.max(att); - seq - }) - .collect::>(); - - let dt = self.host.data_type(); - let d = self.host.hidden_size() as udim; - let nh = self.host.num_attention_heads() as udim; - let nkvh = self.host.num_key_value_heads() as udim; - let dh = d / nh; - let dkv = nkvh * dh; - let di = self.host.intermediate_size() as udim; - let head_group = nh / nkvh; - let head_div = (dh as f32).sqrt().recip(); - - let mut x_ = token_embedded; self.context.apply(|ctx| { let compute = unsafe { self.compute.sprout(ctx) }; - let kernels = self.kernels.on(&compute); - - let reusing = (d + dkv + dkv).max(di + di); - let mut state_buf = Tensor::alloc(dt, &[nt, d + reusing],|len| compute.malloc::(len)); - macro_rules! state { - () => { - split!(state_buf.as_mut().map_physical(|u| LocalSplitable::from(&mut **u)); [1]: d, reusing) - }; - } - - let mut q_buf = compute.malloc::((nh * max_seq_len * dh) as usize * dt.size()); - let mut att_buf = compute.malloc::((nh * max_seq_len * max_att_len) as usize * dt.size()); - let pos = causal_lm::pos(&queries, nt); - let pos = pos.as_ref().map_physical(|u| compute.from_host(u)); - - let mut x = x_.as_mut().map_physical(|u| unsafe { u.mem.sprout(ctx) }); let transfer = unsafe { self.transfer.sprout(ctx) }; - // 层参数滚动加载是有状态的,必须由一个控制流独占。其他逻辑无状态,可以多流并发 - let mut layers = self.layers.lock().unwrap(); - for layer in 0..self.host.num_hidden_layers() { - let params = { - layers.load(layer, &self.host, &transfer); - layers.sync(layer, &compute) - }; - - let (mut x1, qkv) = state!(); - let mut qkv = qkv.slice(&[slice![=>], slice![=> d + dkv + dkv]]); - - kernels.rms_norm(&mut x1, &x, ¶ms.input_layernorm(ctx)); - kernels.mat_mul(&mut qkv, 0., &x1, ¶ms.w_qkv(ctx), 1.); - - let (q, k, v) = split!(qkv; [1]: d, dkv, dkv); - let mut q = q.reshape(&[nt, nh, dh]); - let mut k = k.reshape(&[nt, nkvh, dh]); - let v = v.reshape(&[nt, nkvh, dh]); - let o = x1.reshape(&[nt, nh, dh]); - - kernels.rotary_embedding(&mut q, &pos); - kernels.rotary_embedding(&mut k, &pos); - - let q = q.transpose(&[1, 0, 2]).split(1, &seq_len); - let k = k.transpose(&[1, 0, 2]).split(1, &seq_len); - let v = v.transpose(&[1, 0, 2]).split(1, &seq_len); - let o = o.transpose(&[1, 0, 2]).split(1, &seq_len); - - for (query, q, k, v, mut o) in izip!(&mut queries, q, k, v, o) { - let pos = query.pos(); - let seq_len = query.seq_len(); - let att_len = query.att_len(); - let mut cache = query.cache.as_mut().map(|t|t.as_mut().map_physical(|u| unsafe { u.mem.sprout(ctx) })); - let mut query = QueryContext{ cache:cache.as_mut(), range: query.range.clone() }; - let Some((mut k_cache, mut v_cache)) = query.cache(layer) else { - continue; - }; - - let slice_cat = &[slice![=>], slice![pos =>=> seq_len], slice![=>]]; - let slice_att = &[slice![=>], slice![ => att_len], slice![=>]]; - let shape_q0 = &[nkvh * head_group, seq_len, dh]; - let shape_q1 = &[nkvh, head_group * seq_len, dh]; - let shape_att0 = &[nkvh, head_group * seq_len, att_len]; - let shape_att1 = &[nkvh * head_group, seq_len, att_len]; - - let mut q_att = Tensor::new(dt, shape_q0, &mut q_buf[..]); - let mut k_cat = k_cache.as_mut().slice(slice_cat).map_physical(|u| &mut **u); - let mut v_cat = v_cache.as_mut().slice(slice_cat).map_physical(|u| &mut **u); - kernels.reform(&mut q_att, &q); - kernels.reform(&mut k_cat, &k); - kernels.reform(&mut v_cat, &v); - - let q_att = q_att.reshape(shape_q1); - let k_att = k_cache.slice(slice_att).transpose(&[0, 2, 1]); - let v_att = v_cache.slice(slice_att); - - let mut att = Tensor::new(dt, shape_att0, &mut att_buf[..]); - kernels.mat_mul(&mut att, 0., &q_att, &k_att, head_div); - let mut att = att.reshape(shape_att1); - kernels.softmax(&mut att); - let mut x2 = q_att; - kernels.mat_mul(&mut x2, 0., &att.reshape(shape_att0), &v_att, 1.); - - kernels.reform(&mut o, &x2.reshape(shape_q0)); - } - - let (mut x1, gate_up) = state!(); - let mut gate_up = gate_up.slice(&[slice![=>], slice![=> di + di]]); - - kernels.mat_mul(&mut x, 1., &x1, ¶ms.w_o(ctx), 1.); - kernels.rms_norm(&mut x1, &x, ¶ms.post_attention_layernorm(ctx)); - kernels.mat_mul(&mut gate_up, 0., &x1, ¶ms.mlp_gate_up(ctx), 1.); - let (mut gate, up) = split!(gate_up; [1]: di, di); - kernels.swiglu(&mut gate, &up); - kernels.mat_mul(&mut x, 1., &gate, ¶ms.mlp_down(ctx), 1.); - } - pos.take_physical().drop_on(&compute); - state_buf.take_physical().drop_on(&compute); - q_buf.drop_on(&compute); - att_buf.drop_on(&compute); - }); - x_ + let stream = ComputeStream { + nh: self.host.num_attention_heads() as _, + nkvh: self.host.num_key_value_heads() as _, + di: self.host.intermediate_size() as _, + kernels: self.kernels.on(&compute), + compute: &compute, + transfer: &transfer, + host: Rc::new(&*self.host_), + dev: Rc::new(RefCell::new(self.layers.lock().unwrap())), + }; + ::forward(&stream, queries, token_embedded) + }) } fn decode( @@ -394,11 +337,10 @@ impl Drop for Transformer { #[inline] fn drop(&mut self) { self.context.apply(|ctx| unsafe { - self.model.kill(ctx); - self.layers.lock().unwrap().kill(ctx); self.transfer.kill(ctx); self.compute.kill(ctx); self.kernels.kill(ctx); + self.model.kill(ctx); }); } } @@ -415,6 +357,247 @@ impl Drop for Cache { } } +struct ComputeStream<'a> { + nh: udim, + nkvh: udim, + di: udim, + kernels: KernelRuntime<'a>, + compute: &'a Stream<'a>, + transfer: &'a Stream<'a>, + host: Rc<&'a [LayerStorage]>, + dev: DevMemPool<'a>, +} + +type DevMemPool<'a> = + Rc, EventSpore)>>>>; + +struct LayerStorage { + att_layernorm: Tensor, + att_qkv: Tensor, + att_o: Tensor, + mlp_layernorm: Tensor, + mlp_gate_up: Tensor, + mlp_down: Tensor, +} + +impl<'a> llama::ComputeStream for ComputeStream<'a> { + type Byte = DevByte; + type Storage = Cache; + type Buf<'m> = DevMem<'m>; + type Pos<'m> = DevMem<'m>; + + fn malloc(&self, len: usize) -> Self::Buf<'_> { + self.compute.malloc::(len) + } + fn free(&self, mem: Self::Buf<'_>) { + mem.drop_on(self.compute); + } + fn map_pos<'b>(&self, pos: &'b [u32]) -> Self::Pos<'b> + where + Self: 'b, + { + self.compute.from_host(pos) + } + fn free_pos(&self, mem: Self::Pos<'_>) { + mem.drop_on(self.compute); + } + fn map_storage(&self, storage: &mut Self::Storage) -> impl DerefMut { + unsafe { storage.mem.sprout(self.compute.ctx()) } + } + fn rms_norm(&self, o: &mut Tensor, x: &Tensor, w: &Tensor) + where + O: DerefMut, + X: Deref, + W: Deref, + { + self.kernels.rms_norm(o, x, w); + } + fn mat_mul( + &self, + o: &mut Tensor, + beta: f32, + a: &Tensor, + b: &Tensor, + alpha: f32, + ) where + O: DerefMut, + A: Deref, + B: Deref, + { + self.kernels.mat_mul(o, beta, a, b, alpha); + } + fn rotary_embedding(&self, x: &mut Tensor, pos: &Tensor>) + where + X: DerefMut, + { + self.kernels.rotary_embedding(x, pos); + } + fn reform(&self, y: &mut Tensor, x: &Tensor) + where + Y: DerefMut, + X: Deref, + { + self.kernels.reform(y, x); + } + fn softmax(&self, x: &mut Tensor) + where + X: DerefMut, + { + self.kernels.softmax(x); + } + fn swiglu(&self, a: &mut Tensor, b: &Tensor) + where + A: DerefMut, + B: Deref, + { + self.kernels.swiglu(a, b); + } + fn nh(&self) -> udim { + self.nh + } + fn nkvh(&self) -> udim { + self.nkvh + } + fn di(&self) -> udim { + self.di + } + fn layers(&self) -> impl Iterator> { + Iter::new( + self.host.clone(), + self.dev.clone(), + self.compute, + self.transfer, + ) + } +} + +struct Iter<'a> { + host: Rc<&'a [LayerStorage]>, + pool: DevMemPool<'a>, + compute: &'a Stream<'a>, + transfer: &'a Stream<'a>, + layer: usize, +} + +impl<'a> Iter<'a> { + pub fn new( + host: Rc<&'a [LayerStorage]>, + pool: DevMemPool<'a>, + compute: &'a Stream, + transfer: &'a Stream, + ) -> Self { + Self { + host, + pool, + compute, + transfer, + layer: 0, + } + } +} + +impl<'a> Iterator for Iter<'a> { + type Item = LayerLoader<'a>; + + fn next(&mut self) -> Option { + if self.layer >= self.host.len() { + return None; + } + + let mut pool = self.pool.borrow_mut(); + let load = if pool.len() < self.host.len() { + Some((self.layer + pool.len()) % self.host.len()) + } else { + None + }; + self.layer += 1; + + let (lll, mut event) = pool.pop_front().unwrap(); + let ctx = self.compute.ctx(); + self.compute.wait_for(&unsafe { event.sprout(ctx) }); + unsafe { ctx.kill(&mut event) }; + + Some(Self::Item { + host: self.host.clone(), + pool: self.pool.clone(), + load, + transfer: self.transfer, + storage: Some(lll), + }) + } +} + +struct LayerLoader<'a> { + host: Rc<&'a [LayerStorage]>, + pool: DevMemPool<'a>, + load: Option, + transfer: &'a Stream<'a>, + storage: Option>, +} + +macro_rules! access { + ($self:expr, $name:ident) => { + $self + .storage + .as_ref() + .unwrap() + .$name + .as_ref() + .map_physical(|u| unsafe { u.sprout($self.transfer.ctx()) }) + }; +} +impl<'a> llama::LLamaLayer for LayerLoader<'a> { + type Byte = DevByte; + type Storage<'m> = DevMem<'m> where Self: 'm; + + fn att_layernorm(&self) -> Tensor> { + access!(self, att_layernorm) + } + fn att_qkv(&self) -> Tensor> { + access!(self, att_qkv).transpose(&[1, 0]) + } + fn att_o(&self) -> Tensor> { + access!(self, att_o).transpose(&[1, 0]) + } + fn mlp_layernorm(&self) -> Tensor> { + access!(self, mlp_layernorm) + } + fn mlp_gate_up(&self) -> Tensor> { + access!(self, mlp_gate_up).transpose(&[1, 0]) + } + fn mlp_down(&self) -> Tensor> { + access!(self, mlp_down).transpose(&[1, 0]) + } +} + +impl Drop for LayerLoader<'_> { + fn drop(&mut self) { + let lll = self.storage.take().unwrap(); + if let Some(load) = self.load { + macro_rules! exchange { + ($($name:ident)+) => { + $( + let host = self.host[load].$name.physical(); + let mut dev = unsafe { lll.$name.physical().sprout(self.transfer.ctx()) }; + self.transfer.memcpy_h2d(&mut dev, host); + )+ + }; + } + exchange! { + att_layernorm + att_qkv + att_o + mlp_layernorm + mlp_gate_up + mlp_down + } + } + self.pool + .borrow_mut() + .push_back((lll, self.transfer.record().sporulate())); + } +} + #[test] fn test_infer() { cuda::init(); diff --git a/nvidia/transformer/src/parameters.rs b/nvidia/transformer/src/parameters.rs index 3a095382..4f9da4d5 100644 --- a/nvidia/transformer/src/parameters.rs +++ b/nvidia/transformer/src/parameters.rs @@ -43,153 +43,3 @@ impl ModelParameters { self.sync_event.kill(ctx); } } - -pub(crate) struct LayersParameters { - layers: Vec, - current: usize, -} - -impl LayersParameters { - pub fn new(load_layers: usize, host: &dyn Llama2, stream: &Stream) -> Self { - Self { - layers: (0..host.num_hidden_layers().min(load_layers)) - .map(|layer| LayerParameter::new(host, layer, stream)) - .collect(), - current: 0, - } - } - - #[inline] - pub fn load(&mut self, layer: usize, host: &dyn Llama2, stream: &Stream) { - let step = self.layers.len() - 1; - let i = (self.current + step) % self.layers.len(); - let layer = (layer + step) % host.num_hidden_layers(); - self.layers[i].load(host, layer, stream); - } - - #[inline] - pub fn sync(&mut self, layer: usize, stream: &Stream) -> &LayerParameter { - let i = self.current; - self.current = (i + 1) % self.layers.len(); - - let params = &self.layers[i]; - assert_eq!(params.layer, layer); - stream.wait_for(unsafe { ¶ms.sync_event.sprout(stream.ctx()) }); - - params - } - - pub unsafe fn kill(&mut self, ctx: &ContextGuard) { - for layer in &mut self.layers { - layer.input_layernorm.physical_mut().kill(ctx); - layer.w_qkv.physical_mut().kill(ctx); - layer.self_attn_o_proj.physical_mut().kill(ctx); - layer.post_attention_layernorm.physical_mut().kill(ctx); - layer.mlp_gate_up.physical_mut().kill(ctx); - layer.mlp_down.physical_mut().kill(ctx); - layer.sync_event.kill(ctx); - } - } -} - -pub(crate) struct LayerParameter { - pub input_layernorm: Tensor, - pub w_qkv: Tensor, - pub self_attn_o_proj: Tensor, - pub post_attention_layernorm: Tensor, - pub mlp_gate_up: Tensor, - pub mlp_down: Tensor, - - layer: usize, - sync_event: EventSpore, -} - -impl LayerParameter { - #[inline] - pub fn input_layernorm<'ctx>(&self, ctx: &'ctx ContextGuard) -> Tensor> { - unsafe { - self.input_layernorm - .as_ref() - .map_physical(|s| s.sprout(ctx)) - } - } - - #[inline] - pub fn w_qkv<'ctx>(&self, ctx: &'ctx ContextGuard) -> Tensor> { - unsafe { self.w_qkv.as_ref().map_physical(|s| s.sprout(ctx)) } - } - - #[inline] - pub fn w_o<'ctx>(&self, ctx: &'ctx ContextGuard) -> Tensor> { - unsafe { - self.self_attn_o_proj - .as_ref() - .map_physical(|s| s.sprout(ctx)) - } - } - - #[inline] - pub fn post_attention_layernorm<'ctx>(&self, ctx: &'ctx ContextGuard) -> Tensor> { - unsafe { - self.post_attention_layernorm - .as_ref() - .map_physical(|s| s.sprout(ctx)) - } - } - - #[inline] - pub fn mlp_gate_up<'ctx>(&self, ctx: &'ctx ContextGuard) -> Tensor> { - unsafe { self.mlp_gate_up.as_ref().map_physical(|s| s.sprout(ctx)) } - } - - #[inline] - pub fn mlp_down<'ctx>(&self, ctx: &'ctx ContextGuard) -> Tensor> { - unsafe { self.mlp_down.as_ref().map_physical(|s| s.sprout(ctx)) } - } - - fn new(host: &dyn Llama2, layer: usize, stream: &Stream) -> Self { - macro_rules! map { - ($param:ident) => { - host.$param(layer) - .as_ref() - .map_physical(|slice| stream.from_host(slice).sporulate()) - }; - } - Self { - input_layernorm: map!(input_layernorm), - w_qkv: map!(w_qkv).transpose(&[1, 0]), - self_attn_o_proj: map!(self_attn_o_proj).transpose(&[1, 0]), - post_attention_layernorm: map!(post_attention_layernorm), - mlp_gate_up: map!(mlp_gate_up).transpose(&[1, 0]), - mlp_down: map!(mlp_down).transpose(&[1, 0]), - layer, - sync_event: stream.record().sporulate(), - } - } - - fn load(&mut self, host: &dyn Llama2, layer: usize, stream: &Stream) { - if self.layer == layer { - return; - } - - let ctx = stream.ctx(); - macro_rules! update { - ($param:ident) => { - stream.memcpy_h2d( - unsafe { &mut self.$param.physical_mut().sprout(ctx) }, - host.$param(layer).as_slice(), - ) - }; - } - update!(input_layernorm); - update!(w_qkv); - update!(self_attn_o_proj); - update!(post_attention_layernorm); - update!(mlp_gate_up); - update!(mlp_down); - - unsafe { self.sync_event.kill(stream.ctx()) }; - self.sync_event = stream.record().sporulate(); - self.layer = layer; - } -} diff --git a/transformer-cpu/Cargo.toml b/transformer-cpu/Cargo.toml index 9cdf2bf6..48ba800c 100644 --- a/transformer-cpu/Cargo.toml +++ b/transformer-cpu/Cargo.toml @@ -12,7 +12,6 @@ tensor = { path = "../tensor" } causal-lm = { path = "../causal-lm" } llama = { path = "../models/llama" } kernel-lib = { path = "../kernel-lib" } -itertools.workspace = true gemm = "0.17" intel-mkl-src = { version = "0.8", features = ["mkl-dynamic-lp64-iomp"] } diff --git a/transformer-cpu/src/lib.rs b/transformer-cpu/src/lib.rs index e85c70fb..16b01ad0 100644 --- a/transformer-cpu/src/lib.rs +++ b/transformer-cpu/src/lib.rs @@ -3,14 +3,18 @@ mod kernel; use causal_lm::{CausalLM, DecodingMeta, Model, QueryContext, SampleMeta}; use common::{upos, utok, Blob, FileLoadError}; use gemm::f16; -use itertools::izip; use kernel::{ fused_softmax::softmax, gather::gather, mat_mul::mat_mul, rms_norm::rms_norm, rotary_embedding::rotary_embedding, swiglu::swiglu, }; -use llama::Storage; -use std::{iter::repeat, path::Path, slice::from_raw_parts}; -use tensor::{reslice, slice, split, udim, LocalSplitable, Tensor}; +use llama::{ComputeStream, LayerStorage, Storage, Weight}; +use std::{ + iter::repeat, + ops::{Deref, DerefMut}, + path::Path, + slice::from_raw_parts, +}; +use tensor::{reslice, slice, udim, Tensor}; pub struct Transformer(Storage); @@ -24,6 +28,130 @@ impl Model for Transformer { } } +impl ComputeStream for Transformer { + type Byte = u8; + type Storage = Blob; + type Buf<'m> = Blob; + type Pos<'m> = &'m [u8]; + + #[inline] + fn malloc(&self, len: usize) -> Self::Buf<'_> { + Blob::new(len) + } + #[inline] + fn map_pos<'p>(&self, pos: &'p [u32]) -> Self::Pos<'p> + where + Self: 'p, + { + reslice(pos) + } + fn map_storage(&self, storage: &mut Self::Storage) -> impl DerefMut { + &mut **storage + } + #[inline] + fn rms_norm(&self, o: &mut Tensor, x: &Tensor, w: &Tensor) + where + O: DerefMut, + X: Deref, + W: Deref, + { + rms_norm(o, x, w, self.0.config.epsilon); + } + #[inline] + fn mat_mul( + &self, + o: &mut Tensor, + beta: f32, + a: &Tensor, + b: &Tensor, + alpha: f32, + ) where + O: DerefMut, + A: Deref, + B: Deref, + { + mat_mul(o, beta, a, b, alpha); + } + #[inline] + fn rotary_embedding(&self, x: &mut Tensor, pos: &Tensor>) + where + X: DerefMut, + { + rotary_embedding(x, pos, self.0.config.theta); + } + #[inline] + fn reform(&self, y: &mut Tensor, x: &Tensor) + where + Y: DerefMut, + X: Deref, + { + x.reform_to(y); + } + #[inline] + fn softmax(&self, x: &mut Tensor) + where + X: DerefMut, + { + softmax(x); + } + #[inline] + fn swiglu(&self, a: &mut Tensor, b: &Tensor) + where + A: DerefMut, + B: Deref, + { + swiglu(a, b); + } + #[inline] + fn nh(&self) -> udim { + self.0.config.nh + } + #[inline] + fn nkvh(&self) -> udim { + self.0.config.nkvh + } + #[inline] + fn di(&self) -> udim { + self.0.config.di + } + #[inline] + fn layers(&self) -> impl Iterator> { + self.0.layers.iter().map(|layer| LlamaLayer(&layer)) + } +} + +struct LlamaLayer<'a>(&'a LayerStorage); + +impl<'a> llama::LLamaLayer for LlamaLayer<'a> { + type Byte = u8; + type Storage<'m> = Weight where Self: 'm; + + #[inline] + fn att_layernorm(&self) -> Tensor> { + self.0.att_layernorm.clone() + } + #[inline] + fn att_qkv(&self) -> Tensor> { + self.0.att_qkv.clone() + } + #[inline] + fn att_o(&self) -> Tensor> { + self.0.att_o.clone() + } + #[inline] + fn mlp_layernorm(&self) -> Tensor> { + self.0.mlp_layernorm.clone() + } + #[inline] + fn mlp_gate_up(&self) -> Tensor> { + self.0.mlp_gate_up.clone() + } + #[inline] + fn mlp_down(&self) -> Tensor> { + self.0.mlp_down.clone() + } +} + impl CausalLM for Transformer { type Storage = Blob; @@ -80,119 +208,8 @@ impl CausalLM for Transformer { &self, queries: impl IntoIterator>, token_embedded: Tensor, - ) -> Tensor - where - Self: 'a, - { - let mut queries = queries.into_iter().collect::>(); - let mut nt = 0; - let mut max_seq_len = 0; - let mut max_att_len = 0; - let seq_len = queries - .iter() - .map(|q| { - let seq = q.seq_len(); - let att = q.att_len(); - nt += seq; - max_seq_len = max_seq_len.max(seq); - max_att_len = max_att_len.max(att); - seq - }) - .collect::>(); - - let dt = self.0.config.dt; - let d = self.0.config.d; - let nh = self.0.config.nh; - let nkvh = self.0.config.nkvh; - let dh = d / nh; - let dkv = nkvh * dh; - let di = self.0.config.di; - let head_group = nh / nkvh; - let head_div = (dh as f32).sqrt().recip(); - - let reusing = (d + dkv + dkv).max(di + di); - let mut state_buf = Tensor::alloc(dt, &[nt, d + reusing], Blob::new); - macro_rules! state { - () => { - split!(state_buf.as_mut().map_physical(|u| LocalSplitable::from(&mut **u)); [1]: d, reusing) - }; - } - - let mut q_buf = Blob::new((nh * max_seq_len * dh) as usize * dt.size()); - let mut att_buf = Blob::new((nh * max_seq_len * max_att_len) as usize * dt.size()); - let pos = causal_lm::pos(&queries, nt); - let pos = pos.as_ref().map_physical(|u| reslice(u)); - - let mut x = token_embedded; - for (layer, params) in self.0.layers.iter().enumerate() { - let (mut x1, qkv) = state!(); - let mut qkv = qkv.slice(&[slice![=>], slice![=> d + dkv + dkv]]); - - rms_norm(&mut x1, &x, ¶ms.att_layernorm, self.0.config.epsilon); - mat_mul(&mut qkv, 0., &x1, ¶ms.att_qkv, 1.); - - let (q, k, v) = split!(qkv; [1]: d, dkv, dkv); - let mut q = q.reshape(&[nt, nh, dh]); - let mut k = k.reshape(&[nt, nkvh, dh]); - let v = v.reshape(&[nt, nkvh, dh]); - let o = x1.reshape(&[nt, nh, dh]); - - rotary_embedding(&mut q, &pos, self.0.config.theta); - rotary_embedding(&mut k, &pos, self.0.config.theta); - - let q = q.transpose(&[1, 0, 2]).split(1, &seq_len); - let k = k.transpose(&[1, 0, 2]).split(1, &seq_len); - let v = v.transpose(&[1, 0, 2]).split(1, &seq_len); - let o = o.transpose(&[1, 0, 2]).split(1, &seq_len); - - for (query, q, k, v, mut o) in izip!(&mut queries, q, k, v, o) { - let pos = query.pos(); - let seq_len = query.seq_len(); - let att_len = query.att_len(); - let Some((mut k_cache, mut v_cache)) = query.cache(layer as _) else { - continue; - }; - - let slice_cat = &[slice![=>], slice![pos =>=> seq_len], slice![=>]]; - let slice_att = &[slice![=>], slice![ => att_len], slice![=>]]; - let shape_q0 = &[nkvh * head_group, seq_len, dh]; - let shape_q1 = &[nkvh, head_group * seq_len, dh]; - let shape_att0 = &[nkvh, head_group * seq_len, att_len]; - let shape_att1 = &[nkvh * head_group, seq_len, att_len]; - - let mut q_att = Tensor::new(dt, shape_q0, &mut q_buf[..]); - let mut k_cat = k_cache.as_mut().slice(slice_cat).map_physical(|u| &mut **u); - let mut v_cat = v_cache.as_mut().slice(slice_cat).map_physical(|u| &mut **u); - q.reform_to(&mut q_att); - k.reform_to(&mut k_cat); - v.reform_to(&mut v_cat); - - let q_att = q_att.reshape(shape_q1); - let k_att = k_cache.slice(slice_att).transpose(&[0, 2, 1]); - let v_att = v_cache.slice(slice_att); - - let mut att = Tensor::new(dt, shape_att0, &mut att_buf[..]); - mat_mul(&mut att, 0., &q_att, &k_att, head_div); - let mut att = att.reshape(shape_att1); - softmax(&mut att); - let mut x2 = q_att; - mat_mul(&mut x2, 0., &att.reshape(shape_att0), &v_att, 1.); - - x2.reshape(shape_q0).reform_to(&mut o); - } - - let (mut x1, gate_up) = state!(); - let mut gate_up = gate_up.slice(&[slice![=>], slice![=> di + di]]); - - mat_mul(&mut x, 1., &x1, ¶ms.att_o, 1.); - rms_norm(&mut x1, &x, ¶ms.mlp_layernorm, self.0.config.epsilon); - mat_mul(&mut gate_up, 0., &x1, ¶ms.mlp_gate_up, 1.); - let (mut gate, up) = split!(gate_up; [1]: di, di); - swiglu(&mut gate, &up); - mat_mul(&mut x, 1., &gate, ¶ms.mlp_down, 1.); - } - - x + ) -> Tensor { + ::forward(self, queries, token_embedded) } fn decode(