Skip to content

Commit

Permalink
refactor(transformer-cpu): 基于张量实现 kernel
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <ydrml@hotmail.com>
YdrMaster committed Feb 20, 2024
1 parent ff9360b commit f4ce744
Showing 3 changed files with 35 additions and 12 deletions.
4 changes: 3 additions & 1 deletion tensor/src/tensor.rs
Original file line number Diff line number Diff line change
@@ -10,7 +10,7 @@ pub struct Tensor<Physical> {
physical: Physical,
}

impl<Physical: Clone> Tensor<Physical> {
impl<Physical> Tensor<Physical> {
pub fn new(data_type: DataType, shape: &[usize], physical: Physical) -> Self {
let shape = Shape::from_iter(shape.iter().map(|&d| d as udim));
Self {
@@ -62,7 +62,9 @@ impl<Physical: Clone> Tensor<Physical> {
physical,
}
}
}

impl<Physical: Clone> Tensor<Physical> {
pub fn reshape(&self, shape: Shape) -> Self {
debug_assert!(self.is_contiguous());
debug_assert_eq!(
30 changes: 26 additions & 4 deletions transformer-cpu/src/kernel.rs
Original file line number Diff line number Diff line change
@@ -4,9 +4,18 @@ use std::{
iter::zip,
ops::{Mul, MulAssign},
};
use tensor::DataType;
use tensor::{DataType, Tensor};

pub(super) fn gather(x: &mut [u8], table: &[u8], tokens: &[utok]) {
pub(super) fn gather<T, U>(x: &mut Tensor<T>, table: &Tensor<U>, tokens: &[utok])
where
T: AsMut<[u8]>,
U: AsRef<[u8]>,
{
debug_assert_eq!(x.data_type(), table.data_type());
debug_assert_eq!(x.shape().last(), table.shape().last());

let x = x.as_mut_slice();
let table = table.as_slice();
debug_assert_eq!(x.len() % tokens.len(), 0);

let d = x.len() / tokens.len();
@@ -15,8 +24,21 @@ pub(super) fn gather(x: &mut [u8], table: &[u8], tokens: &[utok]) {
}
}

pub(super) fn rms_norm(o: &mut [u8], x: &[u8], w: &[u8], theta: f32, dt: DataType) {
debug_assert_eq!(o.len(), x.len());
pub(super) fn rms_norm<T, U, V>(o: &mut Tensor<T>, x: &Tensor<U>, w: &Tensor<V>, theta: f32)
where
T: AsMut<[u8]>,
U: AsRef<[u8]>,
V: AsRef<[u8]>,
{
let dt = o.data_type();
debug_assert_eq!(x.data_type(), dt);
debug_assert_eq!(w.data_type(), dt);
debug_assert_eq!(o.shape(), x.shape());
debug_assert_eq!(&[*o.shape().last().unwrap()], w.shape());

let o = o.as_mut_slice();
let x = x.as_slice();
let w = w.as_slice();

fn op<T: Copy + Mul<Output = T> + MulAssign>(
o: &mut [u8],
13 changes: 6 additions & 7 deletions transformer-cpu/src/lib.rs
Original file line number Diff line number Diff line change
@@ -5,7 +5,7 @@ use cache::LayerCache;
use common::{upos, utok};
use kernel::{gather, rms_norm};
use model_parameters::{Llama2, Memory};
use tensor::DataType;
use tensor::{DataType, Tensor};

pub extern crate model_parameters;

@@ -37,19 +37,18 @@ impl Transformer {
let d = self.model.hidden_size();
let dt = self.model.data_type();

let mut a = vec![0u8; seq_len * d * dt.size()];
gather(&mut a, self.model.embed_tokens().as_slice(), tokens);
let mut a = Tensor::new(dt, &[seq_len, d], vec![0u8; seq_len * d * dt.size()]);
gather(&mut a, &self.model.embed_tokens(), tokens);

let mut b = vec![0u8; seq_len * d * dt.size()];
let mut b = Tensor::new(dt, &[seq_len, d], vec![0u8; seq_len * d * dt.size()]);
for l in 0..self.model.num_hidden_layers() {
{
// b <- rms-norm(a)
let o = &mut b;
let x = &a;
let w = self.model.input_layernorm(l);
let w = w.as_slice();
let w = &self.model.input_layernorm(l);
let theta = self.model.rope_theta();
rms_norm(o, x, w, theta, dt);
rms_norm(o, x, w, theta);
}
}

0 comments on commit f4ce744

Please sign in to comment.