Skip to content

Commit

Permalink
refactor(transformer-cpu): Transformer::update 支持多 Request 同时计算
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Mar 4, 2024
1 parent 66ec32e commit cb3c6ff
Show file tree
Hide file tree
Showing 7 changed files with 150 additions and 93 deletions.
7 changes: 7 additions & 0 deletions tensor/src/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@ macro_rules! slice {
len: udim::MAX,
}
};
[from $start:expr, take $len:expr] => {
$crate::SliceDim {
start: $start,
step: 1,
len: $len,
}
};
[$start:expr; $step:expr; $len:expr] => {
$crate::SliceDim {
start: $start,
Expand Down
2 changes: 1 addition & 1 deletion transformer-cpu/src/kernel/fused_softmax.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::ops::DerefMut;
use tensor::{expand_indices, idx_strides, Tensor};

/// - x: [N0, N1, ... , N_, seq_len, att_len]
pub fn softmax<T>(mut x: Tensor<T>)
pub fn softmax<T>(x: &mut Tensor<T>)
where
T: DerefMut<Target = [u8]>,
{
Expand Down
12 changes: 6 additions & 6 deletions transformer-cpu/src/kernel/gather.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use super::slice;
use common::utok;
use crate::Request;
use std::ops::{Deref, DerefMut};
use tensor::Tensor;
use tensor::{udim, Tensor};

pub fn gather<T, U>(mut x: Tensor<T>, table: &Tensor<U>, tokens: &[&[utok]])
pub fn gather<T, U>(mut x: Tensor<T>, table: &Tensor<U>, requests: &[Request])
where
T: DerefMut<Target = [u8]>,
U: Deref<Target = [u8]>,
Expand All @@ -14,16 +14,16 @@ where
debug_assert_eq!(table.shape().len(), 2);
debug_assert_eq!(table.shape()[1], d);
debug_assert_eq!(
tokens.iter().map(|s| s.len()).sum::<usize>(),
num_token as usize
requests.iter().map(Request::seq_len).sum::<udim>(),
num_token
);
debug_assert!(x.is_contiguous());
debug_assert!(table.is_contiguous());
let d = d as usize * x.data_type().size();

let x = x.as_mut_slice();
let table = table.as_slice();
for (i, &t) in tokens.iter().flat_map(|s| s.iter()).enumerate() {
for (i, &t) in requests.iter().flat_map(|s| s.tokens.iter()).enumerate() {
slice!(x; d; [i]).copy_from_slice(&slice!(table; d; [t]))
}
}
2 changes: 1 addition & 1 deletion transformer-cpu/src/kernel/mat_mul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use tensor::{expand_indices, idx_strides, DataType, Tensor};
/// - c: [N0, N1, ... , N_, m, n]
/// - a: [N0, N1, ... , N_, m, k]
/// - b: [N0, N1, ... , N_, k, n]
pub fn mat_mul<T, U, V>(mut c: Tensor<T>, beta: f32, a: &Tensor<U>, b: &Tensor<V>, alpha: f32)
pub fn mat_mul<T, U, V>(c: &mut Tensor<T>, beta: f32, a: &Tensor<U>, b: &Tensor<V>, alpha: f32)
where
T: DerefMut<Target = [u8]>,
U: Deref<Target = [u8]>,
Expand Down
204 changes: 123 additions & 81 deletions transformer-cpu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,24 @@ pub struct Transformer {
logits: Vec<f32>,
}

pub struct Request<'a> {
pub tokens: &'a [utok],
pub cache: &'a mut [LayerCache],
pub pos: upos,
}

impl Request<'_> {
#[inline]
pub const fn seq_len(&self) -> udim {
self.tokens.len() as udim
}

#[inline]
pub const fn att_len(&self) -> udim {
self.pos + self.seq_len()
}
}

impl Transformer {
#[inline]
pub fn new(model: Box<dyn Llama2>) -> Self {
Expand All @@ -39,13 +57,30 @@ impl Transformer {
self.model.max_position_embeddings()
}

pub fn update(
&self,
tokens: &[&[utok]],
cache: &mut [LayerCache],
pos: upos,
) -> Tensor<Storage> {
let seq_len = tokens.iter().map(|s| s.len()).sum::<usize>() as udim;
pub fn update(&self, requests: &mut [Request]) -> Tensor<Storage> {
// println!("tokens:");
// for request in requests.iter() {
// println!(
// "{:?}: {:?}",
// request.tokens,
// request.pos..request.pos + request.tokens.len() as upos
// );
// }

let mut nt = 0 as udim;
let mut max_seq_len = 0 as udim;
let mut max_att_len = 0 as udim;
for request in requests.iter() {
let seq_len = request.seq_len();
let att_len = request.att_len();
nt += seq_len;
max_seq_len = max_seq_len.max(seq_len);
max_att_len = max_att_len.max(att_len);
}
let nt = nt;
let max_seq_len = max_seq_len;
let max_att_len = max_att_len;

let d = self.model.hidden_size() as udim;
let nh = self.model.num_attention_heads() as udim;
let nkvh = self.model.num_key_value_heads() as udim;
Expand All @@ -57,42 +92,39 @@ impl Transformer {
let dt = self.model.data_type();
let epsilon = self.model.rms_norm_eps();
let theta = self.model.rope_theta();
let att_len = pos + seq_len;
let cat_slice = &[slice![all], slice![pos; 1; seq_len], slice![all]];
let att_slice = &[slice![all], slice![ 0; 1; att_len], slice![all]];
let pos = (pos..pos + seq_len).collect::<Vec<udim>>();
let pos = Tensor::new(DataType::U32, &[seq_len], reslice::<udim, u8>(&pos));
// println!("tokens: {tokens:?}");

let mut x0 = tensor(dt, &[seq_len, d]);
let mut x1 = tensor(dt, &[seq_len, d]);
let mut qkv = tensor(dt, &[seq_len, d + dkv + dkv]);
let mut att = tensor(dt, &[nkvh, head_group * seq_len, att_len]);
let mut gate_up = tensor(dt, &[seq_len, di + di]);

let (mut x2, mut q_att) = if seq_len > 1 {
(
// `seq_len x hidden_size` -reshape-> `seq_len x (num_kv_head x head_group x head_dim)` -transpose(1,2,0,3)-> `num_kv_head x head_group x seq_len x head_dim` -reshape-> `num_kv_head x (head_group x seq_len) x head_dim`
Some(tensor(dt, &[nkvh, head_group * seq_len, dh])),
Some(tensor(dt, &[nh, seq_len, dh])),
)
} else {
(None, None)
};

gather(x0.access_mut(), &self.model.embed_tokens(), tokens);
let mut pos = Vec::<u32>::with_capacity(nt as usize);
for request in requests.iter() {
let len = request.tokens.len() as u32;
pos.extend(request.pos..(request.pos + len));
}
let pos = Tensor::new(DataType::U32, &[nt], reslice::<udim, u8>(&pos));

let mut x0 = tensor(dt, &[nt, d]);
let mut x1 = tensor(dt, &[nt, d]);
let mut qkv = tensor(dt, &[nt, d + dkv + dkv]);
let mut q_buf = vec![0u8; (nh * max_seq_len * dh) as usize * dt.size()];
let mut att_buf =
vec![0u8; (nkvh * head_group * max_seq_len * max_att_len) as usize * dt.size()];
// `num_token x hidden_size`
// -|reshape|------------> `num_token x (num_kv_head x head_group x head_dim)`
// -|transpose(1,2,0,3)|-> `num_kv_head x head_group x num_token x head_dim`
// -|reshape|------------> `num_kv_head x (head_group x num_token) x head_dim`
let mut x2 = tensor(dt, &[nkvh, head_group * nt, dh]);
let mut gate_up = tensor(dt, &[nt, di + di]);

gather(x0.access_mut(), &self.model.embed_tokens(), requests);
// println!("gather:\n{}", x0.access());

for (layer, cache) in cache.iter_mut().enumerate() {
for layer in 0..self.model.num_hidden_layers() {
let input_layernorm = self.model.input_layernorm(layer);
rms_norm(x1.access_mut(), &x0.access(), &input_layernorm, epsilon);
// println!("layer {layer} input norm:\n{}", x1.access());
let w_qkv = self.model.w_qkv(layer).transpose(&[1, 0]);
mat_mul(qkv.access_mut(), 0., &x1.access_mut(), &w_qkv, 1.);
mat_mul(&mut qkv.access_mut(), 0., &x1.access_mut(), &w_qkv, 1.);
let mut qkv = qkv.split(1, &[d as _, dkv as _, dkv as _]);
let v = qkv.pop().unwrap().reshape(&[seq_len, nkvh, dh]);
let mut k = qkv.pop().unwrap().reshape(&[seq_len, nkvh, dh]);
let mut q = qkv.pop().unwrap().reshape(&[seq_len, nh, dh]);
let v = qkv.pop().unwrap().reshape(&[nt, nkvh, dh]);
let mut k = qkv.pop().unwrap().reshape(&[nt, nkvh, dh]);
let mut q = qkv.pop().unwrap().reshape(&[nt, nh, dh]);
// println!("layer {layer} q:\n{}", q.access());
// println!("layer {layer} k:\n{}", k.access());
// println!("layer {layer} v:\n{}", v.access());
Expand All @@ -104,62 +136,69 @@ impl Transformer {
let k = k.transpose(&[1, 0, 2]);
let v = v.transpose(&[1, 0, 2]);

let (k_cache, v_cache) = cache.get();
let mut k_cat = k_cache.clone().slice(cat_slice);
let mut v_cat = v_cache.clone().slice(cat_slice);
let q_att = if let Some(q_att) = q_att.as_mut() {
q.access().reform_to(&mut q_att.access_mut());
q_att.clone()
} else {
q.reshape(&[nh, seq_len, dh])
};
k.access().reform_to(&mut k_cat.access_mut());
v.access().reform_to(&mut v_cat.access_mut());

let q_att = q_att.clone().reshape(&[nkvh, head_group * seq_len, dh]);
let k_att = k_cache.clone().slice(att_slice);
let v_att = v_cache.clone().slice(att_slice);
// println!("layer {layer} q attention:\n{}", q_att.access());
// println!("layer {layer} k attention:\n{}", k_att.access());
// println!("layer {layer} v attention:\n{}", v_att.access());

{
let k_att = k_att.transpose(&[0, 2, 1]);
mat_mul(
att.access_mut(),
0.,
&q_att.access(),
&k_att.access(),
head_div,
let mut req = 0;
for r in requests.iter_mut() {
let pos = r.pos;
let seq_len = r.seq_len();
let att_len = r.att_len();

let req_slice = &[slice![all], slice![from req, take seq_len], slice![all]];
let cat_slice = &[slice![all], slice![from pos, take seq_len], slice![all]];
let att_slice = &[slice![all], slice![from 0, take att_len], slice![all]];
req += seq_len;

let q = q.clone().slice(req_slice);
let k = k.clone().slice(req_slice);
let v = v.clone().slice(req_slice);

let (k_cache, v_cache) = r.cache[layer].get();
let mut q_att = Tensor::new(dt, &[nh, seq_len, dh], q_buf.as_mut_slice());
let mut k_cat = k_cache.clone().slice(cat_slice);
let mut v_cat = v_cache.clone().slice(cat_slice);
q.access().reform_to(&mut q_att);
k.access().reform_to(&mut k_cat.access_mut());
v.access().reform_to(&mut v_cat.access_mut());

let q_att = q_att.reshape(&[nkvh, head_group * seq_len, dh]);
let k_att = k_cache.clone().slice(att_slice);
let v_att = v_cache.clone().slice(att_slice);
// println!("layer {layer} q attention:\n{}", q_att.access());
// println!("layer {layer} k attention:\n{}", k_att.access());
// println!("layer {layer} v attention:\n{}", v_att.access());

let mut att = Tensor::new(
dt,
&[nkvh, head_group * seq_len, att_len],
att_buf.as_mut_slice(),
);
{
let mut att = att.clone().reshape(&[nh, seq_len, att_len]);
let k_att = k_att.transpose(&[0, 2, 1]);
mat_mul(&mut att, 0., &q_att, &k_att.access(), head_div);
// println!("layer {layer} before softmax:\n{}", att.access());
softmax(att.access_mut());
att = att.reshape(&[nh, seq_len, att_len]);
softmax(&mut att);
// println!("layer {layer} after softmax:\n{}", att.access());
att = att.reshape(&[nkvh, head_group * seq_len, att_len]);
{
mat_mul(&mut x2.access_mut(), 0., &att, &v_att.access(), 1.);
let x2 = x2.clone().reshape(&[nh, seq_len, dh]).transpose(&[1, 0, 2]);
let mut x1 = x1.clone().reshape(&[seq_len, nh, dh]);
x2.access().reform_to(&mut x1.access_mut());
}
// println!("layer {layer} after attention:\n{}", x1.access());
}
if let Some(x2) = x2.as_mut() {
mat_mul(x2.access_mut(), 0., &att.access(), &v_att.access(), 1.);
let x2 = x2.clone().reshape(&[nh, seq_len, dh]).transpose(&[1, 0, 2]);
let mut x1 = x1.clone().reshape(&[seq_len, nh, dh]);
x2.access().reform_to(&mut x1.access_mut());
} else {
let mut x2 = x1.clone().reshape(&[nkvh, head_group * seq_len, dh]);
mat_mul(x2.access_mut(), 0., &att.access(), &v_att.access(), 1.);
}
// println!("layer {layer} after attention:\n{}", x1.access());
}

let wo = self.model.self_attn_o_proj(layer).transpose(&[1, 0]);
mat_mul(x0.access_mut(), 1., &x1.access(), &wo, 1.);
mat_mul(&mut x0.access_mut(), 1., &x1.access(), &wo, 1.);
// println!("layer {layer} o_proj:\n{}", x0.access());

let post_layernorm = self.model.post_attention_layernorm(layer);
rms_norm(x1.access_mut(), &x0.access(), &post_layernorm, epsilon);
// println!("layer {layer} post norm:\n{}", x1.access());

let w_gate_up = self.model.mlp_gate_up(layer).transpose(&[1, 0]);
mat_mul(gate_up.access_mut(), 0., &x1.access(), &w_gate_up, 1.);
mat_mul(&mut gate_up.access_mut(), 0., &x1.access(), &w_gate_up, 1.);
let mut gate_up = gate_up.split(1, &[di as _, di as _]);
let up = gate_up.pop().unwrap();
let mut gate = gate_up.pop().unwrap();
Expand All @@ -170,15 +209,19 @@ impl Transformer {
// println!("layer {layer} swiglu:\n{}", gate.access());

let mlp_down = self.model.mlp_down(layer).transpose(&[1, 0]);
mat_mul(x0.access_mut(), 1., &gate.access(), &mlp_down, 1.);
mat_mul(&mut x0.access_mut(), 1., &gate.access(), &mlp_down, 1.);
// println!("layer {layer} down:\n{}", x0.access());
}

x0
}

pub fn forward(&mut self, token: utok, cache: &mut [LayerCache], pos: upos) -> &[f32] {
let mut x = self.update(&[&[token]], cache, pos);
let mut x = self.update(&mut [Request {
tokens: &[token],
cache,
pos,
}]);

let model_norm = self.model.model_norm();
rms_norm_inplace(&mut x.access_mut(), &model_norm, self.model.rms_norm_eps());
Expand All @@ -187,7 +230,7 @@ impl Transformer {
let dt = self.model.data_type();
let voc = self.model.vocab_size() as udim;
mat_mul(
Tensor::new(dt, &[1, voc], reslice_mut(&mut self.logits)),
&mut Tensor::new(dt, &[1, voc], reslice_mut(&mut self.logits)),
0.,
&x.access(),
&self.model.lm_head().transpose(&[1, 0]),
Expand All @@ -211,7 +254,6 @@ impl Transformer {
}
}

#[inline]
fn tensor(dt: DataType, shape: &[udim]) -> Tensor<Storage> {
Tensor::new(
dt,
Expand Down
8 changes: 6 additions & 2 deletions xtask/src/generate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use std::{
use tokenizer::Tokenizer;
use transformer_cpu::{
model_parameters::{Allocator, Llama2, Memory},
Transformer,
Request, Transformer,
};

#[derive(Args, Default)]
Expand Down Expand Up @@ -136,7 +136,11 @@ fn on_host(
let time = Instant::now();
let (last, tokens) = prompt_tokens.split_last().expect("prompt is empty");
if !tokens.is_empty() {
transformer.update(&[tokens], &mut kv_cache, 0);
transformer.update(&mut [Request {
tokens,
cache: &mut kv_cache,
pos: 0,
}]);
}
info!("prefill transformer ... {:?}", time.elapsed());

Expand Down
8 changes: 6 additions & 2 deletions xtask/src/service/cpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use common::upos;
use std::{collections::HashMap, time::Instant};
use transformer_cpu::{
model_parameters::{Llama2, Memory},
LayerCache, Transformer,
LayerCache, Request, Transformer,
};

pub(super) fn run(
Expand Down Expand Up @@ -52,7 +52,11 @@ pub(super) fn run(
});

if !tokens.is_empty() {
transformer.update(&[tokens], &mut session.kv_cache, session.pos as _);
transformer.update(&mut [Request {
tokens,
cache: &mut session.kv_cache,
pos: session.pos as _,
}]);
session.pos += tokens.len() as upos;
}

Expand Down

0 comments on commit cb3c6ff

Please sign in to comment.