diff --git a/Cargo.lock b/Cargo.lock index 7d8ffe34..d841edda 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -467,6 +467,8 @@ name = "causal-lm" version = "0.0.0" dependencies = [ "common", + "half", + "rand", "tensor", ] @@ -1362,9 +1364,9 @@ checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" [[package]] name = "jobserver" -version = "0.1.30" +version = "0.1.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "685a7d121ee3f65ae4fddd72b25a04bb36b6af81bc0828f7d5434c0fe60fa3a2" +checksum = "d2b099aaa34a9751c5bf0878add70444e1ed2dd73f347be99003d4577277de6e" dependencies = [ "libc", ] @@ -1994,9 +1996,9 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.33" +version = "0.38.34" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3cc72858054fcff6d7dea32df2aeaee6a7c24227366d7ea429aada2f26b16ad" +checksum = "70dc5ec042f7a43c4a73241207cecc9873a06d45debb38b329f8541d85c2730f" dependencies = [ "bitflags 2.5.0", "errno", @@ -2027,9 +2029,9 @@ checksum = "ecd36cc4259e3e4514335c4a138c6b43171a8d61d8f5c9348f9fc7529416f247" [[package]] name = "rustls-webpki" -version = "0.102.2" +version = "0.102.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "faaa0a62740bedb9b2ef5afa303da42764c012f743917351dc9a237ea1663610" +checksum = "f3bce581c0dd41bce533ce695a1437fa16a7ab5ac3ccfa99fe1a620a7885eabf" dependencies = [ "ring", "rustls-pki-types", @@ -2521,10 +2523,12 @@ dependencies = [ name = "transformer-cpu" version = "0.0.0" dependencies = [ + "causal-lm", "common", "gemm", "intel-mkl-src", "intel-mkl-tool", + "itertools", "tensor", "transformer", ] diff --git a/causal-lm/Cargo.toml b/causal-lm/Cargo.toml index 41d8b82e..53af23c1 100644 --- a/causal-lm/Cargo.toml +++ b/causal-lm/Cargo.toml @@ -9,3 +9,5 @@ authors = ["YdrMaster "] [dependencies] common = { path = "../common" } tensor = { path = "../tensor" } +rand = "0.8" +half.workspace = true diff --git a/causal-lm/src/lib.rs b/causal-lm/src/lib.rs index bf470c6a..5cf857e3 100644 --- a/causal-lm/src/lib.rs +++ b/causal-lm/src/lib.rs @@ -1,10 +1,12 @@ //! 提供因果语言模型的特性定义。 -// #![deny(warnings, missing_docs)] +#![deny(warnings, missing_docs)] mod query_context; +mod sample; pub use query_context::QueryContext; +pub use sample::SampleArgs; use common::{upos, utok}; use tensor::{udim, Tensor}; @@ -36,7 +38,7 @@ pub trait CausalLM { hidden_state: Tensor, ) -> Tensor; /// 对 logits 进行采样。 - fn sample(&self, logits: Tensor) -> Vec; + fn sample(&self, logits: Tensor, args: SampleArgs) -> Vec; } /// 解码的要求。 @@ -50,7 +52,7 @@ pub struct DecodingMeta { /// 生成位置张量。 #[inline] pub fn pos<'a, S: 'a>( - queries: impl IntoIterator>, + queries: impl IntoIterator>, nt_hint: udim, ) -> Tensor> { let mut ans = Vec::with_capacity(nt_hint as usize); diff --git a/causal-lm/src/sample.rs b/causal-lm/src/sample.rs new file mode 100644 index 00000000..b02aa8cf --- /dev/null +++ b/causal-lm/src/sample.rs @@ -0,0 +1,157 @@ +#![allow(missing_docs)] + +use common::utok; +use std::{cmp::Ordering, collections::BinaryHeap, fmt::Debug}; + +#[derive(Clone, PartialEq, Debug)] +pub struct SampleArgs { + pub temperature: f32, + pub top_k: usize, + pub top_p: f32, +} + +impl SampleArgs { + #[inline] + fn is_argmax(&self) -> bool { + self.temperature <= 0. || self.top_k < 2 || self.top_p <= 0. + } + + pub fn random(&self, logits: &[T]) -> utok + where + T: BetweenF32 + PartialOrd, + { + if self.is_argmax() { + return logits + .iter() + .enumerate() + .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap()) + .unwrap() + .0 as _; + } + + #[derive(Clone, Copy, PartialEq, Debug)] + struct Probability { + val: f32, + tok: utok, + } + impl Eq for Probability {} + impl PartialOrd for Probability { + #[inline] + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } + } + impl Ord for Probability { + #[inline] + fn cmp(&self, other: &Self) -> Ordering { + match self.val.partial_cmp(&other.val).unwrap() { + Ordering::Equal => self.tok.cmp(&other.tok), + ord => ord.reverse(), + } + } + } + impl From<(usize, &T)> for Probability { + #[inline] + fn from((i, p): (usize, &T)) -> Self { + Self { + val: p.get(), + tok: i as _, + } + } + } + + // top-k & max + let logits = if self.top_k < logits.len() { + let mut buf = BinaryHeap::with_capacity(self.top_k + 1); + for it in logits.iter().enumerate() { + buf.push(Probability::from(it)); + if buf.len() > self.top_k { + buf.pop(); + } + } + buf.into_vec() + } else { + let mut buf = logits + .iter() + .enumerate() + .map(Probability::from) + .collect::>(); + buf.sort_unstable(); + buf + }; + let max = logits[0].val; + // temperature & sum + let (logits, sum) = { + let mut logits = logits; + let mut sum = 0.; + for pi in logits.iter_mut() { + pi.val = ((pi.val - max) / self.temperature).exp(); + sum += pi.val; + } + (logits, sum) + }; + // top p + let logits = if self.top_p < 1. { + let i = logits + .iter() + .scan(self.top_p * sum, |top_p, pi| { + if *top_p > 0. { + *top_p -= pi.val; + Some(()) + } else { + None + } + }) + .count(); + &logits[..i] + } else { + &logits[..] + }; + // random + let mut rand = rand::random::() * sum; + logits + .iter() + .find(|pi| { + rand -= pi.val; + rand <= 0. + }) + .unwrap_or(logits.last().unwrap()) + .tok + } +} + +pub trait BetweenF32 { + fn zero() -> Self; + fn cast(f: f32) -> Self; + fn get(&self) -> f32; +} + +impl BetweenF32 for f32 { + #[inline] + fn zero() -> Self { + 0. + } + #[inline] + fn cast(f: f32) -> Self { + f + } + #[inline] + fn get(&self) -> f32 { + *self + } +} + +impl BetweenF32 for half::f16 { + #[inline] + fn zero() -> Self { + Self::ZERO + } + #[inline] + fn cast(f: f32) -> Self { + Self::from_f32(f) + } + #[inline] + fn get(&self) -> f32 { + Self::to_f32(*self) + } +} diff --git a/transformer-cpu/Cargo.toml b/transformer-cpu/Cargo.toml index 762fd6a9..587bd192 100644 --- a/transformer-cpu/Cargo.toml +++ b/transformer-cpu/Cargo.toml @@ -9,7 +9,9 @@ authors = ["YdrMaster "] [dependencies] common = { path = "../common" } tensor = { path = "../tensor" } +causal-lm = { path = "../causal-lm" } transformer = { path = "../transformer" } +itertools = "0.12" gemm = "0.17" intel-mkl-src = { version = "0.8", features = ["mkl-dynamic-lp64-iomp"] } diff --git a/transformer-cpu/src/lib.rs b/transformer-cpu/src/lib.rs index 4dc3a21b..ee8b1733 100644 --- a/transformer-cpu/src/lib.rs +++ b/transformer-cpu/src/lib.rs @@ -1,13 +1,285 @@ mod kernel; -use common::{utok, Blob}; +use causal_lm::{CausalLM, DecodingMeta, QueryContext}; +use common::{upos, utok, Blob}; use gemm::f16; +use itertools::izip; use kernel::CpuKernels; +use std::slice::from_raw_parts; use tensor::{reslice, slice, split, udim, DataType, LocalSplitable, Tensor}; use transformer::{pos, Kernels, LayerBuffer, LayerCache, Llama2, Memory, Request, SampleArgs}; pub struct Transformer(Memory); +impl CausalLM for Transformer { + type Storage = Blob; + + #[inline] + fn eos_token(&self) -> utok { + self.0.eos_token_id() + } + + fn new_cache(&self) -> Tensor { + let dt = self.0.data_type(); + let nlayers = self.0.num_hidden_layers() as udim; + let nkvh = self.0.num_key_value_heads() as udim; + let max_seq_len = self.0.max_position_embeddings() as udim; + let d = self.0.hidden_size() as udim; + let nh = self.0.num_attention_heads() as udim; + + Tensor::alloc(dt, &[nlayers, 2, nkvh, max_seq_len, d / nh], Blob::new) + } + + fn duplicate_cache(&self, cache: &Tensor, pos: upos) -> Tensor { + let &[_nlayers, 2, _nkvh, max_seq_len, _dh] = cache.shape() else { + panic!() + }; + assert!(pos <= max_seq_len); + let slice = [ + slice![=>], + slice![=>], + slice![=>], + slice![=>pos], + slice![=>], + ]; + + let mut ans = Tensor::alloc(cache.data_type(), cache.shape(), Blob::new); + cache + .as_ref() + .slice(&slice) + .map_physical(|u| &**u) + .reform_to(&mut ans.as_mut().slice(&slice).map_physical(|u| &mut **u)); + ans + } + + fn token_embed(&self, queries: impl IntoIterator) -> Tensor { + let dt = self.0.data_type(); + let d = self.0.hidden_size() as udim; + let kernels = CpuKernels::new(&self.0); + + let tokens = queries.into_iter().collect::>(); + let nt = tokens.len() as udim; + + let mut x0 = Tensor::alloc(dt, &[nt, d], Blob::new); + kernels.gather(&mut x0, &self.0.embed_tokens(), tokens); + x0 + } + + fn forward<'a>( + &self, + queries: impl IntoIterator>, + token_embedded: Tensor, + ) -> Tensor + where + Self: 'a, + { + let mut queries = queries.into_iter().collect::>(); + let mut nt = 0; + let mut max_seq_len = 0; + let mut max_att_len = 0; + let seq_len = queries + .iter() + .map(|q| { + let seq = q.seq_len(); + let att = q.att_len(); + nt += seq; + max_seq_len = max_seq_len.max(seq); + max_att_len = max_att_len.max(att); + seq + }) + .collect::>(); + + let dt = self.0.data_type(); + let d = self.0.hidden_size() as udim; + let nh = self.0.num_attention_heads() as udim; + let nkvh = self.0.num_key_value_heads() as udim; + let dh = d / nh; + let dkv = nkvh * dh; + let di = self.0.intermediate_size() as udim; + let head_group = nh / nkvh; + let head_div = (dh as f32).sqrt().recip(); + let kernels = CpuKernels::new(&self.0); + + let reusing = (d + dkv + dkv).max(di + di); + let mut state_buf = Tensor::alloc(dt, &[nt, d + reusing], Blob::new); + macro_rules! state { + () => { + split!(state_buf.as_mut().map_physical(|u| LocalSplitable::from(&mut **u)); [1]: d, reusing) + }; + } + + let mut q_buf = Blob::new((nh * max_seq_len * dh) as usize * dt.size()); + let mut att_buf = Blob::new((nh * max_seq_len * max_att_len) as usize * dt.size()); + let pos = causal_lm::pos(&queries, nt); + let pos = pos.as_ref().map_physical(|u| reslice(u)); + + let mut x = token_embedded; + for layer in 0..self.0.num_hidden_layers() { + let (mut x1, qkv) = state!(); + let mut qkv = qkv.slice(&[slice![=>], slice![=> d + dkv + dkv]]); + + let input_layernorm = self.0.input_layernorm(layer); + kernels.rms_norm(&mut x1, &x, &input_layernorm); + + let w_qkv = self.0.w_qkv(layer).transpose(&[1, 0]); + kernels.mat_mul(&mut qkv, 0., &x1, &w_qkv, 1.); + + let (q, k, v) = split!(qkv; [1]: d, dkv, dkv); + let mut q = q.reshape(&[nt, nh, dh]); + let mut k = k.reshape(&[nt, nkvh, dh]); + let v = v.reshape(&[nt, nkvh, dh]); + let o = x1.reshape(&[nt, nh, dh]); + + kernels.rotary_embedding(&mut q, &pos); + kernels.rotary_embedding(&mut k, &pos); + + let q = q.transpose(&[1, 0, 2]).split(1, &seq_len); + let k = k.transpose(&[1, 0, 2]).split(1, &seq_len); + let v = v.transpose(&[1, 0, 2]).split(1, &seq_len); + let o = o.transpose(&[1, 0, 2]).split(1, &seq_len); + + for (query, q, k, v, mut o) in izip!(&mut queries, q, k, v, o) { + let pos = query.pos(); + let seq_len = query.seq_len(); + let att_len = query.att_len(); + let (mut k_cache, mut v_cache) = query.cache(layer); + + let slice_cat = &[slice![=>], slice![pos =>=> seq_len], slice![=>]]; + let slice_att = &[slice![=>], slice![ => att_len], slice![=>]]; + let shape_q0 = &[nkvh * head_group, seq_len, dh]; + let shape_q1 = &[nkvh, head_group * seq_len, dh]; + let shape_att0 = &[nkvh, head_group * seq_len, att_len]; + let shape_att1 = &[nkvh * head_group, seq_len, att_len]; + + let mut q_att = Tensor::new(dt, shape_q0, &mut q_buf[..]); + let mut k_cat = k_cache.as_mut().slice(slice_cat).map_physical(|u| &mut **u); + let mut v_cat = v_cache.as_mut().slice(slice_cat).map_physical(|u| &mut **u); + kernels.reform(&mut q_att, &q); + kernels.reform(&mut k_cat, &k); + kernels.reform(&mut v_cat, &v); + + let q_att = q_att.reshape(shape_q1); + let k_att = k_cache.slice(slice_att).transpose(&[0, 2, 1]); + let v_att = v_cache.slice(slice_att); + + let mut att = Tensor::new(dt, shape_att0, &mut att_buf[..]); + kernels.mat_mul(&mut att, 0., &q_att, &k_att, head_div); + let mut att = att.reshape(shape_att1); + kernels.softmax(&mut att); + let mut x2 = q_att; + kernels.mat_mul(&mut x2, 0., &att.reshape(shape_att0), &v_att, 1.); + + kernels.reform(&mut o, &x2.reshape(&[nh, seq_len, dh])); + } + + let (mut x1, gate_up) = state!(); + let mut gate_up = gate_up.slice(&[slice![=>], slice![=> di + di]]); + + let wo = self.0.self_attn_o_proj(layer).transpose(&[1, 0]); + kernels.mat_mul(&mut x, 1., &x1, &wo, 1.); + + let post_layernorm = self.0.post_attention_layernorm(layer); + kernels.rms_norm(&mut x1, &x, &post_layernorm); + + let w_gate_up = self.0.mlp_gate_up(layer).transpose(&[1, 0]); + kernels.mat_mul(&mut gate_up, 0., &x1, &w_gate_up, 1.); + + let (mut gate, up) = split!(gate_up; [1]: di, di); + kernels.swiglu(&mut gate, &up); + + let mlp_down = self.0.mlp_down(layer).transpose(&[1, 0]); + kernels.mat_mul(&mut x, 1., &gate, &mlp_down, 1.); + } + + x + } + + fn decode( + &self, + decoding: impl IntoIterator, + mut hidden_state: Tensor, + ) -> Tensor { + let dt = self.0.data_type(); + let d = self.0.hidden_size(); + let voc = self.0.vocab_size() as udim; + let kernels = CpuKernels::new(&self.0); + + let buf = hidden_state.as_mut_slice(); + let len = d * dt.size(); + + let mut iter = decoding.into_iter(); + let mut begin = 0; + let mut src = 0; + let mut dst = 0; + for DecodingMeta { + num_query, + num_decode, + } in iter.by_ref() + { + begin += num_query; + if num_decode > 0 { + src = begin; + dst = begin; + begin -= num_decode; + break; + } + } + for DecodingMeta { + num_query, + num_decode, + } in iter + { + src += num_query - num_decode; + if src > dst { + for _ in 0..num_decode { + buf.copy_within(src * len..(src + 1) * len, dst * len); + src += 1; + dst += 1; + } + } else { + src += num_decode; + dst += num_decode; + } + } + + if dst == begin { + return Tensor::alloc(dt, &[0, d as _], Blob::new); + } + + let mut x = hidden_state.slice(&[slice![begin => dst], slice![=>]]); + let mut logits = Tensor::alloc(dt, &[x.shape()[0], voc], Blob::new); + + // 复制一个 x 以实现原地归一化 + let x_ = x + .as_ref() + .map_physical(|u| unsafe { from_raw_parts(u.as_ptr(), u.len()) }); + kernels.rms_norm(&mut x, &x_, &self.0.model_norm()); + + let lm_head = self.0.lm_head().transpose(&[1, 0]); + kernels.mat_mul(&mut logits, 0., &x, &lm_head, 1.); + + logits + } + + fn sample(&self, logits: Tensor, args: causal_lm::SampleArgs) -> Vec { + let &[nt, voc] = logits.shape() else { panic!() }; + let dt = logits.data_type(); + + macro_rules! sample { + ($ty:ty) => {{ + let logits: &[$ty] = reslice(logits.as_slice()); + (0..nt).map(|i| args.random(&kernel::slice!(logits; voc; [i]))).collect() + }}; + } + + match dt { + DataType::F16 => sample!(f16), + DataType::F32 => sample!(f32), + _ => unreachable!(), + } + } +} + impl transformer::Transformer for Transformer { type Cache = Blob; @@ -283,7 +555,7 @@ impl Transformer { // 复制一个 x 以实现原地归一化 let x_ = unsafe { x.as_ref() - .map_physical(|u| std::slice::from_raw_parts(u.as_ptr(), u.len())) + .map_physical(|u| from_raw_parts(u.as_ptr(), u.len())) }; kernels.rms_norm(&mut x, &x_, &self.0.model_norm());