Skip to content

Commit

Permalink
refactor(transformer-cpu): 整理、优化 transformer cpu
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Mar 14, 2024
1 parent 6292545 commit 60fcc41
Showing 1 changed file with 16 additions and 12 deletions.
28 changes: 16 additions & 12 deletions transformer-cpu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ use kernel::{gather, mat_mul, rms_norm, rotary_embedding, softmax, swiglu};
use storage::{Cell, Unique};
use tensor::{reslice, slice, udim, DataType, Tensor};

pub type Request<'a, Id> = transformer::Request<'a, Id, Cell>;
pub type LayerCache = transformer::LayerCache<Cell>;
pub type Request<'a, Id> = transformer::Request<'a, Id, Unique>;
pub type LayerCache = transformer::LayerCache<Unique>;
pub use transformer::{save, Llama2, Memory};

pub struct Transformer(Box<dyn Llama2>);
Expand All @@ -20,7 +20,7 @@ impl Transformer {

#[inline]
pub fn new_cache(&self) -> Vec<LayerCache> {
LayerCache::new_layers(&*self.0, |dt, shape| tensor(dt, shape, Cell::new))
LayerCache::new_layers(&*self.0, |dt, shape| tensor(dt, shape, Unique::new))
}

#[inline]
Expand Down Expand Up @@ -124,17 +124,21 @@ impl Transformer {
let o = o.as_mut().slice(req_slice);
let mut o = unsafe { o.map_physical(|u| &mut ***u) };

let (k_cache, v_cache) = r.cache[layer].get();
let mut q_att = Tensor::new(dt, &[nh, seq_len, dh], &mut q_buf[..]);
let mut k_cat = k_cache.clone().slice(cat_slice);
let mut v_cat = v_cache.clone().slice(cat_slice);
let (k_cache, v_cache) = r.cache[layer].get();
let k_cat = k_cache.as_mut().slice(cat_slice);
let v_cat = v_cache.as_mut().slice(cat_slice);
let mut k_cat = unsafe { k_cat.map_physical(|u| &mut **u) };
let mut v_cat = unsafe { v_cat.map_physical(|u| &mut **u) };
q.access().reform_to(&mut q_att);
k.access().reform_to(&mut k_cat.access_mut());
v.access().reform_to(&mut v_cat.access_mut());
k.access().reform_to(&mut k_cat);
v.access().reform_to(&mut v_cat);

let q_att = q_att.reshape(&[nkvh, head_group * seq_len, dh]);
let k_att = k_cache.clone().slice(att_slice).transpose(&[0, 2, 1]);
let v_att = v_cache.clone().slice(att_slice);
let k_att = k_cache.as_ref().slice(att_slice).transpose(&[0, 2, 1]);
let v_att = v_cache.as_ref().slice(att_slice);
let k_att = unsafe { k_att.map_physical(|u| &**u) };
let v_att = unsafe { v_att.map_physical(|u| &**u) };
// println!("layer {layer} q attention:\n{}", q_att);
// println!("layer {layer} k attention:\n{}", k_att.access());
// println!("layer {layer} v attention:\n{}", v_att.access());
Expand All @@ -143,11 +147,11 @@ impl Transformer {
let shape_att1 = &[nkvh * head_group, seq_len, att_len];

let mut att = Tensor::new(dt, shape_att0, &mut att_buf[..]);
mat_mul(&mut att, 0., &q_att, &k_att.access(), head_div);
mat_mul(&mut att, 0., &q_att, &k_att, head_div);
let mut att = att.reshape(shape_att1);
softmax(&mut att);
let mut x2 = q_att;
mat_mul(&mut x2, 0., &att.reshape(shape_att0), &v_att.access(), 1.);
mat_mul(&mut x2, 0., &att.reshape(shape_att0), &v_att, 1.);

x2.reshape(&[nh, seq_len, dh]).reform_to(&mut o);
// println!("layer {layer} after attention:\n{}", o);
Expand Down

0 comments on commit 60fcc41

Please sign in to comment.