diff --git a/Cargo.toml b/Cargo.toml index ac5b288..292b573 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,7 +7,7 @@ members = [ "models/llama/common", "models/llama/common-cpu", - "models/llama/opencl", + # "models/llama/opencl", "models/llama/infini", "models/llama/cuda", @@ -34,7 +34,7 @@ itertools = "0.13" env_logger = "0.11" build-script-cfg = "0.0" -operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "df027a4", default-features = false } +operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "fd8f972", default-features = false } search-cl-tools = { git = "https://github.com/InfiniTensor/clrt", rev = "9b6289d" } search-infini-tools = { git = "https://github.com/InfiniTensor/infini-rt", rev = "f40bcb5" } diff --git a/common/src/lib.rs b/common/src/lib.rs index f370f31..edf0adf 100644 --- a/common/src/lib.rs +++ b/common/src/lib.rs @@ -1,4 +1,9 @@ -use std::{borrow::Borrow, collections::HashMap, hash::Hash, ops::Deref}; +use std::{ + borrow::Borrow, + collections::HashMap, + hash::Hash, + ops::{Deref, Range}, +}; pub enum Contiguous<'a, T> { Borrowed(&'a [u8]), @@ -52,3 +57,42 @@ impl Slab { self.0.entry(key).or_default().push(value); } } + +#[derive(Clone, Copy, Debug)] +pub struct Distribution { + pub start: usize, + pub len: usize, + pub total: usize, +} + +impl Distribution { + pub const MONO: Self = Self { + start: 0, + len: 1, + total: 1, + }; +} + +pub struct WeightMemCalculator { + align: usize, + size: usize, +} + +impl WeightMemCalculator { + #[inline] + pub const fn new(align: usize) -> Self { + Self { align, size: 0 } + } + + #[inline] + pub const fn size(&self) -> usize { + self.size + } + + #[inline] + pub fn push(&mut self, size: usize) -> Range { + let start = self.size.div_ceil(self.align) * self.align; + self.size = start + size; + start..self.size + } +} diff --git a/models/llama/common-cpu/src/infer.rs b/models/llama/common-cpu/src/infer.rs index 444b0ad..2bd9a76 100644 --- a/models/llama/common-cpu/src/infer.rs +++ b/models/llama/common-cpu/src/infer.rs @@ -1,4 +1,5 @@ use crate::{Operators, RandomSample, Weights}; +use common::Distribution; use gguf::GGufModel; use llama::{ext::ggml_quants::f16, LlamaRequest, LlamaStorage, LlamaWorker, Tensor}; use operators::{ @@ -8,7 +9,13 @@ use operators::{ Blob, }; use regex::Regex; -use std::{iter::zip, ptr::copy_nonoverlapping, slice::from_raw_parts_mut, thread}; +use std::{ + iter::zip, + ptr::copy_nonoverlapping, + slice::from_raw_parts_mut, + sync::{Arc, Barrier}, + thread, +}; use test_utils::{test_infer_paralle, Inference, Task, TokenizerAndPrompt, WorkerSeed}; type Worker<'w> = LlamaWorker, AllReduce>, Weights<'w>>; @@ -51,24 +58,28 @@ fn test_infer() { .collect() }) .unwrap_or_else(|| vec![1]); - let count = lens.iter().sum(); + let dist = lens.iter().sum(); println!("distribution: {lens:?}"); let (seeds, senders) = WorkerSeed::new(InprocNode::new(lens.len())); + let barrier = Arc::new(Barrier::new(dist + 1)); thread::scope(|s| { let _workers = zip(lens, seeds) .enumerate() .scan(0, |start, (id, (len, seed))| { - let range = *start..*start + len; - *start = range.end; - - let mut meta = model.meta.clone(); - meta.distribute(range.clone(), count); + let dist = Distribution { + start: *start, + len, + total: dist, + }; + *start += len; + let meta = model.meta.distribute(dist); let model = &model; + let barrier = barrier.clone(); Some(s.spawn(move || { let WorkerSeed { node, tasks } = seed; - let weights = Weights::new(model, range, count); + let weights = Weights::new(model, dist); let mut worker = Worker::new(id, &node, meta.clone(), weights); let mut cache = meta.kv_cache(meta.nctx).map(Blob::new); let sin_cos = ::build_sin_cos( @@ -85,6 +96,7 @@ fn test_infer() { from_raw_parts_mut(&mut pair as *mut _ as *mut u8, size_of_val(&pair)) }); + barrier.wait(); for task in tasks { let Task { nt, @@ -137,6 +149,7 @@ fn test_infer() { .collect::>(); let senders = senders.into_boxed_slice(); + barrier.wait(); test_infer_paralle(&model, senders, eos, tokenizer, &prompt, max_steps) }) } diff --git a/models/llama/common-cpu/src/lib.rs b/models/llama/common-cpu/src/lib.rs index f879d53..6001771 100644 --- a/models/llama/common-cpu/src/lib.rs +++ b/models/llama/common-cpu/src/lib.rs @@ -1,7 +1,7 @@ -use common::Contiguous; +use common::{Contiguous, Distribution}; use llama::{ ext::ggml_quants::{self, digit_layout::DigitLayout, f16, DataBlock, QuantExt}, - BlkWeight, LlamaBlkStorage, LlamaStorage, Tensor, + LlamaBlkStorage, LlamaBlkWeight, LlamaStorage, Tensor, TensorUsage::Computation, WeightLoader, }; @@ -16,7 +16,7 @@ use std::{ cell::{Ref, RefCell}, marker::PhantomData, mem::size_of, - ops::{Deref, Range, RangeBounds}, + ops::{Deref, Range}, ptr::copy_nonoverlapping, slice::{from_raw_parts, from_raw_parts_mut}, }; @@ -41,7 +41,7 @@ pub struct Weights<'w> { pub struct WeightCache { cache: Blob, - cached_weight: BlkWeight, + cached_weight: LlamaBlkWeight, cached_weight_iblk: usize, } @@ -85,11 +85,7 @@ where } impl<'w> Weights<'w> { - pub fn new( - model: &'w LlamaStorage<&'w [u8]>, - range: impl RangeBounds + Clone, - count: usize, - ) -> Self { + pub fn new(model: &'w LlamaStorage<&'w [u8]>, dist: Distribution) -> Self { let LlamaStorage { meta, output_norm, @@ -100,11 +96,18 @@ impl<'w> Weights<'w> { let blks = blocks .iter() - .map(|blk| blk.distribute(meta, range.clone(), count, Blob::new)) + .map(|blk| { + blk.clone() + .into_vec() + .into_iter() + .map(|(which, data)| { + (which, meta.distribute_data(which, data, dist, Blob::new)) + }) + .collect::>() + }) .collect::>(); - let mut meta = meta.clone(); - meta.distribute(range.clone(), count); + let meta = meta.distribute(dist); let size_qkv = meta.attn_qkv(Computation).take(); let size_o = meta.attn_o(Computation).take(); let size_gate_up = meta.ffn_gate_up(Computation).take(); @@ -113,7 +116,7 @@ impl<'w> Weights<'w> { let weight_cache = if meta.dt_embd == meta.dt_linear { RefCell::new(WeightCache { cache: Blob::new(0), - cached_weight: BlkWeight::AttnQKV, + cached_weight: LlamaBlkWeight::AttnQKV, cached_weight_iblk: 0, }) } else { @@ -131,7 +134,7 @@ impl<'w> Weights<'w> { RefCell::new(WeightCache { cache, - cached_weight: BlkWeight::AttnQKV, + cached_weight: LlamaBlkWeight::AttnQKV, cached_weight_iblk: 0, }) }; @@ -207,7 +210,7 @@ impl WeightLoader for Weights<'_> { #[inline] fn load_blk( &self, - which: BlkWeight, + which: LlamaBlkWeight, iblk: usize, _queue: &QueueOf, ) -> Self::Weight<'_> { @@ -233,10 +236,10 @@ impl WeightLoader for Weights<'_> { ffn_down, } = &blks[iblk]; - use BlkWeight::{ + use Dequant::{Borrowed, Cached}; + use LlamaBlkWeight::{ AttnNorm, AttnO, AttnQKV, AttnQKVBias, FfnDown, FfnGateInp, FfnGateUp, FfnNorm, }; - use Dequant::{Borrowed, Cached}; #[rustfmt::skip] match which { @@ -301,7 +304,7 @@ impl WeightLoader for Weights<'_> { fn load_moe<'a>( &'a self, - which: BlkWeight, + which: LlamaBlkWeight, iblk: usize, iexp: usize, _queue: &'a QueueOf, @@ -315,8 +318,8 @@ impl WeightLoader for Weights<'_> { } = self; assert_eq!(dt_embd, dt_mat); let w = match which { - BlkWeight::FfnGateUp => &*blks[iblk].ffn_gate_up, - BlkWeight::FfnDown => &*blks[iblk].ffn_down, + LlamaBlkWeight::FfnGateUp => &*blks[iblk].ffn_gate_up, + LlamaBlkWeight::FfnDown => &*blks[iblk].ffn_down, _ => unreachable!(), }; let one = w.len() / nexp; diff --git a/models/llama/common/src/compute.rs b/models/llama/common/src/compute.rs index ef405e1..095a89e 100644 --- a/models/llama/common/src/compute.rs +++ b/models/llama/common/src/compute.rs @@ -1,4 +1,4 @@ -use super::{args::Args, LlamaMeta}; +use super::{args::Args, LlamaBlkWeight, LlamaMeta}; use gguf::ggml_quants::{ digit_layout::{types as ty, DigitLayout}, f16, @@ -53,18 +53,6 @@ pub trait Operators { } } -#[derive(Clone, Copy, PartialEq, Eq, Debug)] -pub enum BlkWeight { - AttnNorm, - AttnQKV, - AttnQKVBias, - AttnO, - FfnNorm, - FfnGateInp, - FfnGateUp, - FfnDown, -} - pub trait WeightLoader { type Hardware: Hardware; type Weight<'s>: Deref]> + 's @@ -73,14 +61,14 @@ pub trait WeightLoader { fn load_blk<'a>( &'a self, - which: BlkWeight, + which: LlamaBlkWeight, iblk: usize, queue: &'a QueueOf, ) -> Self::Weight<'a>; fn load_moe<'a>( &'a self, - which: BlkWeight, + which: LlamaBlkWeight, iblk: usize, iexp: usize, queue: &'a QueueOf, @@ -638,7 +626,7 @@ impl WeightDecorator { iblk: usize, queue: &'a QueueOf, ) -> Tensor> { - let w = self.weights.load_blk(BlkWeight::AttnNorm, iblk, queue); + let w = self.weights.load_blk(LlamaBlkWeight::AttnNorm, iblk, queue); self.norm.clone().map(|_| w) } @@ -648,7 +636,7 @@ impl WeightDecorator { iblk: usize, queue: &'a QueueOf, ) -> Tensor> { - let w = self.weights.load_blk(BlkWeight::AttnQKV, iblk, queue); + let w = self.weights.load_blk(LlamaBlkWeight::AttnQKV, iblk, queue); self.attn_qkv.clone().map(|_| w) } @@ -658,7 +646,9 @@ impl WeightDecorator { iblk: usize, queue: &'a QueueOf, ) -> Tensor> { - let w = self.weights.load_blk(BlkWeight::AttnQKVBias, iblk, queue); + let w = self + .weights + .load_blk(LlamaBlkWeight::AttnQKVBias, iblk, queue); self.attn_qkv_bias.clone().map(|_| w) } @@ -668,7 +658,7 @@ impl WeightDecorator { iblk: usize, queue: &'a QueueOf, ) -> Tensor> { - let w = self.weights.load_blk(BlkWeight::AttnO, iblk, queue); + let w = self.weights.load_blk(LlamaBlkWeight::AttnO, iblk, queue); self.attn_o.clone().map(|_| w) } @@ -678,7 +668,7 @@ impl WeightDecorator { iblk: usize, queue: &'a QueueOf, ) -> Tensor> { - let w = self.weights.load_blk(BlkWeight::FfnNorm, iblk, queue); + let w = self.weights.load_blk(LlamaBlkWeight::FfnNorm, iblk, queue); self.norm.clone().map(|_| w) } @@ -688,7 +678,9 @@ impl WeightDecorator { iblk: usize, queue: &'a QueueOf, ) -> Tensor> { - let w = self.weights.load_blk(BlkWeight::FfnGateInp, iblk, queue); + let w = self + .weights + .load_blk(LlamaBlkWeight::FfnGateInp, iblk, queue); self.ffn_gate_inp.clone().map(|_| w) } @@ -699,7 +691,7 @@ impl WeightDecorator { iexp: usize, queue: &'a QueueOf, ) -> Tensor> { - const WHICH: BlkWeight = BlkWeight::FfnGateUp; + const WHICH: LlamaBlkWeight = LlamaBlkWeight::FfnGateUp; let w = if self.is_moe { self.weights.load_moe(WHICH, iblk, iexp, queue) } else { @@ -715,7 +707,7 @@ impl WeightDecorator { iexp: usize, queue: &'a QueueOf, ) -> Tensor> { - const WHICH: BlkWeight = BlkWeight::FfnDown; + const WHICH: LlamaBlkWeight = LlamaBlkWeight::FfnDown; let w = if self.is_moe { self.weights.load_moe(WHICH, iblk, iexp, queue) } else { diff --git a/models/llama/common/src/lib.rs b/models/llama/common/src/lib.rs index d9c8a0d..f05be00 100644 --- a/models/llama/common/src/lib.rs +++ b/models/llama/common/src/lib.rs @@ -2,11 +2,11 @@ mod args; mod compute; mod storage; +use common::Distribution; use gguf::ggml_quants::digit_layout::DigitLayout; -use std::ops::{Range, RangeBounds}; pub use args::{Args as LlamaArgs, Request as LlamaRequest}; -pub use compute::{BlkWeight, LlamaWorker, Operators, WeightLoader}; +pub use compute::{LlamaWorker, Operators, WeightLoader}; pub use storage::{BlkStorage as LlamaBlkStorage, Storage as LlamaStorage}; pub use tensor::{RandomSample, Tensor}; pub mod ext { @@ -45,16 +45,30 @@ pub enum TensorUsage { Computation, } -impl LlamaMeta { - pub fn distribute(&mut self, range: impl RangeBounds, count: usize) { - let len = normalize(range, count).len(); - assert!(0 < len && len <= count); - assert_eq!(self.nkvh % count, 0); - assert_eq!(self.di % count, 0); +#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] +pub enum LlamaBlkWeight { + AttnNorm, + AttnQKV, + AttnQKVBias, + AttnO, + FfnNorm, + FfnGateInp, + FfnGateUp, + FfnDown, +} - self.nh = self.nh / count * len; - self.nkvh = self.nkvh / count * len; - self.di = self.di / count * len; +impl LlamaMeta { + pub fn distribute(&self, Distribution { len, total, .. }: Distribution) -> Self { + assert!(0 < len && len <= total); + assert_eq!(self.nkvh % total, 0); + assert_eq!(self.di % total, 0); + + Self { + nh: self.nh / total * len, + nkvh: self.nkvh / total * len, + di: self.di / total * len, + ..self.clone() + } } #[inline] @@ -62,6 +76,25 @@ impl LlamaMeta { self.nexp > 0 } + pub fn blk(&self) -> LlamaBlkStorage { + use TensorUsage::Storage as TensorMem; + let norm = self.norm().take(); + LlamaBlkStorage { + attn_norm: norm, + attn_qkv: self.attn_qkv(TensorMem).take(), + attn_qkv_bias: if self.attn_qkv_bias { + self.attn_qkv_bias(TensorMem).take() + } else { + 0 + }, + attn_o: self.attn_o(TensorMem).take(), + ffn_norm: norm, + ffn_gate_inp: self.ffn_gate_inp(TensorMem).take(), + ffn_gate_up: self.ffn_gate_up(TensorMem).take(), + ffn_down: self.ffn_down(TensorMem).take(), + } + } + pub fn kv_cache(&self, buf: usize) -> Tensor { let &Self { dt_embd, @@ -162,19 +195,3 @@ impl LlamaMeta { } } } - -fn normalize(range: impl RangeBounds, count: usize) -> Range { - use std::ops::Bound::{Excluded, Included, Unbounded}; - let start = match range.start_bound() { - Included(&i) => i, - Excluded(&i) => i + 1, - Unbounded => 0, - }; - let end = match range.end_bound() { - Included(&i) => i + 1, - Excluded(&i) => i, - Unbounded => count, - }; - assert!(start < end && end <= count); - start..end -} diff --git a/models/llama/common/src/storage.rs b/models/llama/common/src/storage.rs index b803624..bc647de 100644 --- a/models/llama/common/src/storage.rs +++ b/models/llama/common/src/storage.rs @@ -1,7 +1,7 @@ -use crate::{normalize, LlamaMeta}; -use common::{borrow, own, Contiguous}; +use crate::{LlamaBlkWeight, LlamaMeta}; +use common::{borrow, own, Contiguous, Distribution}; use gguf::{GGufMetaMapExt, GGufModel}; -use std::ops::{DerefMut, RangeBounds}; +use std::ops::DerefMut; use tensor::{rearrange, split, Tensor}; #[derive(Clone)] @@ -13,7 +13,7 @@ pub struct Storage { pub blocks: Box<[BlkStorage]>, } -#[derive(Clone, Copy)] +#[derive(Clone)] pub struct BlkStorage { pub attn_norm: T, pub attn_qkv: T, @@ -88,79 +88,100 @@ impl<'a> Storage<&'a [u8]> { } impl BlkStorage { - pub fn map(self, mut f: impl FnMut(T) -> U) -> BlkStorage { - BlkStorage { - attn_norm: f(self.attn_norm), - attn_qkv: f(self.attn_qkv), - attn_qkv_bias: f(self.attn_qkv_bias), - attn_o: f(self.attn_o), - ffn_norm: f(self.ffn_norm), - ffn_gate_inp: f(self.ffn_gate_inp), - ffn_gate_up: f(self.ffn_gate_up), - ffn_down: f(self.ffn_down), - } + #[rustfmt::skip] + pub fn into_vec(self) -> Vec<(LlamaBlkWeight, T)> { + use LlamaBlkWeight as W; + vec![ + (W::AttnNorm , self.attn_norm ), + (W::AttnQKV , self.attn_qkv ), + (W::AttnQKVBias, self.attn_qkv_bias), + (W::AttnO , self.attn_o ), + (W::FfnNorm , self.ffn_norm ), + (W::FfnGateInp , self.ffn_gate_inp ), + (W::FfnGateUp , self.ffn_gate_up ), + (W::FfnDown , self.ffn_down ), + ] } +} - pub fn as_ref(&self) -> BlkStorage<&T> { +impl FromIterator<(LlamaBlkWeight, T)> for BlkStorage { + #[rustfmt::skip] + fn from_iter(iter: U) -> Self + where + U: IntoIterator, + { + let mut collector: BlkStorage> = BlkStorage { + attn_norm : None, + attn_qkv : None, + attn_qkv_bias: None, + attn_o : None, + ffn_norm : None, + ffn_gate_inp : None, + ffn_gate_up : None, + ffn_down : None, + }; + for (which, data) in iter { + use LlamaBlkWeight as W; + match which { + W::AttnNorm => collector.attn_norm = Some(data), + W::AttnQKV => collector.attn_qkv = Some(data), + W::AttnQKVBias => collector.attn_qkv_bias = Some(data), + W::AttnO => collector.attn_o = Some(data), + W::FfnNorm => collector.ffn_norm = Some(data), + W::FfnGateInp => collector.ffn_gate_inp = Some(data), + W::FfnGateUp => collector.ffn_gate_up = Some(data), + W::FfnDown => collector.ffn_down = Some(data), + }; + } BlkStorage { - attn_norm: &self.attn_norm, - attn_qkv: &self.attn_qkv, - attn_qkv_bias: &self.attn_qkv_bias, - attn_o: &self.attn_o, - ffn_norm: &self.ffn_norm, - ffn_gate_inp: &self.ffn_gate_inp, - ffn_gate_up: &self.ffn_gate_up, - ffn_down: &self.ffn_down, + attn_norm : collector.attn_norm .unwrap(), + attn_qkv : collector.attn_qkv .unwrap(), + attn_qkv_bias: collector.attn_qkv_bias.unwrap(), + attn_o : collector.attn_o .unwrap(), + ffn_norm : collector.ffn_norm .unwrap(), + ffn_gate_inp : collector.ffn_gate_inp .unwrap(), + ffn_gate_up : collector.ffn_gate_up .unwrap(), + ffn_down : collector.ffn_down .unwrap(), } } } -impl<'w> BlkStorage<&'w [u8]> { - pub fn distribute( +impl LlamaMeta { + pub fn distribute_data<'w, U>( &self, - meta: &LlamaMeta, - range: impl RangeBounds, - count: usize, + which: LlamaBlkWeight, + data: &'w [u8], + Distribution { start, len, total }: Distribution, mut f: impl FnMut(usize) -> U, - ) -> BlkStorage> + ) -> Contiguous<'w, U> where U: DerefMut, { - let range = normalize(range, count); - let start = range.start; - let len = range.len(); - assert!(0 < len && range.end <= count); - - let &LlamaMeta { - nh, nkvh, dh, di, .. - } = meta; - assert_eq!(nkvh % count, 0); - assert_eq!(di % count, 0); - - let mut dis = meta.clone(); - dis.distribute(range.clone(), count); - + assert!(0 < len && start + len <= total); use crate::TensorUsage::Storage as TensorMem; - BlkStorage { - attn_norm: borrow(self.attn_norm), - attn_qkv: if len == count { - borrow(self.attn_qkv) - } else { + use LlamaBlkWeight as W; + match which { + W::AttnNorm | W::FfnNorm | W::FfnGateInp => borrow(data), + _ if len == total || data.is_empty() => borrow(data), + W::AttnQKV => { + let &LlamaMeta { nh, nkvh, dh, .. } = self; + let dist = self.distribute(Distribution { start, len, total }); + let dq = nh * dh; let dkv = nkvh * dh; - let qkv = meta.attn_qkv(TensorMem).map(|_| self.attn_qkv); + let qkv = self.attn_qkv(TensorMem).map(|_| data); split!(qkv => q, k, v; [dq, dkv, dkv] @ 0); - let dq = dq / count; - let dkv = dkv / count; + let dq = dq / total; + let dkv = dkv / total; let q = q.slice(0, dq * start, 1, dq * len); let k = k.slice(0, dkv * start, 1, dkv * len); let v = v.slice(0, dkv * start, 1, dkv * len); debug_assert!(q.is_contiguous() && k.is_contiguous() && v.is_contiguous()); - let mut ans = dis.attn_qkv(TensorMem).map(&mut f); + let mut ans = dist.attn_qkv(TensorMem).map(&mut f); { let ans = ans.map_slice_mut(); split!(ans => q_, k_, v_; [dq * len , dkv * len, dkv * len] @ 0); @@ -172,22 +193,22 @@ impl<'w> BlkStorage<&'w [u8]> { rearrange(&mut v_, &v); } own(ans.take()) - }, - attn_qkv_bias: if len == count { - borrow(self.attn_qkv_bias) - } else { + } + W::AttnQKVBias => { + let &LlamaMeta { nh, nkvh, dh, .. } = self; + let dq = nh * dh; let dkv = nkvh * dh; let d = dq + 2 * dkv; - let total = self.attn_qkv_bias.len(); + let total = data.len(); assert_eq!(total % d, 0); - let (q, kv) = self.attn_qkv_bias.split_at(dq * total / d); + let (q, kv) = data.split_at(dq * total / d); let (k, v) = kv.split_at(kv.len() / 2); - let q = &q[q.len() * start / count..][..q.len() * len / count]; - let k = &k[k.len() * start / count..][..k.len() * len / count]; - let v = &v[v.len() * start / count..][..v.len() * len / count]; + let q = &q[q.len() * start / total..][..q.len() * len / total]; + let k = &k[k.len() * start / total..][..k.len() * len / total]; + let v = &v[v.len() * start / total..][..v.len() * len / total]; let mut ans = f(q.len() + k.len() + v.len()); let (q_, kv_) = ans.split_at_mut(q.len()); @@ -197,33 +218,30 @@ impl<'w> BlkStorage<&'w [u8]> { v_.copy_from_slice(v); own(ans) - }, - attn_o: if len == count { - borrow(self.attn_o) - } else { - let o = meta.attn_o(TensorMem).map(|_| self.attn_o); + } + W::AttnO => { + let o = self.attn_o(TensorMem).map(|_| data); - let d = o.shape()[1] / count; + let d = o.shape()[1] / total; let o = o.slice(1, d * start, 1, d * len); let mut o_ = Tensor::new(o.dt(), o.shape()).map(&mut f); rearrange(&mut o_, &o); own(o_.take()) - }, - ffn_norm: borrow(self.ffn_norm), - ffn_gate_inp: borrow(self.ffn_gate_inp), - ffn_gate_up: if len == count { - borrow(self.ffn_gate_up) - } else { - let gu = meta.ffn_gate_up(TensorMem).map(|_| self.ffn_gate_up); + } + W::FfnGateUp => { + let &LlamaMeta { di, .. } = self; + let dist = self.distribute(Distribution { start, len, total }); + + let gu = self.ffn_gate_up(TensorMem).map(|_| data); split!(gu => g, u; [di, di] @ 1); - let di = di / count; + let di = di / total; let g = g.slice(1, di * start, 1, di * len); let u = u.slice(1, di * start, 1, di * len); - let mut ans = dis.ffn_gate_up(TensorMem).map(&mut f); + let mut ans = dist.ffn_gate_up(TensorMem).map(&mut f); { let ans = ans.map_slice_mut(); split!(ans => g_, u_; [di * len , di * len] @ 1); @@ -233,19 +251,17 @@ impl<'w> BlkStorage<&'w [u8]> { rearrange(&mut u_, &u); } own(ans.take()) - }, - ffn_down: if len == count { - borrow(self.ffn_down) - } else { - let down = meta.ffn_down(TensorMem).map(|_| self.ffn_down); + } + W::FfnDown => { + let down = self.ffn_down(TensorMem).map(|_| data); - let d = down.shape()[2] / count; + let d = down.shape()[2] / total; let down = down.slice(2, d * start, 1, d * len); let mut down_ = Tensor::new(down.dt(), down.shape()).map(&mut f); rearrange(&mut down_, &down); own(down_.take()) - }, + } } } } diff --git a/models/llama/cuda/src/infer.rs b/models/llama/cuda/src/infer.rs index ae96df0..ea0ac71 100644 --- a/models/llama/cuda/src/infer.rs +++ b/models/llama/cuda/src/infer.rs @@ -1,4 +1,5 @@ use crate::{Operators, RandomSample, Weights}; +use common::Distribution; use gguf::GGufModel; use llama::{ ext::ggml_quants::f16, LlamaArgs, LlamaMeta, LlamaRequest, LlamaStorage, LlamaWorker, Tensor, @@ -9,7 +10,7 @@ use operators::{ Alloc, QueueAlloc, }; use std::{slice::from_raw_parts_mut, time::Instant, usize}; -use test_utils::{load_roll_cache_size, Inference, TokenizerAndPrompt}; +use test_utils::{Inference, TokenizerAndPrompt}; type Worker<'w> = LlamaWorker>; @@ -46,9 +47,6 @@ fn test_infer() { let device = devices.map_or(0, |devices| devices.parse().unwrap()); println!("using gpu{device}"); - let roll_cache_size = load_roll_cache_size(); - println!("roll_cache_size: {roll_cache_size}"); - let gpu = match cuda::init() { Ok(()) => Device::new(device), Err(NoDevice) => return, @@ -70,7 +68,7 @@ fn test_infer() { let time = Instant::now(); let token_embd = stream.ctx().from_host(model.token_embd); - let weights = Weights::new(&model, .., 1, roll_cache_size, ctx); + let weights = Weights::new(&model, Distribution::MONO, ctx); println!("load weights: {:?}", time.elapsed()); let (free, _) = ctx.mem_info(); diff --git a/models/llama/cuda/src/lib.rs b/models/llama/cuda/src/lib.rs index 286fe8f..7ebd3fd 100644 --- a/models/llama/cuda/src/lib.rs +++ b/models/llama/cuda/src/lib.rs @@ -1,21 +1,20 @@ #![cfg(any(use_nvidia, use_iluvatar))] -use common::{Contiguous, Slab}; -use llama::{BlkWeight, LlamaBlkStorage, LlamaStorage, Tensor, WeightLoader}; +use common::{Contiguous, Distribution, Slab, WeightMemCalculator}; +use llama::{LlamaBlkStorage, LlamaBlkWeight, LlamaStorage, Tensor, WeightLoader}; use log::trace; use operators::{ all_reduce::{AllReduce, NonAllReduce}, - cuda::{memcpy_d2h, AsRaw, CurrentCtx, DevByte, DevMem, Event, Gpu, HostMem, Stream}, + cuda::{self, memcpy_d2h, AsRaw, CurrentCtx, DevByte, DevMem, Gpu}, random_sample::cuda::Operator as RandomSampleGpu, rearrange::cuda::Operator as Rearrange, ByteOf, QueueOf, TopoNode, }; use std::{ - cell::{RefCell, RefMut}, + collections::VecDeque, + iter::zip, marker::PhantomData, - mem::replace, - ops::{Deref, RangeBounds}, - rc::Rc, + ops::{Deref, Range}, time::Instant, }; @@ -23,99 +22,6 @@ pub struct Operators>(PhantomData<(N, pub type RandomSample = llama::RandomSample; -pub struct Weights<'ctx> { - nexp: usize, - blks: LlamaBlkStorage>, - output_norm: DevMem<'ctx>, - output: DevMem<'ctx>, -} - -pub enum Cache<'ctx> { - Static(Box<[DevMem<'ctx>]>), - Rolling { - stream: Rc>, - host: Box<[HostMem<'ctx>]>, - dev: RefCell>, - }, -} - -pub struct RollCache<'ctx> { - global_idx: usize, - local_idx: usize, - nblk: usize, - cache: Box<[(DevMem<'ctx>, Event<'ctx>)]>, -} - -impl<'ctx> RollCache<'ctx> { - pub fn new(nblk: usize, cache: Box<[(DevMem<'ctx>, Event<'ctx>)]>) -> Self { - Self { - global_idx: 0, - local_idx: 0, - nblk, - cache, - } - } - - pub fn first_event(&self) -> &Event<'ctx> { - let (_, ref event) = self.cache[self.local_idx]; - event - } -} - -pub enum WeightResult<'s, 'ctx> { - RollCached { - roll_cache: RefMut<'s, RollCache<'ctx>>, - load_stream: &'s Stream<'ctx>, - host: &'s [HostMem<'ctx>], - compute_stream: &'s Stream<'s>, - }, - Borrowed(&'s [DevByte]), -} - -impl Deref for WeightResult<'_, '_> { - type Target = [DevByte]; - - fn deref(&self) -> &Self::Target { - match self { - WeightResult::RollCached { roll_cache, .. } => { - &roll_cache.cache[roll_cache.local_idx].0 - } - WeightResult::Borrowed(dev_mem) => dev_mem, - } - } -} - -impl Drop for WeightResult<'_, '_> { - fn drop(&mut self) { - match self { - WeightResult::RollCached { - roll_cache, - load_stream, - host, - compute_stream, - } => { - // wait for the compute to finish - load_stream.wait_for(&compute_stream.record()); - - let next_load_idx = - (roll_cache.global_idx + roll_cache.cache.len()) % roll_cache.nblk; - let host = &host[next_load_idx]; - - roll_cache.global_idx = (roll_cache.global_idx + 1) % roll_cache.nblk; - - let start_idx = roll_cache.local_idx; - let (dev_mem, event) = &mut roll_cache.cache[start_idx]; - assert!(dev_mem.len() == host.len()); - load_stream.memcpy_h2d(dev_mem, host); - *event = load_stream.record(); - - roll_cache.local_idx = (roll_cache.local_idx + 1) % roll_cache.cache.len(); - } - WeightResult::Borrowed(_) => {} - } - } -} - macro_rules! op { ($name:ident) => { operators::$name::cuda::Operator @@ -158,221 +64,156 @@ where } } -impl<'blk> Weights<'blk> { - pub fn new( - model: &LlamaStorage<&'_ [u8]>, - range: impl RangeBounds + Clone, - count: usize, - pool_size: usize, - ctx: &'blk CurrentCtx, - ) -> Self { - assert!(pool_size > 0); - let stream = Rc::new(ctx.stream()); - let igpu = unsafe { ctx.dev().as_raw() }; - let mut slab = Slab::new(); - let blks = if pool_size < model.meta.nblk { - let mut blks_host = model.blocks[0] - .as_ref() - .map(|_| Vec::with_capacity(model.meta.nblk)); - for (iblk, blk) in model.blocks.iter().enumerate() { - let time = Instant::now(); - let blk = blk - .distribute(&model.meta, range.clone(), count, |len| { - ctx.malloc_host::(len) - }) - .map(|host| match host { - Contiguous::Borrowed(host) => { - let mut ans = ctx.malloc_host::(host.len()); - ans.copy_from_slice(host); - ans - } - Contiguous::Owned(host) => host, - }); +pub struct Weights<'ctx> { + nexp: usize, + mem: DevMem<'ctx>, + blks: Box<[LlamaBlkStorage>]>, + output_norm: Range, + output: Range, +} - macro_rules! push { - ($( $ident:ident )+ ) => { - $({ blks_host.$ident.push(blk.$ident); })+ - }; - } - push! { - attn_norm - attn_qkv - attn_qkv_bias - attn_o - ffn_norm - ffn_gate_up - ffn_down - } - trace!("blk{iblk} loaded to gpu{igpu} in {:?}", time.elapsed()) - } - blks_host.map(|vec| { - let roll_cache = vec - .iter() - .take(pool_size) - .map(|host| (ctx.from_host(host), stream.record())) - .collect::>(); - Cache::Rolling { - stream: stream.clone(), - host: vec.into_boxed_slice(), - dev: RefCell::new(RollCache::new(model.meta.nblk, roll_cache)), - } +impl<'ctx> Weights<'ctx> { + pub fn new(model: &LlamaStorage<&[u8]>, dist: Distribution, ctx: &'ctx CurrentCtx) -> Self { + let LlamaStorage { + meta, + output_norm, + output, + blocks, + .. + } = model; + + let mut calculator = WeightMemCalculator::new(ctx.dev().alignment()); + let meta_dist = meta.distribute(dist); + let blk_size = meta_dist.blk(); + let off_blks = (0..meta_dist.nblk) + .map(|_| { + blk_size + .clone() + .into_vec() + .into_iter() + .map(|(which, size)| (which, calculator.push(size))) + .collect::>() }) - } else { - let mut loader = None; - let mut blks_dev = model.blocks[0] - .as_ref() - .map(|_| Vec::with_capacity(model.meta.nblk)); - for (iblk, blk) in model.blocks.iter().enumerate() { - let blk = blk.distribute(&model.meta, range.clone(), count, |size| { - slab.take(&size) - .unwrap_or_else(|| ctx.malloc_host::(size)) - }); - let loader = loader - .get_or_insert_with(|| blk.as_ref().map(|s| H2DLoader::new(s.len(), &stream))); + .collect::>(); + let off_output_norm = calculator.push(output_norm.len()); + let off_output = calculator.push(output.len()); + + let mut mem = ctx.malloc::(calculator.size()); + let mut slab = Slab::::new(); + let mut queue = VecDeque::new(); + let stream = ctx.stream(); + + macro_rules! host { + ($l:expr) => { + slab.take(&$l).unwrap_or_else(|| ctx.malloc_host::($l)) + }; + } - macro_rules! load { - ($( $ident:ident )+ ) => { - $( - let (dev, host) = loader.$ident.load(blk.$ident, &stream); - if let Some(host) = host { - slab.put(host.len(), host) - } - blks_dev.$ident.push(dev); - )+ - }; + for (blk, off) in zip(blocks, off_blks.clone()) { + let blk = blk.clone().into_vec(); + let off = off.into_vec(); + for ((which, data), (which_, off)) in zip(blk, off) { + assert_eq!(which, which_); + if off.is_empty() { + continue; } - let time = Instant::now(); - load! { - attn_norm - attn_qkv - attn_o - ffn_norm - ffn_gate_inp - ffn_gate_up - ffn_down - } - trace!("blk{iblk} loaded to gpu{igpu} in {:?}", time.elapsed()) + let data = meta.distribute_data(which, data, dist, |l| host!(l)); + let data = match data { + Contiguous::Borrowed(data) => { + let mut mem = host!(data.len()); + mem.copy_from_slice(data); + mem + } + Contiguous::Owned(data) => data, + }; + stream.memcpy_h2d(&mut mem[off], &data); + queue.push_back((stream.record(), Instant::now(), data)) } - blks_dev.map(|vec| Cache::Static(vec.into_boxed_slice())) - }; - Self { - nexp: model.meta.nexp, - blks, - output_norm: ctx.from_host(model.output_norm), - output: ctx.from_host(model.output), + while let Some((event, _, _)) = queue.front() { + unsafe { + use cuda::bindings::{cuEventQuery, cudaError_enum as Res}; + match cuEventQuery(event.as_raw()) { + Res::CUDA_SUCCESS => { + let (_, time, data) = queue.pop_front().unwrap(); + trace!("{:>16}bytes copied in {:?}", data.len(), time.elapsed()); + slab.put(data.len(), data) + } + Res::CUDA_ERROR_NOT_READY => break, + err => panic!("Unexpected CUDA error: {err:?}"), + } + } + } } - } -} + stream.memcpy_h2d(&mut mem[off_output_norm.clone()], output_norm); + stream.memcpy_h2d(&mut mem[off_output.clone()], output); -struct H2DLoader<'ctx> { - event: Event<'ctx>, - host: HostMem<'ctx>, - dev: DevMem<'ctx>, -} - -impl<'ctx> H2DLoader<'ctx> { - fn new(size: usize, stream: &Stream<'ctx>) -> Self { Self { - event: stream.record(), - host: stream.ctx().malloc_host::(size), - dev: stream.ctx().malloc::(size), + nexp: meta.nexp, + mem, + blks: off_blks.into_boxed_slice(), + output_norm: off_output_norm, + output: off_output, } } - - fn load( - &mut self, - host: Contiguous>, - stream: &Stream<'ctx>, - ) -> (DevMem<'ctx>, Option>) { - self.event.synchronize(); - let cache = match host { - Contiguous::Borrowed(host) => { - self.host.copy_from_slice(host); - None - } - Contiguous::Owned(host) => Some(replace(&mut self.host, host)), - }; - stream.memcpy_h2d(&mut self.dev, &self.host); - self.event = stream.record(); - ( - replace(&mut self.dev, stream.ctx().malloc::(self.host.len())), - cache, - ) - } } -impl<'ctx> WeightLoader for Weights<'ctx> { +impl WeightLoader for Weights<'_> { type Hardware = Gpu; + type Weight<'s> - = WeightResult<'s, 'ctx> + = &'s [DevByte] where Self: 's; - #[inline] - fn load_blk<'s>( - &'s self, - which: BlkWeight, + fn load_blk<'a>( + &'a self, + which: LlamaBlkWeight, iblk: usize, - queue: &'s QueueOf, - ) -> Self::Weight<'s> { - let cache = match which { - BlkWeight::AttnNorm => &self.blks.attn_norm, - BlkWeight::AttnQKV => &self.blks.attn_qkv, - BlkWeight::AttnQKVBias => &self.blks.attn_qkv_bias, - BlkWeight::AttnO => &self.blks.attn_o, - BlkWeight::FfnNorm => &self.blks.ffn_norm, - BlkWeight::FfnGateInp => &self.blks.ffn_gate_inp, - BlkWeight::FfnGateUp => &self.blks.ffn_gate_up, - BlkWeight::FfnDown => &self.blks.ffn_down, + _queue: &'a QueueOf, + ) -> Self::Weight<'a> { + let off = &self.blks[iblk]; + use LlamaBlkWeight as W; + #[rustfmt::skip] + let off = match which { + W::AttnNorm => &off.attn_norm , + W::AttnQKV => &off.attn_qkv , + W::AttnQKVBias => &off.attn_qkv_bias, + W::AttnO => &off.attn_o , + W::FfnNorm => &off.ffn_norm , + W::FfnGateInp => &off.ffn_gate_inp , + W::FfnGateUp => &off.ffn_gate_up , + W::FfnDown => &off.ffn_down , }; - - match cache { - Cache::Static(dev) => WeightResult::Borrowed(&dev[iblk]), - Cache::Rolling { stream, host, dev } => { - let roll_cache = dev.borrow_mut(); - queue.wait_for(roll_cache.first_event()); - assert!(iblk == roll_cache.global_idx); - WeightResult::RollCached { - roll_cache, - load_stream: stream, - host, - compute_stream: queue, - } - } - } + &self.mem[off.clone()] } fn load_moe<'a>( &'a self, - which: BlkWeight, + which: LlamaBlkWeight, iblk: usize, iexp: usize, _queue: &'a QueueOf, ) -> Self::Weight<'a> { - let cache = match which { - BlkWeight::FfnGateUp => &self.blks.ffn_gate_up, - BlkWeight::FfnDown => &self.blks.ffn_down, - _ => unreachable!(), + let off = &self.blks[iblk]; + use LlamaBlkWeight as W; + #[rustfmt::skip] + let off = match which { + W::FfnGateUp => &off.ffn_gate_up, + W::FfnDown => &off.ffn_down , + _ => unreachable!() , }; - match cache { - Cache::Static(dev) => { - let w = &dev[iblk]; - let one = w.len() / self.nexp; - WeightResult::Borrowed(&w[iexp * one..][..one]) - } - Cache::Rolling { .. } => todo!(), - } + let w = &self.mem[off.clone()]; + let one = w.len() / self.nexp; + &w[iexp * one..][..one] } - #[inline] - fn output_norm(&self, _queue: &QueueOf) -> Self::Weight<'_> { - WeightResult::Borrowed(&self.output_norm) + fn output_norm<'a>(&'a self, _queue: &'a QueueOf) -> Self::Weight<'a> { + &self.mem[self.output_norm.clone()] } - #[inline] - fn output(&self, _queue: &QueueOf) -> Self::Weight<'_> { - WeightResult::Borrowed(&self.output) + fn output<'a>(&'a self, _queue: &'a QueueOf) -> Self::Weight<'a> { + &self.mem[self.output.clone()] } } diff --git a/models/llama/cuda/src/nccl_parallel.rs b/models/llama/cuda/src/nccl_parallel.rs index 95c21b7..4aa8b3c 100644 --- a/models/llama/cuda/src/nccl_parallel.rs +++ b/models/llama/cuda/src/nccl_parallel.rs @@ -1,4 +1,5 @@ use crate::{Operators, RandomSample, Weights}; +use common::Distribution; use gguf::GGufModel; use llama::{ext::ggml_quants::f16, LlamaRequest, LlamaStorage, LlamaWorker, Tensor}; use log::info; @@ -13,6 +14,7 @@ use regex::Regex; use std::{ iter::zip, slice::{from_raw_parts, from_raw_parts_mut}, + sync::{Arc, Barrier}, thread, u64, }; use test_utils::{test_infer_paralle, Inference, Task, TokenizerAndPrompt, WorkerSeed}; @@ -58,7 +60,7 @@ fn test_infer() { }) .unwrap_or_else(|| vec![0]); let lens = vec![1; devices.len()]; - let count = devices.len(); + let dist = devices.len(); println!("distribution: {devices:?}"); let (seeds, senders) = match cuda::init() { @@ -71,17 +73,21 @@ fn test_infer() { ), Err(NoDevice) => return, }; + let barrier = Arc::new(Barrier::new(dist + 1)); thread::scope(|s| { let _workers = zip(lens, seeds) .enumerate() .scan(0, |start, (id, (len, seed))| { - let range = *start..*start + len; - *start = range.end; - - let mut meta = model.meta.clone(); - meta.distribute(range.clone(), count); - + let dist = Distribution { + start: *start, + len, + total: dist, + }; + *start += len; + + let meta = model.meta.distribute(dist); let model = &model; + let barrir = barrier.clone(); Some(s.spawn(move || { info!("worker[{id}] started"); let WorkerSeed { node, tasks } = seed; @@ -93,7 +99,7 @@ fn test_infer() { let _ = stream.malloc::((free.0 >> 30).saturating_sub(1) << 30); info!("worker[{id}] loading weights..."); - let weights = Weights::new(model, range, count, usize::MAX, ctx); + let weights = Weights::new(model, dist, ctx); let mut worker = Worker::new(id, &node, meta.clone(), weights); info!("worker[{id}] created"); let mut cache = meta @@ -111,6 +117,7 @@ fn test_infer() { let mut pair = KVPair::new(0, f16::ZERO); let mut pairs = Tensor::kv_pair_vec(1, |size| stream.malloc::(size)); + barrir.wait(); for task in tasks { let Task { nt, @@ -177,6 +184,7 @@ fn test_infer() { .collect::>(); let senders = senders.into_boxed_slice(); + barrier.wait(); test_infer_paralle(&model, senders, eos, tokenizer, &prompt, max_steps) }) } diff --git a/models/llama/infini/Cargo.toml b/models/llama/infini/Cargo.toml index 6a1b2d9..6035406 100644 --- a/models/llama/infini/Cargo.toml +++ b/models/llama/infini/Cargo.toml @@ -9,6 +9,7 @@ authors = ["YdrMaster "] [dependencies] llama.path = "../common" common.workspace = true +log.workspace = true operators = { workspace = true, features = ["infini"] } [build-dependencies] diff --git a/models/llama/infini/src/infer.rs b/models/llama/infini/src/infer.rs index ff3882f..b185518 100644 --- a/models/llama/infini/src/infer.rs +++ b/models/llama/infini/src/infer.rs @@ -1,4 +1,5 @@ use crate::{Operators, RandomSample, Weights}; +use common::Distribution; use gguf::{ggml_quants::digit_layout::types, GGufModel}; use llama::{ext::ggml_quants::f16, LlamaRequest, LlamaStorage, LlamaWorker, Tensor}; use operators::{ @@ -11,6 +12,7 @@ use regex::Regex; use std::{ iter::zip, slice::{from_raw_parts, from_raw_parts_mut}, + sync::{Arc, Barrier}, thread, }; use test_utils::{test_infer_paralle, Inference, Task, TokenizerAndPrompt, WorkerSeed}; @@ -60,7 +62,7 @@ fn test_infer() { }) .unwrap_or_else(|| ("cpu".into(), vec![0])); let lens = vec![1; indices.len()]; - let count = indices.len(); + let dist = indices.len(); println!("{ty}; distribution: {indices:?}"); let (seeds, senders) = match &*ty { @@ -82,23 +84,28 @@ fn test_infer() { } _ => todo!(), }; + let barrier = Arc::new(Barrier::new(dist + 1)); thread::scope(|s| { let _workers = zip(lens, seeds) .enumerate() .scan(0, |start, (id, (len, seed))| { - let range = *start..*start + len; - *start = range.end; - - let mut meta = model.meta.clone(); - meta.distribute(range.clone(), count); + let dist = Distribution { + start: *start, + len, + total: dist, + }; + *start += len; + let meta = model.meta.distribute(dist); let model = &model; + let barrier = barrier.clone(); Some(s.spawn(move || { let WorkerSeed { node, tasks } = seed; let device = node.processor(); - let stream = device.stream(); - let weights = Weights::new(model, range, count, &stream); + let weights = Weights::new(model, dist, device); let mut worker = Worker::new(id, &node, meta.clone(), weights); + + let stream = device.stream(); let mut cache = meta .kv_cache(meta.nctx) .map(|size| stream.malloc::(size)); @@ -111,7 +118,7 @@ fn test_infer() { let sample = RandomSample::new(&node); let indices = RandomSample::build_indices(model.meta.nvoc, &stream); - + barrier.wait(); for task in tasks { let Task { nt, @@ -170,6 +177,7 @@ fn test_infer() { .collect::>(); let senders = senders.into_boxed_slice(); + barrier.wait(); test_infer_paralle(&model, senders, eos, tokenizer, &prompt, max_steps) }) } diff --git a/models/llama/infini/src/lib.rs b/models/llama/infini/src/lib.rs index ca3ff9d..2fe8c72 100644 --- a/models/llama/infini/src/lib.rs +++ b/models/llama/infini/src/lib.rs @@ -1,31 +1,27 @@ #![cfg(detected)] -use common::Contiguous; -use llama::{BlkWeight, LlamaBlkStorage, LlamaStorage, Tensor, WeightLoader}; +use common::{Contiguous, Distribution, Slab, WeightMemCalculator}; +use llama::{LlamaBlkStorage, LlamaBlkWeight, LlamaStorage, Tensor, WeightLoader}; +use log::trace; use operators::{ all_reduce::{infini::Operator as InfiniAllReduce, AllReduce}, infini::{Device, InfiniNode}, - infini_rt::{DevBlob, DevByte, Event, HostBlob, Stream}, + infini_rt::{DevBlob, DevByte}, random_sample::infini::Operator as RandomSampleNpu, ByteOf, QueueOf, TopoNode, }; use std::{ + collections::VecDeque, + iter::zip, marker::PhantomData, - mem::replace, - ops::{Deref, RangeBounds}, + ops::{Deref, Range}, + time::Instant, }; pub struct Operators(PhantomData<(N, R)>); pub type RandomSample = llama::RandomSample; -pub struct Weights { - nexp: usize, - blks: Box<[LlamaBlkStorage]>, - output_norm: DevBlob, - output: DevBlob, -} - macro_rules! op { ($name:ident) => { operators::$name::infini::Operator @@ -69,134 +65,153 @@ where } } +pub struct Weights { + nexp: usize, + mem: DevBlob, + blks: Box<[LlamaBlkStorage>]>, + output_norm: Range, + output: Range, +} + impl Weights { - pub fn new( - model: &LlamaStorage<&'_ [u8]>, - range: impl RangeBounds + Clone, - count: usize, - stream: &Stream, - ) -> Self { - let device = stream.get_device(); - let mut loader = None; - Self { - nexp: model.meta.nexp, - blks: model - .blocks - .iter() - .map(|blk| { - let blk = blk.distribute(&model.meta, range.clone(), count, |len| { - device.malloc_host::(len) - }); - let loader = loader.get_or_insert_with(|| { - blk.as_ref().map(|s| H2DLoader::new(s.len(), stream)) - }); - macro_rules! load { - ($( $ident:ident )+ ) => { - LlamaBlkStorage{ - $( $ident: loader.$ident.load(blk.$ident, stream) ),+ - } - }; - } - load! { - attn_norm - attn_qkv - attn_qkv_bias - attn_o - ffn_norm - ffn_gate_inp - ffn_gate_up - ffn_down - } - }) - .collect(), - output_norm: device.from_host(model.output_norm), - output: device.from_host(model.output), + pub fn new(model: &LlamaStorage<&[u8]>, dist: Distribution, dev: &Device) -> Self { + let LlamaStorage { + meta, + output_norm, + output, + blocks, + .. + } = model; + + let mut calculator = WeightMemCalculator::new(size_of::()); + let meta_dist = meta.distribute(dist); + let blk_size = meta_dist.blk(); + let off_blks = (0..meta_dist.nblk) + .map(|_| { + blk_size + .clone() + .into_vec() + .into_iter() + .map(|(which, size)| (which, calculator.push(size))) + .collect::>() + }) + .collect::>(); + let off_output_norm = calculator.push(output_norm.len()); + let off_output = calculator.push(output.len()); + + let mut mem = dev.malloc::(calculator.size()); + let mut slab = Slab::::new(); + let mut queue = VecDeque::new(); + let stream = dev.stream(); + + macro_rules! host { + ($l:expr) => { + slab.take(&$l).unwrap_or_else(|| dev.malloc_host::($l)) + }; } - } -} -struct H2DLoader { - event: Event, - host: HostBlob, - dev: DevBlob, -} + for (blk, off) in zip(blocks, off_blks.clone()) { + let blk = blk.clone().into_vec(); + let off = off.into_vec(); + for ((which, data), (which_, off)) in zip(blk, off) { + assert_eq!(which, which_); + if off.is_empty() { + continue; + } + let data = meta.distribute_data(which, data, dist, |l| host!(l)); + let data = match data { + Contiguous::Borrowed(data) => { + let mut mem = host!(data.len()); + mem.copy_from_slice(data); + mem + } + Contiguous::Owned(data) => data, + }; + stream.memcpy_h2d(&mut mem[off], &data); + let mut event = dev.event(); + stream.record(&mut event); + queue.push_back((event, Instant::now(), data)) + } + + while let Some((event, _, _)) = queue.front() { + if event.is_complete() { + let (_, time, data) = queue.pop_front().unwrap(); + trace!("{:>16}bytes copied in {:?}", data.len(), time.elapsed()); + slab.put(data.len(), data) + } else { + break; + } + } + } + stream.memcpy_h2d(&mut mem[off_output_norm.clone()], output_norm); + stream.memcpy_h2d(&mut mem[off_output.clone()], output); -impl H2DLoader { - fn new(size: usize, stream: &Stream) -> Self { - let device = stream.get_device(); - let mut event = device.event(); - stream.record(&mut event); Self { - event, - host: device.malloc_host::(size), - dev: device.malloc::(size), + nexp: meta.nexp, + mem, + blks: off_blks.into_boxed_slice(), + output_norm: off_output_norm, + output: off_output, } } - - fn load(&mut self, host: Contiguous, stream: &Stream) -> DevBlob { - self.event.synchronize(); - match host { - Contiguous::Borrowed(host) => self.host.copy_from_slice(host), - Contiguous::Owned(host) => self.host = host, - }; - stream.memcpy_h2d(&mut self.dev, &self.host); - stream.record(&mut self.event); - replace(&mut self.dev, stream.malloc::(self.host.len())) - } } impl WeightLoader for Weights { type Hardware = Device; + type Weight<'s> = &'s [DevByte] where Self: 's; - #[inline] - fn load_blk( - &self, - which: BlkWeight, + fn load_blk<'a>( + &'a self, + which: LlamaBlkWeight, iblk: usize, - _queue: &QueueOf, - ) -> Self::Weight<'_> { - let blk = &self.blks[iblk]; - match which { - BlkWeight::AttnNorm => &blk.attn_norm, - BlkWeight::AttnQKV => &blk.attn_qkv, - BlkWeight::AttnQKVBias => &blk.attn_qkv_bias, - BlkWeight::AttnO => &blk.attn_o, - BlkWeight::FfnNorm => &blk.ffn_norm, - BlkWeight::FfnGateInp => &blk.ffn_gate_inp, - BlkWeight::FfnGateUp => &blk.ffn_gate_up, - BlkWeight::FfnDown => &blk.ffn_down, - } + _queue: &'a QueueOf, + ) -> Self::Weight<'a> { + let off = &self.blks[iblk]; + use LlamaBlkWeight as W; + #[rustfmt::skip] + let off = match which { + W::AttnNorm => &off.attn_norm , + W::AttnQKV => &off.attn_qkv , + W::AttnQKVBias => &off.attn_qkv_bias, + W::AttnO => &off.attn_o , + W::FfnNorm => &off.ffn_norm , + W::FfnGateInp => &off.ffn_gate_inp , + W::FfnGateUp => &off.ffn_gate_up , + W::FfnDown => &off.ffn_down , + }; + &self.mem[off.clone()] } fn load_moe<'a>( &'a self, - which: BlkWeight, + which: LlamaBlkWeight, iblk: usize, iexp: usize, _queue: &'a QueueOf, ) -> Self::Weight<'a> { - let blk = &self.blks[iblk]; - let w = match which { - BlkWeight::FfnGateUp => &blk.ffn_gate_up, - BlkWeight::FfnDown => &blk.ffn_down, - _ => unreachable!(), + let off = &self.blks[iblk]; + use LlamaBlkWeight as W; + #[rustfmt::skip] + let off = match which { + W::FfnGateUp => &off.ffn_gate_up, + W::FfnDown => &off.ffn_down , + _ => unreachable!() , }; + let w = &self.mem[off.clone()]; let one = w.len() / self.nexp; &w[iexp * one..][..one] } - #[inline] - fn output_norm(&self, _queue: &QueueOf) -> Self::Weight<'_> { - &self.output_norm + fn output_norm<'a>(&'a self, _queue: &'a QueueOf) -> Self::Weight<'a> { + &self.mem[self.output_norm.clone()] } - #[inline] - fn output(&self, _queue: &QueueOf) -> Self::Weight<'_> { - &self.output + fn output<'a>(&'a self, _queue: &'a QueueOf) -> Self::Weight<'a> { + &self.mem[self.output.clone()] } }