diff --git a/Cargo.lock b/Cargo.lock index 309fbc98..161b1fb0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -829,8 +829,10 @@ dependencies = [ name = "distributed" version = "0.0.0" dependencies = [ + "causal-lm", "common-nv", "half", + "itertools", "log", "nccl", "search-cuda-tools", @@ -2546,8 +2548,8 @@ dependencies = [ "common-nv", "half", "itertools", + "log", "search-cuda-tools", - "tokenizer", "transformer", ] diff --git a/Cargo.toml b/Cargo.toml index f5ea9e02..9360200b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,7 @@ resolver = "2" [workspace.dependencies] half = "2.4" rayon = "1.9" +itertools = "0.12" serde = "1.0" serde_json = "1.0" log = "0.4" diff --git a/nvidia/distributed/Cargo.toml b/nvidia/distributed/Cargo.toml index 1bb45720..46b6a1d0 100644 --- a/nvidia/distributed/Cargo.toml +++ b/nvidia/distributed/Cargo.toml @@ -7,11 +7,13 @@ authors = ["YdrMaster "] # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +causal-lm = { path = "../../causal-lm" } transformer = { path = "../../transformer" } common-nv = { path = "../common" } nccl.workspace = true log.workspace = true half.workspace = true +itertools.workspace = true [dev-dependencies] simple_logger = "4.3" diff --git a/nvidia/distributed/src/lib.rs b/nvidia/distributed/src/lib.rs index d295b610..8a275211 100644 --- a/nvidia/distributed/src/lib.rs +++ b/nvidia/distributed/src/lib.rs @@ -5,212 +5,318 @@ mod parameters; #[macro_use] extern crate log; -pub use common_nv::cuda; - +use causal_lm::{CausalLM, DecodingMeta, Model, QueryContext, SampleMeta}; use common_nv::{ cast_dt, - cuda::{memcpy_d2h, AsRaw, Context, ContextResource, ContextSpore, DevMemSpore, Device}, - slice, split, udim, utok, DataType, NvidiaKernels, NvidiaKernelsPtx, Tensor, + cuda::{ + memcpy_d2h, AsRaw, Context, ContextResource, ContextSpore, DevMemSpore, Device, StreamSpore, + }, + slice, split, udim, upos, utok, DataType, LocalSplitable, NvidiaKernels, NvidiaKernelsPtx, + SafeTensorsError, Tensor, }; use half::f16; +use itertools::izip; use nccl::CommunicatorGroup; use parameters::ParameterMatrix; -use std::{iter::zip, path::Path, slice::from_raw_parts, sync::Arc, time::Instant}; -use transformer::{pos, Kernels, LayerBuffer, LayerCache, Llama2, Memory, Request}; +use std::{ + iter::{repeat, zip}, + path::Path, + slice::from_raw_parts, + sync::Arc, + time::Instant, +}; +use transformer::{Kernels, Llama2, Memory}; + +pub use common_nv::cuda; pub struct Transformer { host: Memory, comms: CommunicatorGroup, + streams: Vec, kernels: Vec, model_norm: Tensor, lm_head: Tensor, matrix: ParameterMatrix, } -impl transformer::Transformer for Transformer { - type Cache = Cache; +impl Model for Transformer { + type Meta = Vec; + type Error = SafeTensorsError; #[inline] - fn max_position_embeddings(&self) -> usize { - self.host.max_position_embeddings() + fn load(model_dir: impl AsRef, meta: Self::Meta) -> Result { + let time = Instant::now(); + let host = Memory::load_safetensors(model_dir)?; + info!("load host: {:?}", time.elapsed()); + + let block_size = meta.iter().map(|dev| dev.max_block_dims().0).min().unwrap(); + let contexts = meta.iter().map(Device::retain_primary).collect::>(); + let kernels = NvidiaKernelsPtx::new(&host, block_size); + + let comms = CommunicatorGroup::new( + &meta + .iter() + .map(|dev| unsafe { dev.as_raw() }) + .collect::>(), + ); + let (model_norm, lm_head) = comms.contexts().next().unwrap().apply(|ctx| { + ( + ctx.from_host(host.model_norm().as_slice()).sporulate(), + ctx.from_host(host.lm_head().as_slice()).sporulate(), + ) + }); + Ok(Self { + comms, + streams: contexts + .iter() + .map(|context| context.apply(|ctx| ctx.stream().sporulate())) + .collect(), + kernels: contexts + .iter() + .map(|context| context.apply(|ctx| kernels.load(ctx))) + .collect(), + matrix: ParameterMatrix::load(&host, &contexts), + model_norm: host.model_norm().map_physical(|_| model_norm), + lm_head: host.lm_head().map_physical(|_| lm_head).transpose(&[1, 0]), + host, + }) } +} + +impl CausalLM for Transformer { + type Storage = Cache; #[inline] fn eos_token(&self) -> utok { self.host.eos_token_id() } - #[inline] - fn new_cache(&self) -> Vec> { + fn new_cache(&self) -> Tensor { + let dt = self.host.data_type(); + let nlayers = self.host.num_hidden_layers() as udim; + let nkvh = self.host.num_key_value_heads() as udim; + let max_seq_len = self.host.max_position_embeddings() as udim; + let d = self.host.hidden_size() as udim; + let nh = self.host.num_attention_heads() as udim; + let contexts = Arc::new(self.comms.contexts().collect::>()); - LayerCache::new_layers(&self.host, |dt, shape| { - let &[nkvh, max_seq_len, d] = shape else { - panic!() - }; - Tensor::alloc( - dt, - &[nkvh / self.comms.len() as udim, max_seq_len, d], - |len| Cache { - mem: contexts - .iter() - .map(|context| context.apply(|ctx| ctx.malloc::(len).sporulate())) - .collect(), - contexts: contexts.clone(), - }, - ) + let n = contexts.len() as udim; + Tensor::alloc(dt, &[nlayers, 2, nkvh / n, max_seq_len, d / nh], |len| { + Cache { + mem: contexts + .iter() + .map(|context| context.apply(|ctx| ctx.malloc::(len).sporulate())) + .collect(), + contexts: contexts.clone(), + } }) } - fn decode( - &self, - mut requests: Vec>, - ) -> (Vec, Tensor) { - // 归拢所有纯解码的请求到前面,减少批量解码的拷贝开销 - requests.sort_unstable_by_key(Request::purely_decode); + fn duplicate_cache(&self, cache: &Tensor, pos: upos) -> Tensor { + let &[_nlayers, 2, _nkvh, max_seq_len, _dh] = cache.shape() else { + panic!() + }; + assert!(pos <= max_seq_len); + let slice = [ + slice![=>], + slice![=>], + slice![=>], + slice![=>pos], + slice![=>], + ]; + let contexts = Arc::new(self.comms.contexts().collect::>()); + let mem = contexts + .iter() + .enumerate() + .map(|(i, context)| { + context.apply(|ctx| { + let stream = ctx.stream(); + let mut ans = Tensor::alloc(cache.data_type(), cache.shape(), |len| { + stream.malloc::(len) + }); + let kernels = self.kernels[i].on(&stream); + kernels.reform( + &mut ans.as_mut().slice(&slice).map_physical(|u| &mut **u), + &cache + .as_ref() + .slice(&slice) + .map_physical(|u| unsafe { u.mem[i].sprout(ctx) }), + ); + ans.take_physical().sporulate() + }) + }) + .collect(); + + Tensor::new(cache.data_type(), cache.shape(), Cache { contexts, mem }) + } + + fn token_embed(&self, queries: impl IntoIterator) -> Tensor { + let tokens = queries.into_iter().collect::>(); + let nt = tokens.len() as udim; + + let contexts = Arc::new(self.comms.contexts().collect::>()); let dt = self.host.data_type(); - let nt = requests.iter().map(Request::seq_len).sum::(); let d = self.host.hidden_size() as udim; - let nh = self.host.num_attention_heads() as udim; - let dh = d / nh; - let nkvh = self.host.num_key_value_heads() as udim; - let dkv = nkvh * dh; - let head_group = nh / nkvh; - let di = self.host.intermediate_size() as udim; - let head_div = (dh as f32).sqrt().recip(); - let contexts = self.comms.contexts().collect::>(); - let mut streams = contexts - .iter() - .map(|ctx| ctx.apply(|c| c.stream().sporulate())) - .collect::>(); - let n = contexts.len(); - // token embedding - let mut x0 = Tensor::alloc(dt, &[nt, d], |len| malloc_all(&contexts, len)); + let mut x = Tensor::alloc(dt, &[nt, d], |len| malloc_all(&contexts, len)); contexts[0].apply(|ctx| { - let stream = unsafe { ctx.sprout(&streams[0]) }; + let stream = unsafe { ctx.sprout(&self.streams[0]) }; let kernels = self.kernels[0].on(&stream); - let mut x = unsafe { x0.as_mut().map_physical(|u| ctx.sprout(&u[0])) }; - kernels.gather( - &mut x, - &self.host.embed_tokens(), - requests.iter().flat_map(Request::tokens).copied(), - ); + let mut x = x.as_mut().map_physical(|u| unsafe { ctx.sprout(&u[0]) }); + kernels.gather(&mut x, &self.host.embed_tokens(), tokens); }); for (i, comm) in self.comms.call().iter().enumerate() { contexts[i].apply(|ctx| { - let stream = unsafe { ctx.sprout(&streams[i]) }; - let mut dst = unsafe { ctx.sprout(&x0.physical_mut()[i]) }; + let stream = unsafe { ctx.sprout(&self.streams[i]) }; + let mut dst = unsafe { ctx.sprout(&x.physical_mut()[i]) }; comm.broadcast(&mut dst, None, 0, &stream); }); } - let mut x1 = Tensor::alloc(dt, &[nt, d], |len| malloc_all(&contexts, len)); - let LayerBuffer { - qkv, - gate_up, - q_buf, - att_buf, - } = LayerBuffer::alloc(&self.host, &requests, |len| malloc_all(&contexts, len / n)); - let mut buf = LayerBuffer { - qkv: { - let &[a, b] = qkv.shape() else { panic!() }; - Tensor::new(dt, &[a, b / n as udim], qkv.take_physical()) - }, - gate_up: { - let &[a, b] = gate_up.shape() else { panic!() }; - Tensor::new(dt, &[a, b / n as udim], gate_up.take_physical()) - }, - q_buf, - att_buf, - }; - // 生成位置张量 - let nt = x0.shape()[0]; // `nt` for number of tokens - let pos_ = pos(&requests, nt); - let mut pos = Tensor::new( - DataType::U32, - &[nt], + x.map_physical(|mem| Cache { contexts, mem }) + } + + fn forward<'a>( + &self, + queries: impl IntoIterator>, + token_embedded: Tensor, + ) -> Tensor + where + Self: 'a, + { + let mut queries = queries.into_iter().collect::>(); + let mut nt = 0; + let mut max_seq_len = 0; + let mut max_att_len = 0; + let seq_len = queries + .iter() + .map(|q| { + let seq = q.seq_len(); + let att = q.att_len(); + nt += seq; + max_seq_len = max_seq_len.max(seq); + max_att_len = max_att_len.max(att); + seq + }) + .collect::>(); + + let dt = self.host.data_type(); + let d = self.host.hidden_size() as udim; + let nh = self.host.num_attention_heads() as udim; + let nkvh = self.host.num_key_value_heads() as udim; + let dh = d / nh; + let dkv = nkvh * dh; + let di = self.host.intermediate_size() as udim; + let head_group = nh / nkvh; + let head_div = (dh as f32).sqrt().recip(); + + let contexts = self.comms.contexts().collect::>(); + let n = contexts.len() as udim; + + let reusing = (d + dkv + dkv).max(di + di); + let mut state_buf = + Tensor::alloc(dt, &[nt, d + reusing / n], |len| malloc_all(&contexts, len)); + macro_rules! state { + () => { + split!(state_buf.as_mut().map_physical(|u| LocalSplitable::from(&mut **u)); [1]: d, reusing / n) + }; + } + + let mut q_buf = malloc_all(&contexts, (nh / n * max_seq_len * dh) as usize * dt.size()); + let mut att_buf = malloc_all( + &contexts, + (nh / n * max_seq_len * max_att_len) as usize * dt.size(), + ); + let pos = causal_lm::pos(&queries, nt); + let mut pos = pos.as_ref().map_physical(|u| { contexts .iter() .enumerate() .map(|(i, context)| { context.apply(|ctx| { - unsafe { ctx.sprout(&streams[i]) } - .from_host(&pos_) + unsafe { ctx.sprout(&self.streams[i]) } + .from_host(u) .sporulate() }) }) - .collect::>(), - ); + .collect::>() + }); + let mut x = token_embedded; for layer in 0..self.host.num_hidden_layers() { - // before attention + let (mut x1, qkv) = state!(); + let mut qkv = qkv.slice(&[slice![=>], slice![=> (d + dkv + dkv) / n]]); + for (i, context) in contexts.iter().enumerate() { context.apply(|ctx| { - let stream = unsafe { ctx.sprout(&streams[i]) }; + let stream = unsafe { ctx.sprout(&self.streams[i]) }; let kernels = self.kernels[i].on(&stream); let params = self.matrix.get(layer, i, ctx); - let x0 = unsafe { x0.as_ref().map_physical(|u| ctx.sprout(&u[i])) }; - let mut x1 = unsafe { x1.as_mut().map_physical(|u| ctx.sprout(&u[i])) }; - let mut qkv = unsafe { buf.qkv.as_mut().map_physical(|u| ctx.sprout(&u[i])) }; - kernels.rms_norm(&mut x1, &x0, ¶ms.input_layernorm()); + let x = x + .as_ref() + .map_physical(|u| unsafe { ctx.sprout(&u.mem[i]) }); + let mut x1 = x1.as_mut().map_physical(|u| unsafe { ctx.sprout(&u[i]) }); + let mut qkv = qkv.as_mut().map_physical(|u| unsafe { ctx.sprout(&u[i]) }); + kernels.rms_norm(&mut x1, &x, ¶ms.input_layernorm()); kernels.mat_mul(&mut qkv, 0., &x1, ¶ms.w_qkv(), 1.); }); } - let (q, k, v) = - split!(buf.qkv.as_ref(); [1]: d / n as udim, dkv / n as udim, dkv / n as udim); - let mut q = q.reshape(&[nt, nh / n as udim, dh]); - let mut k = k.reshape(&[nt, nkvh / n as udim, dh]); - let v = v.reshape(&[nt, nkvh / n as udim, dh]); + + let (q, k, v) = split!(qkv; [1]: d / n, dkv / n, dkv / n); + let mut q = q.reshape(&[nt, nh / n, dh]); + let mut k = k.reshape(&[nt, nkvh / n, dh]); + let v = v.reshape(&[nt, nkvh / n, dh]); + let o = x1.reshape(&[nt, nh, dh]); + let o = o.slice(&[slice![=>], slice![=> nh / n], slice![=>]]); + for (i, context) in contexts.iter().enumerate() { context.apply(|ctx| { - let stream = unsafe { ctx.sprout(&streams[i]) }; + let stream = unsafe { ctx.sprout(&self.streams[i]) }; let kernels = self.kernels[i].on(&stream); - let pos = unsafe { pos.as_ref().map_physical(|u| ctx.sprout(&u[i])) }; - let mut q = unsafe { q.as_mut().map_physical(|u| ctx.sprout(&u[i])) }; - let mut k = unsafe { k.as_mut().map_physical(|u| ctx.sprout(&u[i])) }; + let pos = pos.as_ref().map_physical(|u| unsafe { ctx.sprout(&u[i]) }); + let mut q = q.as_mut().map_physical(|u| unsafe { ctx.sprout(&u[i]) }); + let mut k = k.as_mut().map_physical(|u| unsafe { ctx.sprout(&u[i]) }); kernels.rotary_embedding(&mut q, &pos); kernels.rotary_embedding(&mut k, &pos); }); } - let o = &mut x1; - // attention - let q = q.as_ref().transpose(&[1, 0, 2]); - let k = k.as_ref().transpose(&[1, 0, 2]); - let v = v.as_ref().transpose(&[1, 0, 2]); - let mut o = o - .as_mut() - .reshape(&[nt, nh, dh]) - .transpose(&[1, 0, 2]) - .slice(&[slice![=> nh/n as udim], slice![=>], slice![=>]]); - - let mut req = 0; - for r in requests.iter_mut() { - let pos = r.pos(); - let seq_len = r.seq_len(); - let att_len = r.att_len(); - - let req_slice = &[slice![=>], slice![req =>=> seq_len], slice![=>]]; - let cat_slice = &[slice![=>], slice![pos =>=> seq_len], slice![=>]]; - let att_slice = &[slice![=>], slice![ => att_len], slice![=>]]; - req += seq_len; - let q = q.clone().slice(req_slice); - let k = k.clone().slice(req_slice); - let v = v.clone().slice(req_slice); - let mut o = o.as_mut().slice(req_slice); + let q = q.transpose(&[1, 0, 2]).split(1, &seq_len); + let k = k.transpose(&[1, 0, 2]).split(1, &seq_len); + let v = v.transpose(&[1, 0, 2]).split(1, &seq_len); + let o = o.transpose(&[1, 0, 2]).split(1, &seq_len); + + for (query, q, k, v, mut o) in izip!(&mut queries, q, k, v, o) { + let pos = query.pos(); + let seq_len = query.seq_len(); + let att_len = query.att_len(); + let mut cache = query + .cache + .as_mut() + .map(|t| t.as_mut().map_physical(|u| &mut *u.mem)); + let mut query = QueryContext { + cache: cache.as_mut(), + range: query.range.clone(), + }; + let Some((mut k_cache, mut v_cache)) = query.cache(layer) else { + continue; + }; - let shape_att0 = &[nkvh / n as udim, head_group * seq_len, att_len]; - let shape_att1 = &[nkvh / n as udim * head_group, seq_len, att_len]; + let slice_cat = &[slice![=>], slice![pos =>=> seq_len], slice![=>]]; + let slice_att = &[slice![=>], slice![ => att_len], slice![=>]]; + let shape_q0 = &[nkvh / n * head_group, seq_len, dh]; + let shape_q1 = &[nkvh / n, head_group * seq_len, dh]; + let shape_att0 = &[nkvh / n, head_group * seq_len, att_len]; + let shape_att1 = &[nkvh / n * head_group, seq_len, att_len]; - let mut q_att = Tensor::new(dt, &[nh / n as udim, seq_len, dh], &mut *buf.q_buf); - let mut att = Tensor::new(dt, shape_att0, &mut *buf.att_buf); - let (k_cache, v_cache) = r.cache(layer); + let mut q_att = Tensor::new(dt, shape_q0, &mut q_buf[..]); + let mut att = Tensor::new(dt, shape_att0, &mut att_buf[..]); for (i, context) in contexts.iter().enumerate() { context.apply(|ctx| { - let stream = unsafe { ctx.sprout(&streams[i]) }; + let stream = unsafe { ctx.sprout(&self.streams[i]) }; let kernels = self.kernels[i].on(&stream); let q = unsafe { q.as_ref().map_physical(|u| ctx.sprout(&u[i])) }; @@ -220,22 +326,22 @@ impl transformer::Transformer for Transformer { let mut q_att = unsafe { q_att.as_mut().map_physical(|u| ctx.sprout(&u[i])) }; let mut k_cache = - unsafe { k_cache.as_mut().map_physical(|u| ctx.sprout(&u.mem[i])) }; + unsafe { k_cache.as_mut().map_physical(|u| ctx.sprout(&u[i])) }; let mut v_cache = - unsafe { v_cache.as_mut().map_physical(|u| ctx.sprout(&u.mem[i])) }; + unsafe { v_cache.as_mut().map_physical(|u| ctx.sprout(&u[i])) }; let mut att = unsafe { att.as_mut().map_physical(|u| ctx.sprout(&u[i])) }; let mut k_cat = - k_cache.as_mut().slice(cat_slice).map_physical(|u| &mut **u); + k_cache.as_mut().slice(slice_cat).map_physical(|u| &mut **u); let mut v_cat = - v_cache.as_mut().slice(cat_slice).map_physical(|u| &mut **u); + v_cache.as_mut().slice(slice_cat).map_physical(|u| &mut **u); kernels.reform(&mut q_att, &q); kernels.reform(&mut k_cat, &k); kernels.reform(&mut v_cat, &v); - let q_att = q_att.reshape(&[nkvh / n as udim, head_group * seq_len, dh]); - let k_att = k_cache.slice(att_slice).transpose(&[0, 2, 1]); - let v_att = v_cache.slice(att_slice); + let q_att = q_att.reshape(shape_q1); + let k_att = k_cache.slice(slice_att).transpose(&[0, 2, 1]); + let v_att = v_cache.slice(slice_att); kernels.mat_mul(&mut att, 0., &q_att, &k_att, head_div); let mut att = att.reshape(shape_att1); @@ -244,23 +350,28 @@ impl transformer::Transformer for Transformer { let att = att.reshape(shape_att0); kernels.mat_mul(&mut x2, 0., &att, &v_att, 1.); - kernels.reform(&mut o, &x2.reshape(&[nh / n as udim, seq_len, dh])); + kernels.reform(&mut o, &x2.reshape(shape_q0)); }); } } - // after attention + + let (mut x1, gate_up) = state!(); + let mut gate_up = gate_up.slice(&[slice![=>], slice![=> (di + di) / n]]); + for (i, comm) in self.comms.call().iter().enumerate() { contexts[i].apply(|ctx| { - let stream = unsafe { ctx.sprout(&streams[i]) }; + let stream = unsafe { ctx.sprout(&self.streams[i]) }; let kernels = self.kernels[i].on(&stream); let params = self.matrix.get(layer, i, ctx); - let mut x0 = unsafe { x0.as_mut().map_physical(|u| ctx.sprout(&u[i])) }; + let mut x = x + .as_ref() + .map_physical(|u| unsafe { ctx.sprout(&u.mem[i]) }); let o = x1.as_ref().slice(&[slice![=>], slice![=> d/n as udim]]); let o = unsafe { o.map_physical(|u| ctx.sprout(&u[i])) }; - kernels.mat_mul(&mut x0, if i == 0 { 1. } else { 0. }, &o, ¶ms.w_o(), 1.); + kernels.mat_mul(&mut x, if i == 0 { 1. } else { 0. }, &o, ¶ms.w_o(), 1.); comm.all_reduce( - x0.physical_mut(), + x.physical_mut(), None, cast_dt(self.host.data_type()), nccl::ReduceType::ncclSum, @@ -268,49 +379,48 @@ impl transformer::Transformer for Transformer { ); }); } - // for (i, context) in contexts.iter().enumerate() { - // context.apply(|ctx| { - // let stream = unsafe { ctx.sprout(&streams[i]) }; - // stream.synchronize(); - // }) - // } - // std::process::exit(0); for (i, context) in contexts.iter().enumerate() { context.apply(|ctx| { - let stream = unsafe { ctx.sprout(&streams[i]) }; + let stream = unsafe { ctx.sprout(&self.streams[i]) }; let kernels = self.kernels[i].on(&stream); let params = self.matrix.get(layer, i, ctx); - let x0 = unsafe { x0.as_ref().map_physical(|u| ctx.sprout(&u[i])) }; + let x = x + .as_ref() + .map_physical(|u| unsafe { ctx.sprout(&u.mem[i]) }); let mut x1 = unsafe { x1.as_mut().map_physical(|u| ctx.sprout(&u[i])) }; let mut gate_up = - unsafe { buf.gate_up.as_mut().map_physical(|u| ctx.sprout(&u[i])) }; + unsafe { gate_up.as_mut().map_physical(|u| ctx.sprout(&u[i])) }; - kernels.rms_norm(&mut x1, &x0, ¶ms.post_att_layernorm()); + kernels.rms_norm(&mut x1, &x, ¶ms.post_att_layernorm()); kernels.mat_mul(&mut gate_up, 0., &x1, ¶ms.mlp_gate_up(), 1.); }); } - let (mut gate, up) = split!(buf.gate_up; [1]: di / n as udim, di / n as udim); + + let (mut gate, up) = split!(gate_up; [1]: di / n, di / n); + for (i, comm) in self.comms.call().iter().enumerate() { contexts[i].apply(|ctx| { - let stream = unsafe { ctx.sprout(&streams[i]) }; + let stream = unsafe { ctx.sprout(&self.streams[i]) }; let kernels = self.kernels[i].on(&stream); let params = self.matrix.get(layer, i, ctx); let mut gate = unsafe { gate.as_mut().map_physical(|u| ctx.sprout(&u[i])) }; let up = unsafe { up.as_ref().map_physical(|u| ctx.sprout(&u[i])) }; - let mut x0 = unsafe { x0.as_mut().map_physical(|u| ctx.sprout(&u[i])) }; + let mut x = x + .as_mut() + .map_physical(|u| unsafe { ctx.sprout(&u.mem[i]) }); kernels.swiglu(&mut gate, &up); kernels.mat_mul( - &mut x0, + &mut x, if i == 0 { 1. } else { 0. }, &gate, ¶ms.mlp_down(), 1., ); comm.all_reduce( - x0.physical_mut(), + x.physical_mut(), None, cast_dt(self.host.data_type()), nccl::ReduceType::ncclSum, @@ -319,135 +429,130 @@ impl transformer::Transformer for Transformer { }); } } - // decode - if requests[0].decode() { - let context = self.comms.contexts().next().unwrap(); - let logits = context.apply(|ctx| { - let stream = unsafe { ctx.sprout(&streams[0]) }; - let kernels = self.kernels[0].on(&stream); - - let slice = { - let mut dst = unsafe { ctx.sprout(&x0.physical_mut()[0]) }; - let dst = &mut *dst; - let src = unsafe { from_raw_parts(dst.as_ptr(), dst.len()) }; - - let (head, others) = requests.split_first().unwrap(); - let begin = head.seq_len() as usize - 1; - - let mut i_src = begin; - let mut i_dst = begin; - for r in others { - i_src += r.seq_len() as usize; - if r.decode() { - i_dst += 1; - if i_dst < i_src { - stream.memcpy_d2d(dst, src); - } - } - } - slice![begin => i_dst + 1] - }; - let x = x0.as_ref().slice(&[slice, slice![=>]]); - let mut x = unsafe { x.map_physical(|u| ctx.sprout(&u[0])) }; - - let dt = self.host.data_type(); - let voc = self.host.vocab_size() as udim; - - let mut logits = - Tensor::alloc(dt, &[x.shape()[0], voc], |len| stream.malloc::(len)); - // 复制一个 x 以实现原地归一化 - let x_ = unsafe { - x.as_ref() - .map_physical(|u| from_raw_parts(u.as_ptr(), u.len())) - }; - - let model_norm = - unsafe { self.model_norm.as_ref().map_physical(|u| ctx.sprout(u)) }; - let lm_head = unsafe { self.lm_head.as_ref().map_physical(|u| ctx.sprout(u)) }; - kernels.rms_norm(&mut x, &x_, &model_norm); - kernels.mat_mul(&mut logits, 0., &x, &lm_head, 1.); - logits.map_physical(|mem| mem.sporulate()) - }); - let logits = logits.map_physical(|mem| Cache { - contexts: Arc::new(vec![context]), - mem: vec![mem], + // kill + for (i, context) in contexts.iter().enumerate() { + context.apply(|ctx| unsafe { + ctx.kill(&mut state_buf.physical_mut()[i]); + ctx.kill(&mut q_buf[i]); + ctx.kill(&mut att_buf[i]); + ctx.kill(&mut pos.physical_mut()[i]); }); - // kill - for (i, context) in contexts.iter().enumerate() { - context.apply(|ctx| unsafe { - ctx.kill(&mut streams[i]); - ctx.kill(&mut x0.physical_mut()[i]); - ctx.kill(&mut x1.physical_mut()[i]); - ctx.kill(&mut buf.qkv.physical_mut()[i]); - ctx.kill(&mut buf.gate_up.physical_mut()[i]); - ctx.kill(&mut buf.q_buf[i]); - ctx.kill(&mut buf.att_buf[i]); - ctx.kill(&mut pos.physical_mut()[i]); + } + + x + } + + fn decode( + &self, + decoding: impl IntoIterator, + mut hidden_state: Tensor, + ) -> Tensor { + let dt = self.host.data_type(); + let d = self.host.hidden_size(); + let voc = self.host.vocab_size() as udim; + + let contexts = Arc::new(vec![self.comms.contexts().next().unwrap()]); + contexts[0].apply(|ctx| { + let mut x = hidden_state + .as_mut() + .map_physical(|u| unsafe { u.mem[0].sprout(ctx) }); + let stream = unsafe { self.streams[0].sprout(ctx) }; + let kernels = self.kernels[0].on(&stream); + + let len = d * dt.size(); + + let mut iter = decoding.into_iter(); + let mut begin = 0; + let mut src = 0; + let mut dst = 0; + for DecodingMeta { + num_query, + num_decode, + } in iter.by_ref() + { + begin += num_query; + if num_decode > 0 { + src = begin; + dst = begin; + begin -= num_decode; + break; + } + } + let dst_ = &mut **x.physical_mut(); + let src_ = unsafe { from_raw_parts(dst_.as_ptr(), dst_.len()) }; + for DecodingMeta { + num_query, + num_decode, + } in iter + { + src += num_query - num_decode; + if src > dst { + for _ in 0..num_decode { + stream.memcpy_d2d(&mut dst_[dst * len..][..len], &src_[src * len..][..len]); + src += 1; + dst += 1; + } + } else { + src += num_decode; + dst += num_decode; + } + } + + if dst == begin { + return Tensor::alloc(dt, &[0, d as _], |_| Cache { + contexts: contexts.clone(), + mem: vec![stream.malloc::(0).sporulate()], }); } - (requests.into_iter().map(Request::id).collect(), logits) - } else { - todo!() - } + + let mut x = x.slice(&[slice![begin => dst], slice![=>]]); + let mut logits = + Tensor::alloc(dt, &[x.shape()[0], voc], |len| stream.malloc::(len)); + + let model_norm = self + .model_norm + .as_ref() + .map_physical(|u| unsafe { u.sprout(ctx) }); + let lm_head = self + .lm_head + .as_ref() + .map_physical(|u| unsafe { u.sprout(ctx) }); + // 复制一个 x 以实现原地归一化 + let x_ = x + .as_ref() + .map_physical(|u| unsafe { from_raw_parts(u.as_ptr(), u.len()) }); + kernels.rms_norm(&mut x, &x_, &model_norm); + kernels.mat_mul(&mut logits, 0., &x, &lm_head, 1.); + + logits.map_physical(|u| Cache { + contexts: contexts.clone(), + mem: vec![u.sporulate()], + }) + }) } - fn sample( + fn sample( &self, - args: &transformer::SampleArgs, - requests: Vec, - logits: Tensor, - ) -> Vec<(Id, utok)> { + args: impl IntoIterator, + logits: Tensor, + ) -> Vec { assert_eq!(logits.data_type(), DataType::F16); let &[_, voc] = logits.shape() else { panic!() }; + let voc = voc as usize; let mut host = vec![f16::ZERO; logits.size()]; let Cache { contexts, mem } = logits.physical(); contexts[0].apply(|ctx| memcpy_d2h(&mut host, unsafe { &mem[0].sprout(ctx) })); - requests - .into_iter() + args.into_iter() + .flat_map(|meta| repeat(meta.args).take(meta.num_decode)) .enumerate() - .map(|(i, id)| (id, args.random(&host[i * voc as usize..][..voc as usize]))) + .map(|(i, args)| args.random(&host[i * voc..][..voc])) .collect() } } -impl Transformer { - pub fn new(model_dir: impl AsRef, dev: &[Device]) -> Self { - let time = Instant::now(); - let host = Memory::load_safetensors(model_dir).unwrap(); - info!("load host: {:?}", time.elapsed()); - - let block_size = dev.iter().map(|dev| dev.max_block_dims().0).min().unwrap(); - let contexts = dev.iter().map(Device::retain_primary).collect::>(); - let kernels = NvidiaKernelsPtx::new(&host, block_size); - - let comms = CommunicatorGroup::new( - &dev.iter() - .map(|dev| unsafe { dev.as_raw() }) - .collect::>(), - ); - let (model_norm, lm_head) = comms.contexts().next().unwrap().apply(|ctx| { - ( - ctx.from_host(host.model_norm().as_slice()).sporulate(), - ctx.from_host(host.lm_head().as_slice()).sporulate(), - ) - }); - Self { - comms, - kernels: contexts - .iter() - .map(|context| context.apply(|ctx| kernels.load(ctx))) - .collect(), - matrix: ParameterMatrix::load(&host, &contexts), - model_norm: host.model_norm().map_physical(|_| model_norm), - lm_head: host.lm_head().map_physical(|_| lm_head).transpose(&[1, 0]), - host, - } - } -} - impl Drop for Transformer { #[inline] fn drop(&mut self) { @@ -458,8 +563,12 @@ impl Drop for Transformer { ctx.kill(self.lm_head.physical_mut()); }); self.matrix.kill(&contexts); - for (context, kernels) in zip(contexts, &mut self.kernels) { - context.apply(|ctx| kernels.kill(ctx)); + for (context, stream, kernels) in izip!(contexts, &mut self.streams, &mut self.kernels) + { + context.apply(|ctx| { + stream.kill(ctx); + kernels.kill(ctx); + }); } } } @@ -487,35 +596,59 @@ fn malloc_all(contexts: &[Context], len: usize) -> Vec { } #[test] -fn test() { - use common_nv::cuda::{self, Device}; - use log::LevelFilter::Trace; - use simple_logger::SimpleLogger; - use transformer::Transformer as _; +fn test_infer() { + use std::time::Instant; let Some(model_dir) = common_nv::test_model::find() else { return; }; println!("model_dir: {}", model_dir.display()); - const N: usize = 4; - cuda::init(); - if Device::count() < N { + if cuda::Device::count() < 2 { return; } - SimpleLogger::new().with_level(Trace).init().unwrap(); - - let time = Instant::now(); - let transformer = Transformer::new(model_dir, &[Device::fetch().unwrap()]); - info!("load {:?}", time.elapsed()); - - let time = Instant::now(); - let mut cache = transformer.new_cache(); - info!("new cache: {:?}", time.elapsed()); - - let time = Instant::now(); - transformer.decode(vec![Request::new(0, &[1, 2, 3], &mut cache, 0, true)]); - info!("decode: {:?}", time.elapsed()); + let t0 = Instant::now(); + let model = ::load( + model_dir, + [0, 1].map(cuda::Device::new).into_iter().collect(), + ) + .unwrap(); + let t1 = Instant::now(); + println!("load {:?}", t1 - t0); + + let mut cache = model.new_cache(); + + let mut prompt: Vec = vec![ + 29966, 29989, 1792, 29989, 29958, 13, 29903, 388, 376, 18567, 29908, 304, 592, 21106, + 29879, 5299, 29989, 465, 22137, 29989, 29958, 13, + ]; + let mut pos = 0; + + while prompt != &[model.eos_token()] { + let token_embedded = CausalLM::token_embed(&model, prompt.iter().copied()); + + let queries = [QueryContext { + cache: Some(&mut cache), + range: pos..pos + prompt.len() as upos, + }]; + let hidden_state = CausalLM::forward(&model, queries, token_embedded); + + let decoding = [DecodingMeta { + num_query: prompt.len(), + num_decode: 1, + }]; + let logits = CausalLM::decode(&model, decoding, hidden_state); + + let args = [SampleMeta { + num_decode: 1, + args: causal_lm::SampleArgs::default(), + }]; + let tokens = CausalLM::sample(&model, args, logits); + + println!("{:?}", tokens); + pos += prompt.len() as upos; + prompt = tokens; + } } diff --git a/nvidia/transformer/Cargo.toml b/nvidia/transformer/Cargo.toml index 1d38dffa..6dcba087 100644 --- a/nvidia/transformer/Cargo.toml +++ b/nvidia/transformer/Cargo.toml @@ -9,12 +9,10 @@ authors = ["YdrMaster "] [dependencies] causal-lm = { path = "../../causal-lm" } transformer = { path = "../../transformer" } -itertools = "0.12" common-nv = { path = "../common" } +log.workspace = true half.workspace = true - -[dev-dependencies] -tokenizer = { path = "../../tokenizer" } +itertools.workspace = true [build-dependencies] search-cuda-tools.workspace = true diff --git a/nvidia/transformer/src/lib.rs b/nvidia/transformer/src/lib.rs index 8b734107..6754b102 100644 --- a/nvidia/transformer/src/lib.rs +++ b/nvidia/transformer/src/lib.rs @@ -2,6 +2,9 @@ mod parameters; +#[macro_use] +extern crate log; + use ::half::f16; use causal_lm::{CausalLM, DecodingMeta, Model, QueryContext, SampleMeta}; use common_nv::{ @@ -17,6 +20,7 @@ use std::{ path::Path, slice::from_raw_parts, sync::{Arc, Mutex}, + time::Instant, }; use transformer::{Kernels, Llama2, Memory}; @@ -39,10 +43,12 @@ impl Model for Transformer { #[inline] fn load(model_dir: impl AsRef, meta: Self::Meta) -> Result { let context = Arc::new(meta.retain_primary()); + let time = Instant::now(); let host = Memory::load_safetensors_realloc( model_dir, Some(|l| context.apply(|ctx| ctx.malloc_host::(l).sporulate())), )?; + info!("load host: {:?}", time.elapsed()); let load_layers = host.num_hidden_layers(); let (model, layers, kernels, transfer, compute) = context.apply(|ctx| { diff --git a/transformer-cpu/Cargo.toml b/transformer-cpu/Cargo.toml index d085d77a..04e91ec8 100644 --- a/transformer-cpu/Cargo.toml +++ b/transformer-cpu/Cargo.toml @@ -11,7 +11,7 @@ common = { path = "../common" } tensor = { path = "../tensor" } causal-lm = { path = "../causal-lm" } transformer = { path = "../transformer" } -itertools = "0.12" +itertools.workspace = true gemm = "0.17" intel-mkl-src = { version = "0.8", features = ["mkl-dynamic-lp64-iomp"] }