From d066026d616fb7aa973b3cca010f75f9df6c8352 Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Mon, 18 Mar 2024 12:04:38 +0800 Subject: [PATCH] =?UTF-8?q?refactor(transformer):=20=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?=E9=9A=8F=E6=9C=BA=E9=87=87=E6=A0=B7=E7=9A=84=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- Cargo.lock | 40 ++++++++++++++++++------------- Cargo.toml | 1 + service/src/cpu.rs | 21 ++++++---------- service/src/lib.rs | 9 ------- service/src/nvidia.rs | 37 +++++++++++++--------------- tensor/Cargo.toml | 4 ++-- transformer-cpu/src/kernel/mod.rs | 4 ++-- transformer-cpu/src/lib.rs | 38 ++++++++++++++++++++++++++--- transformer-nvidia/src/lib.rs | 39 ++++++++++++++++++++++++++---- transformer/Cargo.toml | 4 ++-- transformer/src/lib.rs | 2 ++ transformer/src/sample.rs | 9 +++++++ 12 files changed, 135 insertions(+), 73 deletions(-) create mode 100644 transformer/src/sample.rs diff --git a/Cargo.lock b/Cargo.lock index 9511c70d..cf5a3951 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -93,7 +93,7 @@ dependencies = [ "regex", "rustc-hash", "shlex", - "syn 2.0.52", + "syn 2.0.53", "which", ] @@ -126,7 +126,7 @@ checksum = "4da9a32f3fed317401fa3c862968128267c3106685286e15d5aaa3d7389c2f60" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.53", ] [[package]] @@ -163,9 +163,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.2" +version = "4.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b230ab84b0ffdf890d5a10abdbc8b83ae1c4918275daea1ab8801f71536b2651" +checksum = "949626d00e063efc93b6dca932419ceb5432f99769911c0b995f7e884c778813" dependencies = [ "clap_builder", "clap_derive", @@ -185,14 +185,14 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.5.0" +version = "4.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "307bc0538d5f0f83b8248db3087aa92fe504e4691294d0c96c0eabc33f47ba47" +checksum = "90239a040c80f5e14809ca132ddc4176ab33d5e17e49691793296e3fcb34d72f" dependencies = [ - "heck", + "heck 0.5.0", "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.53", ] [[package]] @@ -302,10 +302,10 @@ version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5ffccbb6966c05b32ef8fbac435df276c4ae4d3dc55a8cd0eb9745e6c12f546a" dependencies = [ - "heck", + "heck 0.4.1", "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.53", ] [[package]] @@ -469,6 +469,12 @@ version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + [[package]] name = "home" version = "0.5.9" @@ -696,7 +702,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a41cf62165e97c7f814d2221421dbb9afcbcdb0a88068e5ea206e19951c2cbb5" dependencies = [ "proc-macro2", - "syn 2.0.52", + "syn 2.0.53", ] [[package]] @@ -710,9 +716,9 @@ dependencies = [ [[package]] name = "pulp" -version = "0.18.8" +version = "0.18.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "091bad01115892393939669b38f88ff2b70838e969a7ac172a9d06d05345a732" +checksum = "03457ac216146f43f921500bac4e892d5cd32b0479b929cbfc90f95cd6c599c2" dependencies = [ "bytemuck", "libm", @@ -875,7 +881,7 @@ checksum = "7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.53", ] [[package]] @@ -959,9 +965,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.52" +version = "2.0.53" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b699d15b36d1f02c3e7c69f8ffef53de37aefae075d8488d4ba1a7788d574a07" +checksum = "7383cd0e49fff4b6b90ca5670bfd3e9d6a733b3f90c686605aa7eec8c4996032" dependencies = [ "proc-macro2", "quote", @@ -1010,7 +1016,7 @@ checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.53", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 92f6c5d0..69690b86 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,7 @@ resolver = "2" [workspace.dependencies] find_cuda_helper = "0.2" half = "2.4" +rayon = "1.9" serde_json = "1.0" serde = "1.0" log = "0.4" diff --git a/service/src/cpu.rs b/service/src/cpu.rs index 03cce940..7a2e8a9b 100644 --- a/service/src/cpu.rs +++ b/service/src/cpu.rs @@ -1,9 +1,7 @@ -use crate::{argmax, Command}; +use crate::Command; use common::utok; -use half::f16; use std::{collections::HashMap, path::Path, time::Instant}; -use tensor::reslice; -use transformer_cpu::{LayerCache, Memory, Request, Transformer}; +use transformer_cpu::{LayerCache, Memory, Request, SampleArgs, Transformer}; pub struct CpuTask { transformer: Transformer, @@ -43,22 +41,17 @@ impl CpuTask { let eos = self.transformer.eos_token_id(); let time = Instant::now(); - let mut logits = self + let mut token = self .transformer - .decode(vec![ctx.request(&prompt, max_seq_len)]) + .decode(vec![ctx.request(&prompt, max_seq_len)], SampleArgs::Top)[0] .1; info!("prefill transformer ... {:?}", time.elapsed()); - loop { - let token = argmax(reslice::(logits.as_slice())); - if token == eos { - break; - } + while token != eos { responsing.send(token).unwrap(); - - logits = self + token = self .transformer - .decode(vec![ctx.request(&[token], max_seq_len)]) + .decode(vec![ctx.request(&[token], max_seq_len)], SampleArgs::Top)[0] .1; } } diff --git a/service/src/lib.rs b/service/src/lib.rs index 11edcc45..c0624cfb 100644 --- a/service/src/lib.rs +++ b/service/src/lib.rs @@ -165,12 +165,3 @@ impl SessionContext { } } } - -fn argmax(logits: &[T]) -> utok { - logits - .iter() - .enumerate() - .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap()) - .unwrap() - .0 as _ -} diff --git a/service/src/nvidia.rs b/service/src/nvidia.rs index bbea2215..111b087c 100644 --- a/service/src/nvidia.rs +++ b/service/src/nvidia.rs @@ -1,11 +1,9 @@ -use crate::{argmax, Command}; +use crate::Command; use common::utok; -use half::f16; use std::{ collections::HashMap, fs::File, io::Read, path::Path, sync::mpsc::Receiver, time::Instant, }; -use tensor::reslice; -use transformer_cpu::{Llama2, Memory}; +use transformer_cpu::{Llama2, Memory, SampleArgs}; use transformer_nvidia::{ cuda::{ContextGuard, Stream}, LayerCache, Request, Transformer, @@ -49,25 +47,24 @@ pub fn task(model_dir: impl AsRef, receiver: Receiver, ctx: &Cont .or_insert_with_key(|&id| SessionContext::new(&transformer, id, &transfer)); let time = Instant::now(); - let mut logits = transformer - .decode(vec![ctx.request(&prompt, max_seq_len)], &compute, &transfer) - .1; + let mut token = transformer.decode( + vec![ctx.request(&prompt, max_seq_len)], + SampleArgs::Top, + &compute, + &transfer, + )[0] + .1; info!("prefill transformer ... {:?}", time.elapsed()); - loop { - let token = argmax(reslice::(logits.as_slice())); - if token == eos { - break; - } + while token != eos { responsing.send(token).unwrap(); - - logits = transformer - .decode( - vec![ctx.request(&[token], max_seq_len)], - &compute, - &transfer, - ) - .1; + token = transformer.decode( + vec![ctx.request(&[token], max_seq_len)], + SampleArgs::Top, + &compute, + &transfer, + )[0] + .1; } } Command::Drop { id } => { diff --git a/tensor/Cargo.toml b/tensor/Cargo.toml index 301f3db5..3861ce5e 100644 --- a/tensor/Cargo.toml +++ b/tensor/Cargo.toml @@ -7,8 +7,8 @@ authors = ["YdrMaster "] # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -half.workspace = true smallvec = "1.13" nalgebra = "0.32" -rayon = "1.9" +half.workspace = true +rayon.workspace = true serde.workspace = true diff --git a/transformer-cpu/src/kernel/mod.rs b/transformer-cpu/src/kernel/mod.rs index 92eb7f2b..37d16d55 100644 --- a/transformer-cpu/src/kernel/mod.rs +++ b/transformer-cpu/src/kernel/mod.rs @@ -14,8 +14,8 @@ pub(super) use swiglu::swiglu; macro_rules! slice { ($blob:expr; $width:expr; [$line:expr]) => { - $blob[$line as usize * $width..][..$width] + $blob[$line as usize * $width as usize..][..$width as usize] }; } -use slice; +pub(super) use slice; diff --git a/transformer-cpu/src/lib.rs b/transformer-cpu/src/lib.rs index 8adf0fc7..c32d1f95 100644 --- a/transformer-cpu/src/lib.rs +++ b/transformer-cpu/src/lib.rs @@ -2,13 +2,14 @@ mod kernel; mod storage; use common::utok; +use gemm::f16; use kernel::{gather, mat_mul, rms_norm, rotary_embedding, softmax, swiglu}; use storage::Storage; use tensor::{reslice, slice, udim, DataType, Tensor}; pub type Request<'a, Id> = transformer::Request<'a, Id, Storage>; pub type LayerCache = transformer::LayerCache; -pub use transformer::{save, Llama2, Memory}; +pub use transformer::{save, Llama2, Memory, SampleArgs}; pub struct Transformer(Box); @@ -33,7 +34,11 @@ impl Transformer { self.0.eos_token_id() } - pub fn decode(&mut self, mut requests: Vec>) -> (Vec, Tensor>) { + pub fn decode( + &mut self, + mut requests: Vec>, + sample: SampleArgs, + ) -> Vec<(Id, utok)> { requests.sort_unstable_by_key(|t| t.tokens.len()); // println!("tokens:"); @@ -226,7 +231,25 @@ impl Transformer { mat_mul(&mut logits, 0., &x, &lm_head, 1.); // println!("logits:\n{}", logits.access()); - (requests.into_iter().map(|r| r.id).collect(), logits) + let logits: &[f16] = reslice(logits.as_slice()); + requests + .into_iter() + .enumerate() + .map(|(i, r)| { + let logits = &kernel::slice!(logits; voc; [i]); + ( + r.id, + match sample { + SampleArgs::Top => argmax(logits), + SampleArgs::Random { + temperature: _, + top_k: _, + top_p: _, + } => todo!(), + }, + ) + }) + .collect() } } @@ -257,3 +280,12 @@ fn test_build() { let t1 = Instant::now(); println!("build transformer {:?}", t1 - t0); } + +fn argmax(logits: &[T]) -> utok { + logits + .iter() + .enumerate() + .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap()) + .unwrap() + .0 as _ +} diff --git a/transformer-nvidia/src/lib.rs b/transformer-nvidia/src/lib.rs index ef349678..9dd6e53e 100644 --- a/transformer-nvidia/src/lib.rs +++ b/transformer-nvidia/src/lib.rs @@ -7,15 +7,18 @@ mod storage; #[macro_use] extern crate log; +use ::half::f16; +use common::utok; use cublas::Cublas; use cuda::{AsRaw, CudaDataType::half, Stream}; use kernel::{gather, mat_mul, FusedSoftmax, Reform, RmsNormalization, RotaryEmbedding, Swiglu}; use parameters::{LayersParameters, ModelParameters}; use storage::Storage; -use tensor::{slice, udim, DataType, Tensor}; +use tensor::{reslice, slice, udim, DataType, Tensor}; pub type Request<'a, 'b, Id> = transformer::Request<'a, Id, Storage<'b>>; pub type LayerCache<'a> = transformer::LayerCache>; +use transformer::SampleArgs; pub use transformer::{Llama2, Memory}; pub extern crate cuda; @@ -57,10 +60,11 @@ impl<'ctx> Transformer<'ctx> { pub fn decode( &mut self, - mut requests: Vec>, + mut requests: Vec>, + sample: SampleArgs, compute: &Stream<'ctx>, transfer: &Stream<'ctx>, - ) -> (Vec, Tensor>) { + ) -> Vec<(Id, utok)> { requests.sort_unstable_by_key(|t| t.tokens.len()); // println!("tokens:"); @@ -286,7 +290,25 @@ impl<'ctx> Transformer<'ctx> { logits_dev.physical().copy_out(logits.physical_mut()); // println!("logits:\n{}", logits); - (requests.into_iter().map(|r| r.id).collect(), logits) + let logits: &[f16] = reslice(logits.as_slice()); + requests + .into_iter() + .enumerate() + .map(|(i, r)| { + let logits = &logits[i * voc as usize..][..voc as usize]; + ( + r.id, + match sample { + SampleArgs::Top => argmax(logits), + SampleArgs::Random { + temperature: _, + top_k: _, + top_p: _, + } => todo!(), + }, + ) + }) + .collect() } } @@ -309,3 +331,12 @@ fn map_tensor(tensor: &Tensor) -> Tensor> { }) } } + +fn argmax(logits: &[T]) -> utok { + logits + .iter() + .enumerate() + .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap()) + .unwrap() + .0 as _ +} diff --git a/transformer/Cargo.toml b/transformer/Cargo.toml index 58487778..7a9ae32d 100644 --- a/transformer/Cargo.toml +++ b/transformer/Cargo.toml @@ -9,9 +9,9 @@ authors = ["YdrMaster "] [dependencies] common = { path = "../common" } tensor = { path = "../tensor" } -half.workspace = true -rayon = "1.9" memmap2 = "0.9" safetensors = "0.4" +half.workspace = true +rayon.workspace = true serde = { workspace = true, features = ["derive"] } serde_json.workspace = true diff --git a/transformer/src/lib.rs b/transformer/src/lib.rs index 87b0853f..7e25f15f 100644 --- a/transformer/src/lib.rs +++ b/transformer/src/lib.rs @@ -6,8 +6,10 @@ mod cache; mod host_memory; mod parameters; mod request; +mod sample; pub use cache::LayerCache; pub use host_memory::HostMemory; pub use parameters::{save, Llama2, Memory, SafeTensorError}; pub use request::Request; +pub use sample::SampleArgs; diff --git a/transformer/src/sample.rs b/transformer/src/sample.rs new file mode 100644 index 00000000..0de8d8f2 --- /dev/null +++ b/transformer/src/sample.rs @@ -0,0 +1,9 @@ +#[derive(Clone, PartialEq, Debug)] +pub enum SampleArgs { + Top, + Random { + temperature: f32, + top_k: usize, + top_p: f32, + }, +}