From 60fcc411e7a2d162432a0c931879bd6fc456c0ff Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Thu, 14 Mar 2024 15:22:14 +0800 Subject: [PATCH] =?UTF-8?q?refactor(transformer-cpu):=20=E6=95=B4=E7=90=86?= =?UTF-8?q?=E3=80=81=E4=BC=98=E5=8C=96=20transformer=20cpu?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- transformer-cpu/src/lib.rs | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/transformer-cpu/src/lib.rs b/transformer-cpu/src/lib.rs index c0f8aa69..8a7328d9 100644 --- a/transformer-cpu/src/lib.rs +++ b/transformer-cpu/src/lib.rs @@ -6,8 +6,8 @@ use kernel::{gather, mat_mul, rms_norm, rotary_embedding, softmax, swiglu}; use storage::{Cell, Unique}; use tensor::{reslice, slice, udim, DataType, Tensor}; -pub type Request<'a, Id> = transformer::Request<'a, Id, Cell>; -pub type LayerCache = transformer::LayerCache; +pub type Request<'a, Id> = transformer::Request<'a, Id, Unique>; +pub type LayerCache = transformer::LayerCache; pub use transformer::{save, Llama2, Memory}; pub struct Transformer(Box); @@ -20,7 +20,7 @@ impl Transformer { #[inline] pub fn new_cache(&self) -> Vec { - LayerCache::new_layers(&*self.0, |dt, shape| tensor(dt, shape, Cell::new)) + LayerCache::new_layers(&*self.0, |dt, shape| tensor(dt, shape, Unique::new)) } #[inline] @@ -124,17 +124,21 @@ impl Transformer { let o = o.as_mut().slice(req_slice); let mut o = unsafe { o.map_physical(|u| &mut ***u) }; - let (k_cache, v_cache) = r.cache[layer].get(); let mut q_att = Tensor::new(dt, &[nh, seq_len, dh], &mut q_buf[..]); - let mut k_cat = k_cache.clone().slice(cat_slice); - let mut v_cat = v_cache.clone().slice(cat_slice); + let (k_cache, v_cache) = r.cache[layer].get(); + let k_cat = k_cache.as_mut().slice(cat_slice); + let v_cat = v_cache.as_mut().slice(cat_slice); + let mut k_cat = unsafe { k_cat.map_physical(|u| &mut **u) }; + let mut v_cat = unsafe { v_cat.map_physical(|u| &mut **u) }; q.access().reform_to(&mut q_att); - k.access().reform_to(&mut k_cat.access_mut()); - v.access().reform_to(&mut v_cat.access_mut()); + k.access().reform_to(&mut k_cat); + v.access().reform_to(&mut v_cat); let q_att = q_att.reshape(&[nkvh, head_group * seq_len, dh]); - let k_att = k_cache.clone().slice(att_slice).transpose(&[0, 2, 1]); - let v_att = v_cache.clone().slice(att_slice); + let k_att = k_cache.as_ref().slice(att_slice).transpose(&[0, 2, 1]); + let v_att = v_cache.as_ref().slice(att_slice); + let k_att = unsafe { k_att.map_physical(|u| &**u) }; + let v_att = unsafe { v_att.map_physical(|u| &**u) }; // println!("layer {layer} q attention:\n{}", q_att); // println!("layer {layer} k attention:\n{}", k_att.access()); // println!("layer {layer} v attention:\n{}", v_att.access()); @@ -143,11 +147,11 @@ impl Transformer { let shape_att1 = &[nkvh * head_group, seq_len, att_len]; let mut att = Tensor::new(dt, shape_att0, &mut att_buf[..]); - mat_mul(&mut att, 0., &q_att, &k_att.access(), head_div); + mat_mul(&mut att, 0., &q_att, &k_att, head_div); let mut att = att.reshape(shape_att1); softmax(&mut att); let mut x2 = q_att; - mat_mul(&mut x2, 0., &att.reshape(shape_att0), &v_att.access(), 1.); + mat_mul(&mut x2, 0., &att.reshape(shape_att0), &v_att, 1.); x2.reshape(&[nh, seq_len, dh]).reform_to(&mut o); // println!("layer {layer} after attention:\n{}", o);