Skip to content

Commit

Permalink
refactor(transformer-cpu): 整理、优化 transformer cpu
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Mar 14, 2024
1 parent 60fcc41 commit a6cdb82
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 102 deletions.
1 change: 1 addition & 0 deletions tensor/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ pub use nalgebra::DVector;
pub use pattern::{expand_indices, idx_strides, Affine, Shape};
pub use physical_cell::PhysicalCell;
pub use slice::SliceDim;
pub use split::Splitable;
pub use tensor::Tensor;

use std::mem::{align_of, size_of, size_of_val};
Expand Down
15 changes: 13 additions & 2 deletions tensor/src/split.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,25 @@
use crate::{idim, pattern::Pattern, udim, Affine, Shape, Tensor};

impl<Physical: Clone> Tensor<Physical> {
pub trait Splitable {
fn split(&self) -> Self;
}

impl<T: Clone> Splitable for T {
#[inline]
fn split(&self) -> Self {
self.clone()
}
}

impl<Physical: Splitable> Tensor<Physical> {
pub fn split(&self, axis: usize, segments: &[udim]) -> Vec<Self> {
build(axis, segments, &self.shape)
.into_iter()
.map(|(shape, affine)| Self {
data_type: self.data_type,
shape,
pattern: Pattern(affine * &self.pattern.0),
physical: self.physical.clone(),
physical: self.physical.split(),
})
.collect()
}
Expand Down
2 changes: 1 addition & 1 deletion transformer-cpu/src/kernel/rotary_embedding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use tensor::{expand_indices, idx_strides, udim, Tensor};

/// - t: [num_token, num_head, head_dim]
/// - pos: [num_token]
pub fn rotary_embedding<T, U>(mut t: Tensor<T>, pos: &Tensor<U>, theta: f32)
pub fn rotary_embedding<T, U>(t: &mut Tensor<T>, pos: &Tensor<U>, theta: f32)
where
T: DerefMut<Target = [u8]>,
U: Deref<Target = [u8]>,
Expand Down
2 changes: 1 addition & 1 deletion transformer-cpu/src/kernel/swiglu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
use std::ops::{Deref, DerefMut};
use tensor::{idim, DVector, Tensor};

pub fn swiglu<T, U>(mut gate: Tensor<T>, up: &Tensor<U>)
pub fn swiglu<T, U>(gate: &mut Tensor<T>, up: &Tensor<U>)
where
T: DerefMut<Target = [u8]>,
U: Deref<Target = [u8]>,
Expand Down
72 changes: 39 additions & 33 deletions transformer-cpu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ mod storage;

use common::utok;
use kernel::{gather, mat_mul, rms_norm, rotary_embedding, softmax, swiglu};
use storage::{Cell, Unique};
use storage::Storage;
use tensor::{reslice, slice, udim, DataType, Tensor};

pub type Request<'a, Id> = transformer::Request<'a, Id, Unique>;
pub type LayerCache = transformer::LayerCache<Unique>;
pub type Request<'a, Id> = transformer::Request<'a, Id, Storage>;
pub type LayerCache = transformer::LayerCache<Storage>;
pub use transformer::{save, Llama2, Memory};

pub struct Transformer(Box<dyn Llama2>);
Expand All @@ -20,7 +20,7 @@ impl Transformer {

#[inline]
pub fn new_cache(&self) -> Vec<LayerCache> {
LayerCache::new_layers(&*self.0, |dt, shape| tensor(dt, shape, Unique::new))
LayerCache::new_layers(&*self.0, tensor)
}

#[inline]
Expand Down Expand Up @@ -73,41 +73,47 @@ impl Transformer {
}
let pos = Tensor::new(DataType::U32, &[nt], reslice(&pos));

let mut x0 = tensor(dt, &[nt, d], Unique::new);
let mut x1 = tensor(dt, &[nt, d], Unique::new);
let mut qkv = tensor(dt, &[nt, d + dkv + dkv], Cell::new);
let mut q_buf = Unique::new((nh * max_seq_len * dh) as usize * dt.size());
let mut att_buf =
Unique::new((nkvh * head_group * max_seq_len * max_att_len) as usize * dt.size());
let mut gate_up = tensor(dt, &[nt, di + di], Cell::new);
let mut x0 = tensor(dt, &[nt, d]);
let mut x1 = tensor(dt, &[nt, d]);
let mut qkv = tensor(dt, &[nt, d + dkv + dkv]);
let mut q_buf = Storage::new((nh * max_seq_len * dh) as usize * dt.size());
let mut att_buf = Storage::new((nh * max_seq_len * max_att_len) as usize * dt.size());
let mut gate_up = tensor(dt, &[nt, di + di]);

let tokens = requests.iter().map(|r| r.tokens).flatten().copied();
gather(&mut x0, &self.0.embed_tokens(), tokens);
// println!("gather:\n{}", x0.access());
// println!("gather:\n{x0}");

for layer in 0..self.0.num_hidden_layers() {
let input_layernorm = self.0.input_layernorm(layer);
rms_norm(&mut x1, &x0, &input_layernorm, epsilon);
// println!("layer {layer} input norm:\n{}", x1.access());
// println!("layer {layer} input norm:\n{x1}");

let w_qkv = self.0.w_qkv(layer).transpose(&[1, 0]);
mat_mul(&mut qkv.access_mut(), 0., &x1, &w_qkv, 1.);
mat_mul(&mut qkv, 0., &x1, &w_qkv, 1.);
let mut qkv = qkv.split(1, &[d as _, dkv as _, dkv as _]);
let v = qkv.pop().unwrap().reshape(&[nt, nkvh, dh]);
let mut k = qkv.pop().unwrap().reshape(&[nt, nkvh, dh]);
let mut q = qkv.pop().unwrap().reshape(&[nt, nh, dh]);
// println!("layer {layer} q:\n{}", q.access());
// println!("layer {layer} k:\n{}", k.access());
// println!("layer {layer} v:\n{}", v.access());
rotary_embedding(q.access_mut(), &pos, theta);
rotary_embedding(k.access_mut(), &pos, theta);
// println!("layer {layer} rot q:\n{}", q.access());
// println!("layer {layer} rot k:\n{}", k.access());
let q = q.transpose(&[1, 0, 2]);
let k = k.transpose(&[1, 0, 2]);
let v = v.transpose(&[1, 0, 2]);
// println!("layer {layer} q:\n{q}");
// println!("layer {layer} k:\n{k}");
// println!("layer {layer} v:\n{v}");

rotary_embedding(&mut q, &pos, theta);
rotary_embedding(&mut k, &pos, theta);
// println!("layer {layer} rot q:\n{q}");
// println!("layer {layer} rot k:\n{k}");

let q = q.as_ref().transpose(&[1, 0, 2]);
let k = k.as_ref().transpose(&[1, 0, 2]);
let v = v.as_ref().transpose(&[1, 0, 2]);
let mut o = x1.as_mut().reshape(&[nt, nh, dh]).transpose(&[1, 0, 2]);

let q = unsafe { q.map_physical(|u| &**u) };
let k = unsafe { k.map_physical(|u| &**u) };
let v = unsafe { v.map_physical(|u| &**u) };

let mut req = 0;
let mut o = x1.as_mut().reshape(&[nt, nh, dh]).transpose(&[1, 0, 2]);
for r in requests.iter_mut() {
let pos = r.pos;
let seq_len = r.seq_len();
Expand All @@ -130,9 +136,9 @@ impl Transformer {
let v_cat = v_cache.as_mut().slice(cat_slice);
let mut k_cat = unsafe { k_cat.map_physical(|u| &mut **u) };
let mut v_cat = unsafe { v_cat.map_physical(|u| &mut **u) };
q.access().reform_to(&mut q_att);
k.access().reform_to(&mut k_cat);
v.access().reform_to(&mut v_cat);
q.reform_to(&mut q_att);
k.reform_to(&mut k_cat);
v.reform_to(&mut v_cat);

let q_att = q_att.reshape(&[nkvh, head_group * seq_len, dh]);
let k_att = k_cache.as_ref().slice(att_slice).transpose(&[0, 2, 1]);
Expand Down Expand Up @@ -166,18 +172,18 @@ impl Transformer {
// println!("layer {layer} post norm:\n{}", x1.access());

let w_gate_up = self.0.mlp_gate_up(layer).transpose(&[1, 0]);
mat_mul(&mut gate_up.access_mut(), 0., &x1, &w_gate_up, 1.);
mat_mul(&mut gate_up, 0., &x1, &w_gate_up, 1.);
let mut gate_up = gate_up.split(1, &[di as _, di as _]);
let up = gate_up.pop().unwrap();
let mut gate = gate_up.pop().unwrap();
// println!("layer {layer} gate:\n{}", gate.access());
// println!("layer {layer} up:\n{}", up.access());

swiglu(gate.access_mut(), unsafe { &up.access_unchecked() });
swiglu(&mut gate, &up);
// println!("layer {layer} swiglu:\n{}", gate.access());

let mlp_down = self.0.mlp_down(layer).transpose(&[1, 0]);
mat_mul(&mut x0, 1., &gate.access(), &mlp_down, 1.);
mat_mul(&mut x0, 1., &gate, &mlp_down, 1.);
// println!("layer {layer} down:\n{}", x0.access());
}

Expand Down Expand Up @@ -225,9 +231,9 @@ impl Transformer {
}

#[inline]
fn tensor<T>(dt: DataType, shape: &[udim], f: impl FnOnce(usize) -> T) -> Tensor<T> {
fn tensor(dt: DataType, shape: &[udim]) -> Tensor<Storage> {
let size = shape.iter().product::<udim>() as usize * dt.size();
Tensor::new(dt, shape, f(size))
Tensor::new(dt, shape, Storage::new(size))
}

#[test]
Expand Down
93 changes: 28 additions & 65 deletions transformer-cpu/src/storage.rs
Original file line number Diff line number Diff line change
@@ -1,95 +1,58 @@
use std::{
cell::{Ref, RefCell, RefMut},
alloc::{alloc, Layout},
mem::align_of,
ops::{Deref, DerefMut},
rc::Rc,
};
use tensor::PhysicalCell;
use tensor::Splitable;

#[derive(Debug)]
#[repr(transparent)]
pub struct Unique(Vec<u8>);
pub struct Storage(Rc<Internal>);

impl Deref for Unique {
type Target = [u8];

#[inline]
fn deref(&self) -> &Self::Target {
&self.0
}
}

impl DerefMut for Unique {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}

impl Unique {
impl Storage {
#[inline]
pub fn new(size: usize) -> Self {
Self(vec![0u8; size])
const ALIGN: usize = align_of::<usize>();
let layout = Layout::from_size_align(size, ALIGN).unwrap();
Self(Rc::new(Internal {
ptr: unsafe { alloc(layout) },
len: size,
}))
}
}

#[derive(Clone, Debug)]
pub struct Cell(Rc<RefCell<Vec<u8>>>);
pub struct VecRef<'a>(Ref<'a, Vec<u8>>);
pub struct VecRefMut<'a>(RefMut<'a, Vec<u8>>);

impl PhysicalCell for Cell {
type Raw = [u8];
type Access<'a> = VecRef<'a>;
type AccessMut<'a> = VecRefMut<'a>;

#[inline]
unsafe fn get_unchecked(&self) -> &Self::Raw {
&*self.0.as_ptr()
}

#[inline]
unsafe fn get_unchecked_mut(&mut self) -> &mut Self::Raw {
&mut *self.0.as_ptr()
}

impl Splitable for Storage {
#[inline]
fn access(&self) -> Self::Access<'_> {
VecRef(self.0.borrow())
}

#[inline]
fn access_mut(&mut self) -> Self::AccessMut<'_> {
VecRefMut(self.0.borrow_mut())
fn split(&self) -> Self {
Self(self.0.clone())
}
}

impl Deref for VecRef<'_> {
impl Deref for Storage {
type Target = [u8];

#[inline]
fn deref(&self) -> &Self::Target {
&self.0
unsafe { std::slice::from_raw_parts(self.0.ptr, self.0.len) }
}
}

impl Deref for VecRefMut<'_> {
type Target = [u8];

impl DerefMut for Storage {
#[inline]
fn deref(&self) -> &Self::Target {
&self.0
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { std::slice::from_raw_parts_mut(self.0.ptr, self.0.len) }
}
}

impl DerefMut for VecRefMut<'_> {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
struct Internal {
ptr: *mut u8,
len: usize,
}

impl Cell {
pub fn new(size: usize) -> Self {
Self(Rc::new(RefCell::new(vec![0; size])))
impl Drop for Internal {
#[inline]
fn drop(&mut self) {
const ALIGN: usize = align_of::<usize>();
let layout = Layout::from_size_align(self.len, ALIGN).unwrap();
unsafe { std::alloc::dealloc(self.ptr, layout) }
}
}

0 comments on commit a6cdb82

Please sign in to comment.