diff --git a/model-parameters/src/memory/inside_memory.rs b/model-parameters/src/memory/inside_memory.rs index 83420d6c..bc98991d 100644 --- a/model-parameters/src/memory/inside_memory.rs +++ b/model-parameters/src/memory/inside_memory.rs @@ -2,7 +2,7 @@ use half::{bf16, f16}; impl Memory> { - pub fn cast(src: &impl Llama2, new_dtype: DataType) -> Self { + pub fn cast(src: &dyn Llama2, new_dtype: DataType) -> Self { let from = src.data_type(); let mut blob = Vec::with_capacity(src.size() * new_dtype.size() / from.size()); let mut append = |src: &[u8]| { diff --git a/transformer-cpu/src/cache.rs b/transformer-cpu/src/cache.rs new file mode 100644 index 00000000..c90a5374 --- /dev/null +++ b/transformer-cpu/src/cache.rs @@ -0,0 +1,19 @@ +use model_parameters::Llama2; + +pub(super) struct LayerCache(Vec); + +impl LayerCache { + pub fn new(model: &dyn Llama2, batch: usize) -> Self { + let n = batch; + let dkv = model.num_key_value_heads(); + let ds = model.max_position_embeddings(); + let dh = model.hidden_size() / model.num_attention_heads(); + Self(vec![0; 2 * n * dkv * ds * dh]) + } + + #[inline] + pub fn get(&mut self) -> (&mut [u8], &mut [u8]) { + let mid = self.0.len() / 2; + self.0.split_at_mut(mid) + } +} diff --git a/transformer-cpu/src/lib.rs b/transformer-cpu/src/lib.rs index b02cad0c..698dfd65 100644 --- a/transformer-cpu/src/lib.rs +++ b/transformer-cpu/src/lib.rs @@ -1,19 +1,44 @@ +mod cache; + +use cache::LayerCache; use model_parameters::{DataType, Llama2, Memory}; pub struct Transformer { model: Box, + cache: Vec, } -impl From for Transformer -where - T: 'static + Llama2, -{ - fn from(value: T) -> Self { - let model: Box = if value.data_type() == DataType::BF16 { - Box::new(Memory::cast(&value, DataType::F32)) - } else { - Box::new(value) +impl Transformer { + pub fn new(model: Box, batch: usize) -> Self { + let model = match model.data_type() { + DataType::BF16 => Box::new(Memory::cast(&*model, DataType::F32)), + _ => model, }; - Self { model } + let cache = (0..model.num_hidden_layers()) + .map(|_| LayerCache::new(&*model, batch)) + .collect(); + Self { model, cache } } } + +#[test] +fn test_build() { + use model_parameters::SafeTensorError; + use std::time::Instant; + + let t0 = Instant::now(); + let safetensors = Memory::load_safetensors("../../TinyLlama-1.1B-Chat-v1.0"); + let t1 = Instant::now(); + println!("mmap {:?}", t1 - t0); + + let safetensors = match safetensors { + Ok(m) => m, + Err(SafeTensorError::Io(e)) if e.kind() == std::io::ErrorKind::NotFound => return, + Err(e) => panic!("{e:?}"), + }; + + let t0 = Instant::now(); + let _transformer = Transformer::new(Box::new(safetensors), 1); + let t1 = Instant::now(); + println!("build transformer {:?}", t1 - t0); +}