Skip to content

Commit

Permalink
feat(transformer-cpu): 实现 cpu transformer 的构造
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Feb 16, 2024
1 parent e49606e commit 0158392
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 11 deletions.
2 changes: 1 addition & 1 deletion model-parameters/src/memory/inside_memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
use half::{bf16, f16};

impl Memory<Vec<u8>> {
pub fn cast(src: &impl Llama2, new_dtype: DataType) -> Self {
pub fn cast(src: &dyn Llama2, new_dtype: DataType) -> Self {
let from = src.data_type();
let mut blob = Vec::with_capacity(src.size() * new_dtype.size() / from.size());
let mut append = |src: &[u8]| {
Expand Down
19 changes: 19 additions & 0 deletions transformer-cpu/src/cache.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
use model_parameters::Llama2;

pub(super) struct LayerCache(Vec<u8>);

impl LayerCache {
pub fn new(model: &dyn Llama2, batch: usize) -> Self {
let n = batch;
let dkv = model.num_key_value_heads();
let ds = model.max_position_embeddings();
let dh = model.hidden_size() / model.num_attention_heads();
Self(vec![0; 2 * n * dkv * ds * dh])
}

#[inline]
pub fn get(&mut self) -> (&mut [u8], &mut [u8]) {
let mid = self.0.len() / 2;
self.0.split_at_mut(mid)
}
}
45 changes: 35 additions & 10 deletions transformer-cpu/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,44 @@
mod cache;

use cache::LayerCache;
use model_parameters::{DataType, Llama2, Memory};

pub struct Transformer {
model: Box<dyn Llama2>,
cache: Vec<LayerCache>,
}

impl<T> From<T> for Transformer
where
T: 'static + Llama2,
{
fn from(value: T) -> Self {
let model: Box<dyn Llama2> = if value.data_type() == DataType::BF16 {
Box::new(Memory::cast(&value, DataType::F32))
} else {
Box::new(value)
impl Transformer {
pub fn new(model: Box<dyn Llama2>, batch: usize) -> Self {
let model = match model.data_type() {
DataType::BF16 => Box::new(Memory::cast(&*model, DataType::F32)),
_ => model,
};
Self { model }
let cache = (0..model.num_hidden_layers())
.map(|_| LayerCache::new(&*model, batch))
.collect();
Self { model, cache }
}
}

#[test]
fn test_build() {
use model_parameters::SafeTensorError;
use std::time::Instant;

let t0 = Instant::now();
let safetensors = Memory::load_safetensors("../../TinyLlama-1.1B-Chat-v1.0");
let t1 = Instant::now();
println!("mmap {:?}", t1 - t0);

let safetensors = match safetensors {
Ok(m) => m,
Err(SafeTensorError::Io(e)) if e.kind() == std::io::ErrorKind::NotFound => return,
Err(e) => panic!("{e:?}"),
};

let t0 = Instant::now();
let _transformer = Transformer::new(Box::new(safetensors), 1);
let t1 = Instant::now();
println!("build transformer {:?}", t1 - t0);
}

0 comments on commit 0158392

Please sign in to comment.