Skip to content

Commit

Permalink
refactor(transformer-cpu): 基于 causal-lm 接口重写 transformer 推理
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Apr 23, 2024
1 parent 5435e08 commit a65e109
Show file tree
Hide file tree
Showing 4 changed files with 209 additions and 8 deletions.
14 changes: 8 additions & 6 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion causal-lm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ pub struct DecodingMeta {
/// 生成位置张量。
#[inline]
pub fn pos<'a, S: 'a>(
queries: impl IntoIterator<Item = QueryContext<'a, S>>,
queries: impl IntoIterator<Item = &'a QueryContext<'a, S>>,
nt_hint: udim,
) -> Tensor<Vec<upos>> {
let mut ans = Vec::with_capacity(nt_hint as usize);
Expand Down
2 changes: 2 additions & 0 deletions transformer-cpu/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ authors = ["YdrMaster <[email protected]>"]
[dependencies]
common = { path = "../common" }
tensor = { path = "../tensor" }
causal-lm = { path = "../causal-lm" }
transformer = { path = "../transformer" }
itertools = "0.12"
gemm = "0.17"
intel-mkl-src = { version = "0.8", features = ["mkl-dynamic-lp64-iomp"] }

Expand Down
199 changes: 198 additions & 1 deletion transformer-cpu/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,210 @@
mod kernel;

use common::{utok, Blob};
use causal_lm::{CausalLM, QueryContext};
use common::{upos, utok, Blob};
use gemm::f16;
use itertools::izip;
use kernel::CpuKernels;
use tensor::{reslice, slice, split, udim, DataType, LocalSplitable, Tensor};
use transformer::{pos, Kernels, LayerBuffer, LayerCache, Llama2, Memory, Request, SampleArgs};

pub struct Transformer(Memory);

impl CausalLM for Transformer {
type Storage = Blob;

#[inline]
fn eos_token(&self) -> utok {
self.0.eos_token_id()
}

fn new_cache(&self) -> Tensor<Self::Storage> {
let dt = self.0.data_type();
let nlayers = self.0.num_hidden_layers() as udim;
let nkvh = self.0.num_key_value_heads() as udim;
let max_seq_len = self.0.max_position_embeddings() as udim;
let d = self.0.hidden_size() as udim;
let nh = self.0.num_attention_heads() as udim;

Tensor::alloc(dt, &[nlayers, 2, nkvh, max_seq_len, d / nh], Blob::new)
}

fn duplicate_cache(&self, cache: &Tensor<Self::Storage>, pos: upos) -> Tensor<Self::Storage> {
let &[_nlayers, 2, _nkvh, max_seq_len, _dh] = cache.shape() else {
panic!()
};
assert!(pos <= max_seq_len);
let slice = [
slice![=>],
slice![=>],
slice![=>],
slice![=>pos],
slice![=>],
];

let mut ans = Tensor::alloc(cache.data_type(), cache.shape(), Blob::new);
cache
.as_ref()
.slice(&slice)
.map_physical(|u| &**u)
.reform_to(&mut ans.as_mut().slice(&slice).map_physical(|u| &mut **u));
ans
}

fn token_embed(&self, queries: impl IntoIterator<Item = utok>) -> Tensor<Self::Storage> {
let dt = self.0.data_type();
let d = self.0.hidden_size() as udim;
let kernels = CpuKernels::new(&self.0);

let tokens = queries.into_iter().collect::<Vec<_>>();
let nt = tokens.len() as udim;

let mut x0 = Tensor::alloc(dt, &[nt, d], Blob::new);
kernels.gather(&mut x0, &self.0.embed_tokens(), tokens);
x0
}

fn forward<'a>(
&self,
queries: impl IntoIterator<Item = QueryContext<'a, Self::Storage>>,
token_embedded: Tensor<Self::Storage>,
) -> Tensor<Self::Storage>
where
Self: 'a,
{
let mut queries = queries.into_iter().collect::<Vec<_>>();
let mut nt = 0;
let mut max_seq_len = 0;
let mut max_att_len = 0;
let seq_len = queries
.iter()
.map(|q| {
let seq = q.seq_len();
let att = q.att_len();
nt += seq;
max_seq_len = max_seq_len.max(seq);
max_att_len = max_att_len.max(att);
seq
})
.collect::<Vec<_>>();

let dt = self.0.data_type();
let d = self.0.hidden_size() as udim;
let nh = self.0.num_attention_heads() as udim;
let nkvh = self.0.num_key_value_heads() as udim;
let dh = d / nh;
let dkv = nkvh * dh;
let di = self.0.intermediate_size() as udim;
let head_group = nh / nkvh;
let head_div = (dh as f32).sqrt().recip();
let kernels = CpuKernels::new(&self.0);

let reusing = (d + dkv + dkv).max(di + di);
let mut state_buf = Tensor::alloc(dt, &[nt, d + reusing], Blob::new);
macro_rules! state {
() => {
split!(state_buf.as_mut().map_physical(|u| LocalSplitable::from(&mut **u)); [1]: d, reusing)
};
}

let mut q_buf = Blob::new((nh * max_seq_len * dh) as usize * dt.size());
let mut att_buf = Blob::new((nh * max_seq_len * max_att_len) as usize * dt.size());
let pos = causal_lm::pos(&queries, nt);
let pos = pos.as_ref().map_physical(|u| reslice(u));

let mut x = token_embedded;
for layer in 0..self.0.num_hidden_layers() {
let (mut x1, qkv) = state!();
let mut qkv = qkv.slice(&[slice![=>], slice![=> d + dkv + dkv]]);

let input_layernorm = self.0.input_layernorm(layer);
kernels.rms_norm(&mut x1, &x, &input_layernorm);

let w_qkv = self.0.w_qkv(layer).transpose(&[1, 0]);
kernels.mat_mul(&mut qkv, 0., &x1, &w_qkv, 1.);

let (q, k, v) = split!(qkv; [1]: d, dkv, dkv);
let mut q = q.reshape(&[nt, nh, dh]);
let mut k = k.reshape(&[nt, nkvh, dh]);
let v = v.reshape(&[nt, nkvh, dh]);
let o = x1.reshape(&[nt, nh, dh]);

kernels.rotary_embedding(&mut q, &pos);
kernels.rotary_embedding(&mut k, &pos);

let q = q.transpose(&[1, 0, 2]).split(1, &seq_len);
let k = k.transpose(&[1, 0, 2]).split(1, &seq_len);
let v = v.transpose(&[1, 0, 2]).split(1, &seq_len);
let o = o.transpose(&[1, 0, 2]).split(1, &seq_len);

for (query, q, k, v, mut o) in izip!(&mut queries, q, k, v, o) {
let pos = query.pos();
let seq_len = query.seq_len();
let att_len = query.att_len();
let (mut k_cache, mut v_cache) = query.cache(layer);

let cat_slice = &[slice![=>], slice![pos =>=> seq_len], slice![=>]];
let att_slice = &[slice![=>], slice![ => att_len], slice![=>]];

let mut q_att = Tensor::new(dt, &[nh, seq_len, dh], &mut q_buf[..]);
let mut k_cat = k_cache.as_mut().slice(cat_slice).map_physical(|u| &mut **u);
let mut v_cat = v_cache.as_mut().slice(cat_slice).map_physical(|u| &mut **u);
kernels.reform(&mut q_att, &q);
kernels.reform(&mut k_cat, &k);
kernels.reform(&mut v_cat, &v);

let q_att = q_att.reshape(&[nkvh, head_group * seq_len, dh]);
let k_att = k_cache.slice(att_slice).transpose(&[0, 2, 1]);
let v_att = v_cache.slice(att_slice);

let shape_att0 = &[nkvh, head_group * seq_len, att_len];
let shape_att1 = &[nkvh * head_group, seq_len, att_len];

let mut att = Tensor::new(dt, shape_att0, &mut att_buf[..]);
kernels.mat_mul(&mut att, 0., &q_att, &k_att, head_div);
let mut att = att.reshape(shape_att1);
kernels.softmax(&mut att);
let mut x2 = q_att;
kernels.mat_mul(&mut x2, 0., &att.reshape(shape_att0), &v_att, 1.);

kernels.reform(&mut o, &x2.reshape(&[nh, seq_len, dh]));
}

let (mut x1, gate_up) = state!();
let mut gate_up = gate_up.slice(&[slice![=>], slice![=> di + di]]);

let wo = self.0.self_attn_o_proj(layer).transpose(&[1, 0]);
kernels.mat_mul(&mut x, 1., &x1, &wo, 1.);

let post_layernorm = self.0.post_attention_layernorm(layer);
kernels.rms_norm(&mut x1, &x, &post_layernorm);

let w_gate_up = self.0.mlp_gate_up(layer).transpose(&[1, 0]);
kernels.mat_mul(&mut gate_up, 0., &x1, &w_gate_up, 1.);

let (mut gate, up) = split!(gate_up; [1]: di, di);
kernels.swiglu(&mut gate, &up);

let mlp_down = self.0.mlp_down(layer).transpose(&[1, 0]);
kernels.mat_mul(&mut x, 1., &gate, &mlp_down, 1.);
}

x
}

fn decode(
&self,
decoding: impl IntoIterator<Item = causal_lm::DecodingMeta>,
hidden_state: Tensor<Self::Storage>,
) -> Tensor<Self::Storage> {
todo!()
}

fn sample(&self, logits: Tensor<Self::Storage>) -> Vec<utok> {
todo!()
}
}

impl transformer::Transformer for Transformer {
type Cache = Blob;

Expand Down

0 comments on commit a65e109

Please sign in to comment.