Skip to content

Commit

Permalink
refactor(transformer-cpu): 整理 kernels,简化调用
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Mar 4, 2024
1 parent 8f50ee1 commit 66ec32e
Show file tree
Hide file tree
Showing 7 changed files with 39 additions and 35 deletions.
2 changes: 1 addition & 1 deletion transformer-cpu/src/kernel/fused_softmax.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::ops::DerefMut;
use tensor::{expand_indices, idx_strides, Tensor};

/// - x: [N0, N1, ... , N_, seq_len, att_len]
pub fn softmax<T>(x: &mut Tensor<T>)
pub fn softmax<T>(mut x: Tensor<T>)
where
T: DerefMut<Target = [u8]>,
{
Expand Down
2 changes: 1 addition & 1 deletion transformer-cpu/src/kernel/gather.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use common::utok;
use std::ops::{Deref, DerefMut};
use tensor::Tensor;

pub fn gather<T, U>(x: &mut Tensor<T>, table: &Tensor<U>, tokens: &[&[utok]])
pub fn gather<T, U>(mut x: Tensor<T>, table: &Tensor<U>, tokens: &[&[utok]])
where
T: DerefMut<Target = [u8]>,
U: Deref<Target = [u8]>,
Expand Down
2 changes: 1 addition & 1 deletion transformer-cpu/src/kernel/mat_mul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use tensor::{expand_indices, idx_strides, DataType, Tensor};
/// - c: [N0, N1, ... , N_, m, n]
/// - a: [N0, N1, ... , N_, m, k]
/// - b: [N0, N1, ... , N_, k, n]
pub fn mat_mul<T, U, V>(c: &mut Tensor<T>, beta: f32, a: &Tensor<U>, b: &Tensor<V>, alpha: f32)
pub fn mat_mul<T, U, V>(mut c: Tensor<T>, beta: f32, a: &Tensor<U>, b: &Tensor<V>, alpha: f32)
where
T: DerefMut<Target = [u8]>,
U: Deref<Target = [u8]>,
Expand Down
15 changes: 12 additions & 3 deletions transformer-cpu/src/kernel/rms_norm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,22 @@ use std::{
};
use tensor::{reslice, reslice_mut, DataType, Tensor};

pub fn rms_norm<T, U, V>(o: &mut Tensor<T>, x: &Tensor<U>, w: &Tensor<V>, epsilon: f32)
pub fn rms_norm<T, U, V>(mut o: Tensor<T>, x: &Tensor<U>, w: &Tensor<V>, epsilon: f32)
where
T: DerefMut<Target = [u8]>,
U: Deref<Target = [u8]>,
V: Deref<Target = [u8]>,
{
let &[.., d] = o.shape() else { panic!() };
let dt = o.data_type();

debug_assert_eq!(x.data_type(), dt);
debug_assert_eq!(w.data_type(), dt);
debug_assert_eq!(o.shape(), x.shape());
debug_assert_eq!(&[*o.shape().last().unwrap()], w.shape());
debug_assert_eq!(w.shape(), &[d]);
debug_assert!(o.is_contiguous());
debug_assert!(x.is_contiguous());
debug_assert!(w.is_contiguous());

let o = o.as_mut_slice();
let x = x.as_slice();
Expand All @@ -36,9 +41,13 @@ where
T: DerefMut<Target = [u8]>,
U: Deref<Target = [u8]>,
{
let &[.., d] = o.shape() else { panic!() };
let dt = o.data_type();

debug_assert_eq!(w.data_type(), dt);
debug_assert_eq!(&[*o.shape().last().unwrap()], w.shape());
debug_assert_eq!(w.shape(), &[d]);
debug_assert!(o.is_contiguous());
debug_assert!(w.is_contiguous());

let o: &mut [u8] = o.as_mut_slice();
let x = unsafe { std::slice::from_raw_parts(o.as_ptr(), o.len()) };
Expand Down
16 changes: 8 additions & 8 deletions transformer-cpu/src/kernel/rotary_embedding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,22 @@ use gemm::f16;
use std::ops::{Deref, DerefMut};
use tensor::{expand_indices, idx_strides, udim, Tensor};

/// - t: [N0, N1, ... , N_, num_head, head_dim]
/// - pos: [N0, N1, ... , N_]
pub fn rotary_embedding<T, U>(t: &mut Tensor<T>, pos: &Tensor<U>, theta: f32)
/// - t: [num_token, num_head, head_dim]
/// - pos: [num_token]
pub fn rotary_embedding<T, U>(mut t: Tensor<T>, pos: &Tensor<U>, theta: f32)
where
T: DerefMut<Target = [u8]>,
U: Deref<Target = [u8]>,
{
assert!(t.contiguous_len() >= 2);
let [batch @ .., nh, dh] = t.shape() else {
let &[num_tokens, nh, dh] = t.shape() else {
panic!()
};
assert_eq!(pos.shape(), batch);
let nh = *nh as usize;
let dh = *dh as usize / 2;
assert_eq!(pos.shape(), &[num_tokens]);
let nh = nh as usize;
let dh = dh as usize / 2;

let (n, idx_strides) = idx_strides(batch);
let (n, idx_strides) = idx_strides(&[num_tokens]);
for i in 0..n {
let pos = pos
.locate(&expand_indices(i, &idx_strides, &[1]).as_view())
Expand Down
2 changes: 1 addition & 1 deletion transformer-cpu/src/kernel/swiglu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
use std::ops::{Deref, DerefMut};
use tensor::{idim, DVector, Tensor};

pub fn swiglu<T, U>(gate: &mut Tensor<T>, up: &Tensor<U>)
pub fn swiglu<T, U>(mut gate: Tensor<T>, up: &Tensor<U>)
where
T: DerefMut<Target = [u8]>,
U: Deref<Target = [u8]>,
Expand Down
35 changes: 15 additions & 20 deletions transformer-cpu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,29 +80,24 @@ impl Transformer {
(None, None)
};

gather(&mut x0.access_mut(), &self.model.embed_tokens(), tokens);
gather(x0.access_mut(), &self.model.embed_tokens(), tokens);
// println!("gather:\n{}", x0.access());

for (layer, cache) in cache.iter_mut().enumerate() {
let input_layernorm = self.model.input_layernorm(layer);
rms_norm(
&mut x1.access_mut(),
&x0.access(),
&input_layernorm,
epsilon,
);
rms_norm(x1.access_mut(), &x0.access(), &input_layernorm, epsilon);
// println!("layer {layer} input norm:\n{}", x1.access());
let w_qkv = self.model.w_qkv(layer).transpose(&[1, 0]);
mat_mul(&mut qkv.access_mut(), 0., &x1.access_mut(), &w_qkv, 1.);
mat_mul(qkv.access_mut(), 0., &x1.access_mut(), &w_qkv, 1.);
let mut qkv = qkv.split(1, &[d as _, dkv as _, dkv as _]);
let v = qkv.pop().unwrap().reshape(&[seq_len, nkvh, dh]);
let mut k = qkv.pop().unwrap().reshape(&[seq_len, nkvh, dh]);
let mut q = qkv.pop().unwrap().reshape(&[seq_len, nh, dh]);
// println!("layer {layer} q:\n{}", q.access());
// println!("layer {layer} k:\n{}", k.access());
// println!("layer {layer} v:\n{}", v.access());
rotary_embedding(&mut q.access_mut(), &pos, theta);
rotary_embedding(&mut k.access_mut(), &pos, theta);
rotary_embedding(q.access_mut(), &pos, theta);
rotary_embedding(k.access_mut(), &pos, theta);
// println!("layer {layer} rot q:\n{}", q.access());
// println!("layer {layer} rot k:\n{}", k.access());
let q = q.transpose(&[1, 0, 2]);
Expand Down Expand Up @@ -131,7 +126,7 @@ impl Transformer {
{
let k_att = k_att.transpose(&[0, 2, 1]);
mat_mul(
&mut att.access_mut(),
att.access_mut(),
0.,
&q_att.access(),
&k_att.access(),
Expand All @@ -140,42 +135,42 @@ impl Transformer {
{
let mut att = att.clone().reshape(&[nh, seq_len, att_len]);
// println!("layer {layer} before softmax:\n{}", att.access());
softmax(&mut att.access_mut());
softmax(att.access_mut());
// println!("layer {layer} after softmax:\n{}", att.access());
}
if let Some(x2) = x2.as_mut() {
mat_mul(&mut x2.access_mut(), 0., &att.access(), &v_att.access(), 1.);
mat_mul(x2.access_mut(), 0., &att.access(), &v_att.access(), 1.);
let x2 = x2.clone().reshape(&[nh, seq_len, dh]).transpose(&[1, 0, 2]);
let mut x1 = x1.clone().reshape(&[seq_len, nh, dh]);
x2.access().reform_to(&mut x1.access_mut());
} else {
let mut x2 = x1.clone().reshape(&[nkvh, head_group * seq_len, dh]);
mat_mul(&mut x2.access_mut(), 0., &att.access(), &v_att.access(), 1.);
mat_mul(x2.access_mut(), 0., &att.access(), &v_att.access(), 1.);
}
// println!("layer {layer} after attention:\n{}", x1.access());
}

let wo = self.model.self_attn_o_proj(layer).transpose(&[1, 0]);
mat_mul(&mut x0.access_mut(), 1., &x1.access(), &wo, 1.);
mat_mul(x0.access_mut(), 1., &x1.access(), &wo, 1.);
// println!("layer {layer} o_proj:\n{}", x0.access());

let post_layernorm = self.model.post_attention_layernorm(layer);
rms_norm(&mut x1.access_mut(), &x0.access(), &post_layernorm, epsilon);
rms_norm(x1.access_mut(), &x0.access(), &post_layernorm, epsilon);
// println!("layer {layer} post norm:\n{}", x1.access());

let w_gate_up = self.model.mlp_gate_up(layer).transpose(&[1, 0]);
mat_mul(&mut gate_up.access_mut(), 0., &x1.access(), &w_gate_up, 1.);
mat_mul(gate_up.access_mut(), 0., &x1.access(), &w_gate_up, 1.);
let mut gate_up = gate_up.split(1, &[di as _, di as _]);
let up = gate_up.pop().unwrap();
let mut gate = gate_up.pop().unwrap();
// println!("layer {layer} gate:\n{}", gate.access());
// println!("layer {layer} up:\n{}", up.access());

swiglu(&mut gate.access_mut(), unsafe { &up.access_unchecked() });
swiglu(gate.access_mut(), unsafe { &up.access_unchecked() });
// println!("layer {layer} swiglu:\n{}", gate.access());

let mlp_down = self.model.mlp_down(layer).transpose(&[1, 0]);
mat_mul(&mut x0.access_mut(), 1., &gate.access(), &mlp_down, 1.);
mat_mul(x0.access_mut(), 1., &gate.access(), &mlp_down, 1.);
// println!("layer {layer} down:\n{}", x0.access());
}

Expand All @@ -192,7 +187,7 @@ impl Transformer {
let dt = self.model.data_type();
let voc = self.model.vocab_size() as udim;
mat_mul(
&mut Tensor::new(dt, &[1, voc], reslice_mut(&mut self.logits)),
Tensor::new(dt, &[1, voc], reslice_mut(&mut self.logits)),
0.,
&x.access(),
&self.model.lm_head().transpose(&[1, 0]),
Expand Down

0 comments on commit 66ec32e

Please sign in to comment.