From f4ce744ca68f37c75a81a18c2aa22d471732bfae Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Tue, 20 Feb 2024 17:27:58 +0800 Subject: [PATCH] =?UTF-8?q?refactor(transformer-cpu):=20=E5=9F=BA=E4=BA=8E?= =?UTF-8?q?=E5=BC=A0=E9=87=8F=E5=AE=9E=E7=8E=B0=20kernel?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- tensor/src/tensor.rs | 4 +++- transformer-cpu/src/kernel.rs | 30 ++++++++++++++++++++++++++---- transformer-cpu/src/lib.rs | 13 ++++++------- 3 files changed, 35 insertions(+), 12 deletions(-) diff --git a/tensor/src/tensor.rs b/tensor/src/tensor.rs index 5d95f83e..2533b425 100644 --- a/tensor/src/tensor.rs +++ b/tensor/src/tensor.rs @@ -10,7 +10,7 @@ pub struct Tensor { physical: Physical, } -impl Tensor { +impl Tensor { pub fn new(data_type: DataType, shape: &[usize], physical: Physical) -> Self { let shape = Shape::from_iter(shape.iter().map(|&d| d as udim)); Self { @@ -62,7 +62,9 @@ impl Tensor { physical, } } +} +impl Tensor { pub fn reshape(&self, shape: Shape) -> Self { debug_assert!(self.is_contiguous()); debug_assert_eq!( diff --git a/transformer-cpu/src/kernel.rs b/transformer-cpu/src/kernel.rs index ac2d9b1c..1ed0b62d 100644 --- a/transformer-cpu/src/kernel.rs +++ b/transformer-cpu/src/kernel.rs @@ -4,9 +4,18 @@ use std::{ iter::zip, ops::{Mul, MulAssign}, }; -use tensor::DataType; +use tensor::{DataType, Tensor}; -pub(super) fn gather(x: &mut [u8], table: &[u8], tokens: &[utok]) { +pub(super) fn gather(x: &mut Tensor, table: &Tensor, tokens: &[utok]) +where + T: AsMut<[u8]>, + U: AsRef<[u8]>, +{ + debug_assert_eq!(x.data_type(), table.data_type()); + debug_assert_eq!(x.shape().last(), table.shape().last()); + + let x = x.as_mut_slice(); + let table = table.as_slice(); debug_assert_eq!(x.len() % tokens.len(), 0); let d = x.len() / tokens.len(); @@ -15,8 +24,21 @@ pub(super) fn gather(x: &mut [u8], table: &[u8], tokens: &[utok]) { } } -pub(super) fn rms_norm(o: &mut [u8], x: &[u8], w: &[u8], theta: f32, dt: DataType) { - debug_assert_eq!(o.len(), x.len()); +pub(super) fn rms_norm(o: &mut Tensor, x: &Tensor, w: &Tensor, theta: f32) +where + T: AsMut<[u8]>, + U: AsRef<[u8]>, + V: AsRef<[u8]>, +{ + let dt = o.data_type(); + debug_assert_eq!(x.data_type(), dt); + debug_assert_eq!(w.data_type(), dt); + debug_assert_eq!(o.shape(), x.shape()); + debug_assert_eq!(&[*o.shape().last().unwrap()], w.shape()); + + let o = o.as_mut_slice(); + let x = x.as_slice(); + let w = w.as_slice(); fn op + MulAssign>( o: &mut [u8], diff --git a/transformer-cpu/src/lib.rs b/transformer-cpu/src/lib.rs index 08f50791..af4a4223 100644 --- a/transformer-cpu/src/lib.rs +++ b/transformer-cpu/src/lib.rs @@ -5,7 +5,7 @@ use cache::LayerCache; use common::{upos, utok}; use kernel::{gather, rms_norm}; use model_parameters::{Llama2, Memory}; -use tensor::DataType; +use tensor::{DataType, Tensor}; pub extern crate model_parameters; @@ -37,19 +37,18 @@ impl Transformer { let d = self.model.hidden_size(); let dt = self.model.data_type(); - let mut a = vec![0u8; seq_len * d * dt.size()]; - gather(&mut a, self.model.embed_tokens().as_slice(), tokens); + let mut a = Tensor::new(dt, &[seq_len, d], vec![0u8; seq_len * d * dt.size()]); + gather(&mut a, &self.model.embed_tokens(), tokens); - let mut b = vec![0u8; seq_len * d * dt.size()]; + let mut b = Tensor::new(dt, &[seq_len, d], vec![0u8; seq_len * d * dt.size()]); for l in 0..self.model.num_hidden_layers() { { // b <- rms-norm(a) let o = &mut b; let x = &a; - let w = self.model.input_layernorm(l); - let w = w.as_slice(); + let w = &self.model.input_layernorm(l); let theta = self.model.rope_theta(); - rms_norm(o, x, w, theta, dt); + rms_norm(o, x, w, theta); } }