From 862e797f15e6da00fb61af59e27aaf482c84ccb6 Mon Sep 17 00:00:00 2001 From: kilinchange Date: Thu, 4 Jul 2024 19:35:01 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=AE=8C=E6=88=90=E5=AF=92=E6=AD=A6?= =?UTF-8?q?=E7=BA=AA=E6=A8=A1=E5=9E=8B=E6=90=AD=E5=BB=BA=EF=BC=8C=E9=80=82?= =?UTF-8?q?=E9=85=8D=20c=20=E7=AE=97=E5=AD=90=E5=BA=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Cargo.lock | 94 ++++--- Cargo.toml | 3 +- devices/cambricon-mlu/Cargo.toml | 4 +- devices/cambricon-mlu/src/gather.rs | 28 +++ devices/cambricon-mlu/src/lib.rs | 184 +++++++------- devices/cambricon-mlu/src/sample.rs | 19 ++ models/llama/cambricon-mlu/Cargo.toml | 1 + models/llama/cambricon-mlu/src/lib.rs | 350 ++++++++++++++++++++++++-- 8 files changed, 532 insertions(+), 151 deletions(-) create mode 100644 devices/cambricon-mlu/src/gather.rs create mode 100644 devices/cambricon-mlu/src/sample.rs diff --git a/Cargo.lock b/Cargo.lock index b945a621..ea9b5691 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -281,13 +281,24 @@ checksum = "4b82cf0babdbd58558212896d1a4272303a57bdb245c2bf1147185fb45640e70" [[package]] name = "cndrv" version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e214fcc4c0f5219ea135f371289c24614a39dc8bd300ef7a1b13093bdca5f28a" +source = "git+https://github.com/InfiniTensor/cndrv?rev=d354eee#d354eee4f37bc64ac3d56358bb6263316982d29e" dependencies = [ "bindgen", "build-script-cfg", "log", - "search-neuware-tools", + "search-neuware-tools 0.0.0 (git+https://github.com/InfiniTensor/cndrv?rev=d354eee)", +] + +[[package]] +name = "cnnl" +version = "0.0.0" +source = "git+https://github.com/InfiniTensor/cndrv?rev=d354eee#d354eee4f37bc64ac3d56358bb6263316982d29e" +dependencies = [ + "bindgen", + "build-script-cfg", + "cndrv", + "digit-layout", + "search-neuware-tools 0.0.0 (git+https://github.com/InfiniTensor/cndrv?rev=d354eee)", ] [[package]] @@ -320,7 +331,7 @@ dependencies = [ [[package]] name = "common" version = "0.1.0" -source = "git+https://github.com/YdrMaster/operators-rs?rev=189c2d5#189c2d5cfa525a751ef9d80499a90789412072ae" +source = "git+https://github.com/YdrMaster/operators-rs?rev=f5d2b6b#f5d2b6b0a04cd6c26eefb1e579734dded663d06b" dependencies = [ "digit-layout", ] @@ -331,11 +342,12 @@ version = "0.0.0" dependencies = [ "bindgen", "build-script-cfg", - "cndrv", "common 0.0.0", "common-devices", "digit-layout", - "search-neuware-tools", + "operators", + "sample", + "search-neuware-tools 0.0.0 (registry+https://github.com/rust-lang/crates.io-index)", "tensor", ] @@ -871,7 +883,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e310b3a6b5907f99202fcdb4960ff45b93735d7c7d96b760fcff8db2dc0e103d" dependencies = [ "cfg-if", - "windows-targets 0.52.5", + "windows-targets 0.52.6", ] [[package]] @@ -910,8 +922,9 @@ dependencies = [ "causal-lm", "common 0.0.0", "common-cn", + "digit-layout", "llama", - "search-neuware-tools", + "search-neuware-tools 0.0.0 (registry+https://github.com/rust-lang/crates.io-index)", ] [[package]] @@ -1177,9 +1190,11 @@ checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" [[package]] name = "operators" version = "0.0.0" -source = "git+https://github.com/YdrMaster/operators-rs?rev=189c2d5#189c2d5cfa525a751ef9d80499a90789412072ae" +source = "git+https://github.com/YdrMaster/operators-rs?rev=f5d2b6b#f5d2b6b0a04cd6c26eefb1e579734dded663d06b" dependencies = [ "build-script-cfg", + "cndrv", + "cnnl", "common 0.1.0", "cublas", "cuda", @@ -1188,7 +1203,7 @@ dependencies = [ "half", "log", "search-cuda-tools", - "search-neuware-tools", + "search-neuware-tools 0.0.0 (registry+https://github.com/rust-lang/crates.io-index)", ] [[package]] @@ -1451,6 +1466,11 @@ version = "0.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c83ae237201de85c7ece8ad034bc749a6f229f42f07f62a30008b30a3f1e2231" +[[package]] +name = "search-neuware-tools" +version = "0.0.0" +source = "git+https://github.com/InfiniTensor/cndrv?rev=d354eee#d354eee4f37bc64ac3d56358bb6263316982d29e" + [[package]] name = "seq-macro" version = "0.3.5" @@ -1783,7 +1803,7 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" dependencies = [ - "windows-targets 0.52.5", + "windows-targets 0.52.6", ] [[package]] @@ -1803,18 +1823,18 @@ dependencies = [ [[package]] name = "windows-targets" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" dependencies = [ - "windows_aarch64_gnullvm 0.52.5", - "windows_aarch64_msvc 0.52.5", - "windows_i686_gnu 0.52.5", + "windows_aarch64_gnullvm 0.52.6", + "windows_aarch64_msvc 0.52.6", + "windows_i686_gnu 0.52.6", "windows_i686_gnullvm", - "windows_i686_msvc 0.52.5", - "windows_x86_64_gnu 0.52.5", - "windows_x86_64_gnullvm 0.52.5", - "windows_x86_64_msvc 0.52.5", + "windows_i686_msvc 0.52.6", + "windows_x86_64_gnu 0.52.6", + "windows_x86_64_gnullvm 0.52.6", + "windows_x86_64_msvc 0.52.6", ] [[package]] @@ -1825,9 +1845,9 @@ checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" [[package]] name = "windows_aarch64_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" [[package]] name = "windows_aarch64_msvc" @@ -1837,9 +1857,9 @@ checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" [[package]] name = "windows_aarch64_msvc" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" [[package]] name = "windows_i686_gnu" @@ -1849,15 +1869,15 @@ checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" [[package]] name = "windows_i686_gnu" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" [[package]] name = "windows_i686_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" [[package]] name = "windows_i686_msvc" @@ -1867,9 +1887,9 @@ checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" [[package]] name = "windows_i686_msvc" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" [[package]] name = "windows_x86_64_gnu" @@ -1879,9 +1899,9 @@ checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" [[package]] name = "windows_x86_64_gnu" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" [[package]] name = "windows_x86_64_gnullvm" @@ -1891,9 +1911,9 @@ checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" [[package]] name = "windows_x86_64_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" [[package]] name = "windows_x86_64_msvc" @@ -1903,9 +1923,9 @@ checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" [[package]] name = "windows_x86_64_msvc" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] name = "xtask" @@ -1926,7 +1946,7 @@ dependencies = [ "mixtral", "mixtral-cpu", "search-cuda-tools", - "search-neuware-tools", + "search-neuware-tools 0.0.0 (registry+https://github.com/rust-lang/crates.io-index)", "service", "simple_logger", "tensor", diff --git a/Cargo.toml b/Cargo.toml index bc56e30b..ea75446d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,9 +38,8 @@ tokio = { version = "1.38", features = ["rt-multi-thread", "sync"] } digit-layout = "0.0" build-script-cfg = "0.0" -operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "189c2d5", default-features = false } +operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "f5d2b6b", default-features = false } nccl = { git = "https://github.com/YdrMaster/cuda-driver", rev = "877df52" } search-cuda-tools = { git = "https://github.com/YdrMaster/cuda-driver", rev = "877df52" } -cndrv = "0.1" search-neuware-tools = "0.0" diff --git a/devices/cambricon-mlu/Cargo.toml b/devices/cambricon-mlu/Cargo.toml index c8a85ab0..19b17bf4 100644 --- a/devices/cambricon-mlu/Cargo.toml +++ b/devices/cambricon-mlu/Cargo.toml @@ -10,9 +10,9 @@ authors = ["YdrMaster "] common = { path = "../../common" } common-devices = { path = "../common" } tensor = { path = "../../tensor" } +sample = { path = "../../sample" } digit-layout="0.0" -cndrv.workspace = true -# operators = { workspace = true, features = ["cambricon-mlu"] } +operators = { workspace = true, features = ["cambricon-mlu"] } [build-dependencies] build-script-cfg.workspace = true diff --git a/devices/cambricon-mlu/src/gather.rs b/devices/cambricon-mlu/src/gather.rs new file mode 100644 index 00000000..3a0b04ad --- /dev/null +++ b/devices/cambricon-mlu/src/gather.rs @@ -0,0 +1,28 @@ +use common::utok; +use operators::cndrv::{DevByte, Queue}; +use std::ops::{Deref, DerefMut}; +use tensor::Tensor; + +pub fn gather(x: &mut Tensor, table: &Tensor, tokens: I, queue: &Queue) +where + T: DerefMut, + U: Deref, + I: IntoIterator, +{ + let &[_, d] = x.shape() else { panic!() }; + + debug_assert_eq!(x.data_layout(), table.data_layout()); + debug_assert_eq!(table.shape().len(), 2); + debug_assert_eq!(table.shape()[1], d); + debug_assert!(x.is_contiguous()); + debug_assert!(table.is_contiguous()); + let d = d as usize * x.data_layout().nbytes(); + + let x = &mut **x.physical_mut(); + let table = table.as_slice(); + for (i, t) in tokens.into_iter().enumerate() { + let dst = &mut x[d * i..][..d]; + let src = &table[d * t as usize..][..d]; + queue.memcpy_h2d(dst, src); + } +} diff --git a/devices/cambricon-mlu/src/lib.rs b/devices/cambricon-mlu/src/lib.rs index 1d47c1b0..312e83c1 100644 --- a/devices/cambricon-mlu/src/lib.rs +++ b/devices/cambricon-mlu/src/lib.rs @@ -1,18 +1,23 @@ #![cfg(detected_neuware)] // Include the bindings include!(concat!(env!("OUT_DIR"), "/bindings.rs")); -pub extern crate cndrv; -use cndrv::{ContextGuard, ContextSpore}; +mod gather; +mod sample; + +use cndrv::{ContextSpore, CurrentCtx, DevByte}; use common::utok; +pub use operators::{cndrv, cambricon_mlu::Handle as Mlu}; +pub use sample::sample_cpu; // pub type CTensor = Tensor; // use tensor::Tensor; use digit_layout::DigitLayout; +use operators::{cndrv::AsRaw, QueueOf}; use std::ops::{Deref, DerefMut}; -pub use tensor::Tensor as rustTensor; -pub use common_devices::{Kernels, SliceOn}; +pub use tensor::{Tensor as rustTensor, slice}; +pub use common_devices::{Kernels, KernelsA, KernelsB, SliceOn}; impl DataLayout { pub fn new(packed: u16, sign: u16, size: u16, mantissa: u16, exponent: u16) -> Self { @@ -38,7 +43,7 @@ impl From for DataLayout { fn to_ctensor(tensor: &rustTensor) -> Tensor where - T: Deref, + T: Deref, { // 获取 strides let strides_vec: Vec = tensor.strides().iter().map(|&x| x as i64).collect(); @@ -68,23 +73,36 @@ where } } -pub struct CambriconKernels; +pub struct CambriconKernels { + mat_mul: *mut MatmulDescriptor, + rms_norm: *mut RMSNormDescriptor, + rope: *mut RotaryEmbeddingDescriptor, + reform: *mut ReformDescriptor, + softmax: *mut CausalSoftmaxDescriptor, + swiglu: *mut SwigluDescriptor, +} -impl CambriconKernels { - fn gather( - &self, - x: &mut rustTensor, - table: &rustTensor, - tokens: I, - stream: *mut ::std::os::raw::c_void, - ) where - T: DerefMut, - U: Deref, - I: IntoIterator, - { - todo!() +impl CambriconKernels { + pub fn new(device: DeviceEnum) -> Self { + let config: *mut std::ffi::c_void = std::ptr::null_mut(); + unsafe { + Self { + mat_mul: createMatmulDescriptor(device, config) as *mut MatmulDescriptor, + rms_norm: createRMSNormDescriptor(device, config) as *mut RMSNormDescriptor, + rope: createRotaryEmbeddingDescriptor(device, config) as *mut RotaryEmbeddingDescriptor, + reform: createReformDescriptor(device, config) as *mut ReformDescriptor, + softmax: createCausalSoftmaxDescriptor(device, config) as *mut CausalSoftmaxDescriptor, + swiglu: createSwigluDescriptor(device, config) as *mut SwigluDescriptor, + } + } } +} + +impl Kernels for CambriconKernels {} + +impl KernelsA for CambriconKernels { + type Handle = Mlu; fn rms_norm( &self, @@ -92,25 +110,20 @@ impl CambriconKernels { x: &rustTensor, w: &rustTensor, epsilon: f32, - stream: *mut ::std::os::raw::c_void, + queue: &QueueOf, ) where - T: DerefMut, - U: Deref, - V: Deref, + T: DerefMut>, + U: Deref>, + V: Deref> { - let device = DeviceEnum::DevCpu; - let config: *mut std::ffi::c_void = std::ptr::null_mut(); - unsafe { - let descriptor = createRMSNormDescriptor(device, config) as *mut RMSNormDescriptor; - + unsafe { let y = to_ctensor(y); let x = to_ctensor(x); let w = to_ctensor(w); - rmsNorm(descriptor, y, x, w, epsilon, stream); + rmsNorm(self.rms_norm, y, x, w, epsilon, queue.as_raw() as *mut ::std::os::raw::c_void); // Destroy the SwigluDescriptor - destroyRMSNormDescriptor(descriptor); destroyTensorDescriptor(x.layout); destroyTensorDescriptor(y.layout); destroyTensorDescriptor(w.layout); @@ -122,23 +135,18 @@ impl CambriconKernels { t: &mut rustTensor, pos: &rustTensor, theta: f32, - stream: *mut ::std::os::raw::c_void, + queue: &QueueOf, ) where - T: DerefMut, - U: Deref, + T: DerefMut>, + U: Deref>, { - let device = DeviceEnum::DevCpu; - let config: *mut std::ffi::c_void = std::ptr::null_mut(); unsafe { - let descriptor = createRotaryEmbeddingDescriptor(device, config) as *mut RotaryEmbeddingDescriptor; - let t = to_ctensor(t); let pos = to_ctensor(pos); - rotaryEmbedding(descriptor, t, pos, theta, stream); + rotaryEmbedding(self.rope, t, pos, theta, queue.as_raw() as *mut ::std::os::raw::c_void); // Destroy the SwigluDescriptor - destroyRotaryEmbeddingDescriptor(descriptor); destroyTensorDescriptor(t.layout); destroyTensorDescriptor(pos.layout); } @@ -151,93 +159,95 @@ impl CambriconKernels { a: &rustTensor, b: &rustTensor, alpha: f32, - stream: *mut ::std::os::raw::c_void, + queue: &QueueOf, ) where - T: DerefMut, - U: Deref, - V: Deref, + T: DerefMut>, + U: Deref>, + V: Deref>, { - let device = DeviceEnum::DevCpu; - let config: *mut std::ffi::c_void = std::ptr::null_mut(); unsafe { - let descriptor = createMatmulDescriptor(device, config) as *mut MatmulDescriptor; - let c = to_ctensor(c); let a = to_ctensor(a); let b = to_ctensor(b); - matmul(descriptor, c, beta, a, b, alpha, stream); + matmul(self.mat_mul, c, beta, a, b, alpha, queue.as_raw() as *mut ::std::os::raw::c_void); // Destroy the SwigluDescriptor - destroyMatmulDescriptor(descriptor); destroyTensorDescriptor(c.layout); destroyTensorDescriptor(a.layout); destroyTensorDescriptor(b.layout); } } - fn reform(&self, dst: &mut rustTensor, src: &rustTensor, stream: *mut ::std::os::raw::c_void) + fn softmax(&self, att: &mut rustTensor, queue: &QueueOf) where - T: DerefMut, - U: Deref, + T: DerefMut>, { - let device = DeviceEnum::DevCpu; - let config: *mut std::ffi::c_void = std::ptr::null_mut(); unsafe { - let descriptor = createReformDescriptor(device, config) as *mut ReformDescriptor; - - let dst = to_ctensor(dst); - let src = to_ctensor(src); + let att = to_ctensor(att); - reform(descriptor, dst, src, stream); + causalSoftmax(self.softmax, att, queue.as_raw() as *mut ::std::os::raw::c_void); // Destroy the SwigluDescriptor - destroyReformDescriptor(descriptor); - destroyTensorDescriptor(dst.layout); - destroyTensorDescriptor(src.layout); + destroyTensorDescriptor(att.layout); } } - fn softmax(&self, att: &mut rustTensor, stream: *mut ::std::os::raw::c_void) + fn swiglu(&self, gate: &mut rustTensor, up: &rustTensor, queue: &QueueOf) where - T: DerefMut, + T: DerefMut>, + U: Deref>, { - let device = DeviceEnum::DevCpu; - let config: *mut std::ffi::c_void = std::ptr::null_mut(); unsafe { - let descriptor = createCausalSoftmaxDescriptor(device, config) as *mut CausalSoftmaxDescriptor; - - let att = to_ctensor(att); + let gate = to_ctensor(gate); + let up = to_ctensor(up); - causalSoftmax(descriptor, att, stream); + swiglu(self.swiglu, gate, up, queue.as_raw() as *mut ::std::os::raw::c_void); // Destroy the SwigluDescriptor - destroyCausalSoftmaxDescriptor(descriptor); - destroyTensorDescriptor(att.layout); + destroyTensorDescriptor(gate.layout); + destroyTensorDescriptor(up.layout); } + } + +} + +impl KernelsB for CambriconKernels { + type Handle = Mlu; + + fn gather( + &self, + x: &mut rustTensor, + table: &rustTensor, + tokens: I, + queue: &QueueOf, + ) where + T: DerefMut>, + U: Deref, + I: IntoIterator, + { + gather::gather(x, table, tokens, queue); } - fn swiglu(&self, gate: &mut rustTensor, up: &rustTensor, stream: *mut ::std::os::raw::c_void) + + fn reform(&self, dst: &mut rustTensor, src: &rustTensor, queue: &QueueOf) where - T: DerefMut, - U: Deref, + T: DerefMut>, + U: Deref>, { - let device = DeviceEnum::DevCpu; - let config: *mut std::ffi::c_void = std::ptr::null_mut(); - unsafe { - let descriptor = createSwigluDescriptor(device, config) as *mut SwigluDescriptor; - - let gate = to_ctensor(gate); - let up = to_ctensor(up); + unsafe { + let dst = to_ctensor(dst); + let src = to_ctensor(src); - swiglu(descriptor, gate, up, stream); + reform(self.reform, dst, src, queue.as_raw() as *mut ::std::os::raw::c_void); // Destroy the SwigluDescriptor - destroySwigluDescriptor(descriptor); - destroyTensorDescriptor(gate.layout); - destroyTensorDescriptor(up.layout); + destroyTensorDescriptor(dst.layout); + destroyTensorDescriptor(src.layout); } } + + } pub struct DropOption(Option); @@ -265,7 +275,7 @@ impl AsMut for DropOption { impl DropOption { #[inline] - pub fn sprout<'ctx>(&mut self, ctx: &'ctx ContextGuard) -> ::Resource<'ctx> { + pub fn sprout<'ctx>(&mut self, ctx: &'ctx CurrentCtx) -> ::Resource<'ctx> { self.0.take().unwrap().sprout(ctx) } } diff --git a/devices/cambricon-mlu/src/sample.rs b/devices/cambricon-mlu/src/sample.rs new file mode 100644 index 00000000..ad7c109b --- /dev/null +++ b/devices/cambricon-mlu/src/sample.rs @@ -0,0 +1,19 @@ +use common::{f16, utok, Blob}; +use sample::SampleArgs; +use operators::cndrv::{DevByte, Queue, memcpy_d2h}; +use tensor::reslice; + +pub fn sample_cpu( + args: impl IntoIterator, + logits: &[DevByte], + voc: usize, + _queue: &Queue, +) -> Vec { + let mut host = Blob::new(logits.len()); + memcpy_d2h(&mut host, logits); + + let logits: &[f16] = reslice(&host); + args.into_iter() + .map(|(i, arg)| arg.random(&logits[voc * i..][..voc])) + .collect() +} diff --git a/models/llama/cambricon-mlu/Cargo.toml b/models/llama/cambricon-mlu/Cargo.toml index fde45d2e..f101de1f 100644 --- a/models/llama/cambricon-mlu/Cargo.toml +++ b/models/llama/cambricon-mlu/Cargo.toml @@ -9,6 +9,7 @@ common = { path = "../../../common" } common-cn = { path = "../../../devices/cambricon-mlu" } causal-lm = { path = "../../../causal-lm" } llama = { path = "../common" } +digit-layout.workspace = true [build-dependencies] build-script-cfg.workspace = true diff --git a/models/llama/cambricon-mlu/src/lib.rs b/models/llama/cambricon-mlu/src/lib.rs index 73d2eefd..712b2eaf 100644 --- a/models/llama/cambricon-mlu/src/lib.rs +++ b/models/llama/cambricon-mlu/src/lib.rs @@ -3,71 +3,375 @@ mod resource; use causal_lm::{CausalLM, DecodingMeta, Model, QueryContext, SampleMeta}; -use common::{upos, utok, FileLoadError}; -use common_cn::rustTensor as Tensor; -use std::path::Path; +use common::{upos, utok, Blob, FileLoadError}; +use common_cn::{sample_cpu, slice, KernelsA, KernelsB, cndrv::{memcpy_d2h, Context, ContextResource, ContextSpore, CurrentCtx, DevByte, DevMem, DevMemSpore, Device, HostMemSpore, Queue, QueueSpore}, rustTensor as Tensor, CambriconKernels, DeviceEnum, Kernels, Mlu}; +use std::{iter::repeat, mem::ManuallyDrop, ops::Deref, path::Path, slice::from_raw_parts, sync::Arc}; +use llama::{ComputeConst, InferenceConfig, LayerStorage, SliceOn, Weight}; +use digit_layout::types::F16; pub use common_cn::{cndrv, synchronize}; pub use resource::Cache; -pub struct Transformer; +pub struct Transformer(ManuallyDrop); +pub struct Internal { + config: InferenceConfig, + + resource: Arc, + compute: QueueSpore, + kernels: CambriconKernels, + + embed_tokens: Tensor, + layers: Vec>, + lm_layernorm: Tensor, + lm_head: Tensor, +} impl Model for Transformer { - type Meta = (); + type Meta = Device; type Error = FileLoadError; - fn load(_model_dir: impl AsRef, _meta: Self::Meta) -> Result { - todo!() + #[inline] + fn load(model_dir: impl AsRef, meta: Self::Meta ) -> Result { + // let time = Instant::now(); + let host = llama::Storage::load_safetensors(model_dir)?; + // info!("load host: {:?}", time.elapsed()); + let resource = Arc::new(meta.acquire_shared()); + resource.apply(|ctx| { + let page_lock = |u: &Weight| { + let mut host = ctx.malloc_host::(u.len()); + host.copy_from_slice(u); + host.sporulate() + }; + + Ok(Self (ManuallyDrop::new(Internal { + resource: resource.clone(), + compute: ctx.queue().sporulate(), + kernels: CambriconKernels::new( + DeviceEnum::DevCambriconMlu + ), + embed_tokens: host + .embed_tokens + .as_ref() + .map_physical(page_lock), + layers: host + .layers + .iter() + .map(|l| l.map(|u| ctx.from_host(&u).sporulate())) + .collect::>(), + lm_layernorm: host + .lm_layernorm + .map_physical(|u| ctx.from_host(&u).sporulate()), + lm_head: host + .lm_head + .map_physical(|u| ctx.from_host(&u).sporulate()), + + config: host.config, + }) )) + }) + + } +} + +impl Transformer { + #[inline] + fn cache(&self, len: usize) -> Cache { + Cache::new(&self.0.resource, len) + } + + #[inline] + fn tensor(&self, shape: &[u32]) -> Tensor { + Tensor::alloc(self.0.config.dt, shape, |len| self.cache(len)) } } impl CausalLM for Transformer { type Storage = Cache; + #[inline] fn max_seq_len(&self) -> upos { - todo!() + self.0.config.max_seq_len } + #[inline] fn eos_token(&self) -> utok { - todo!() + self.0.config.eos_token } fn new_cache(&self) -> Tensor { - todo!() + self.0.config.new_cache(|len| self.cache(len)) } - fn duplicate_cache(&self, _cache: &Tensor, _pos: upos) -> Tensor { - todo!() + fn duplicate_cache(&self, cache: &Tensor, pos: upos) -> Tensor { + self.0.config.duplicate_cache( + cache, + pos, + |len| self.cache(len), + |dst, src| { + self.0.resource.apply(|ctx| { + self.0.kernels.reform( + &mut dst.map_physical(|u| &mut **u.mem.sprout_mut(ctx)), + &src.map_physical(|u| &**u.mem.sprout_ref(ctx)), + &ctx.queue(), + ); + }) + }, + ) } - fn token_embed(&self, _queries: impl IntoIterator) -> Tensor { - todo!() + fn token_embed(&self, queries: impl IntoIterator) -> Tensor { + let tokens = queries.into_iter().collect::>(); + let nt = tokens.len() as u32; + let d = self.0.config.d; + + let mut x = self.tensor(&[nt, d]); + self.0.resource.apply(|ctx| { + self.0.kernels.gather( + &mut x + .as_mut() + .map_physical(|u| &mut **u.mem.sprout_mut(ctx)), + &self.0.embed_tokens.as_ref().map_physical(|u| &**u), + tokens, + &ctx.queue(), + ) + }); + x } fn forward<'a>( &self, - _queries: impl IntoIterator>, - _token_embedded: Tensor, + queries: impl IntoIterator>, + token_embedded: Tensor, ) -> Tensor where Self: 'a, { - todo!() + self.0.resource.apply(|ctx| { + let stream = ComputeStream { + nh: self.0.config.nh, + nkvh: self.0.config.nkvh, + di: self.0.config.di, + epsilon: self.0.config.epsilon, + theta: self.0.config.theta, + kernels: &self.0.kernels, + compute: self.0.compute.sprout_ref(ctx), + layers: &self.0.layers, + }; + ::forward(&stream, queries, token_embedded) + }) } fn decode( &self, - _decoding: impl IntoIterator, - _hidden_state: Tensor, + decoding: impl IntoIterator, + mut hidden_state: Tensor, ) -> Tensor { - todo!() + self.0.resource.apply(|ctx| { + let compute = self.0.compute.sprout_ref(ctx); + let mut x = hidden_state + .as_mut() + .map_physical(|u| &mut **u.mem.sprout_mut(ctx)); + let range = + DecodingMeta::select(&mut x, decoding, |dst, src| compute.memcpy_d2d(dst, src)); + if range.is_empty() { + return self.tensor(&[0, self.0.config.d]); + } + + let lm_layernorm = self + .0 + .lm_layernorm + .as_ref() + .map_physical(|u| &**u.sprout_ref(ctx)); + let lm_head = self + .0 + .lm_head + .as_ref() + .map_physical(|u| &**u.sprout_ref(ctx)); + + let mut x = x.slice(&[slice![range.start => range.end], slice![=>]]); + let mut logits = self.tensor(&[x.shape()[0], lm_head.shape()[1]]); + + // 复制一个 x 以实现原地归一化 + let x_ = x + .as_ref() + .map_physical(|u| unsafe { from_raw_parts(u.as_ptr(), u.len()) }); + self.0 + .kernels + .rms_norm(&mut x, &x_, &lm_layernorm, self.0.config.epsilon, compute); + self.0.kernels.mat_mul( + &mut logits + .as_mut() + .map_physical(|u| &mut **u.mem.sprout_mut(ctx)), + 0., + &x, + &lm_head, + 1., + compute, + ); + + logits + }) } fn sample( &self, - _args: impl IntoIterator, - _logits: Tensor, + args: impl IntoIterator, + logits: Tensor, ) -> Vec { - todo!() + assert_eq!(logits.data_layout(), F16); + let &[_nt, voc] = logits.shape() else { + panic!() + }; + let voc = voc as usize; + + self.0.resource.apply(|ctx| { + let compute = self.0.compute.sprout_ref(ctx); + sample_cpu( + args.into_iter() + .flat_map(|meta| repeat(meta.args).take(meta.num_decode)) + .enumerate(), + logits.take_physical().mem.sprout_ref(ctx), + voc, + compute, + ) + }) + } +} + +impl Drop for Transformer { + #[inline] + fn drop(&mut self) { + let Internal { + config: _, + resource, + compute, + kernels: _, + embed_tokens, + layers, + lm_layernorm, + lm_head, + } = unsafe { ManuallyDrop::take(&mut self.0) }; + resource.apply(|ctx| { + compute.sprout(ctx); + embed_tokens.take_physical().sprout(ctx); + lm_layernorm.take_physical().sprout(ctx); + lm_head.take_physical().sprout(ctx); + for layer in layers { + layer.att_layernorm.take_physical().sprout(ctx); + layer.att_qkv.take_physical().sprout(ctx); + layer.att_o.take_physical().sprout(ctx); + layer.mlp_layernorm.take_physical().sprout(ctx); + layer.mlp_gate_up.take_physical().sprout(ctx); + layer.mlp_down.take_physical().sprout(ctx); + } + }); + } +} + +struct ComputeStream<'a> { + nh: u32, + nkvh: u32, + di: u32, + epsilon: f32, + theta: f32, + kernels: &'a CambriconKernels, + compute: &'a Queue<'a>, + layers: &'a [LayerStorage], +} + +impl<'a> llama::ComputeStream for ComputeStream<'a> { + type Handle = Mlu; + type Storage = Cache; + type Buf<'m> = DevMem<'m>; + type Pos<'m> = DevMem<'m>; + + #[inline] + fn malloc(&self, len: usize) -> Self::Buf<'_> { + self.compute.ctx().malloc::(len) + } + #[inline] + fn map_pos<'b>(&self, pos: &'b [u32]) -> Self::Pos<'b> + where + Self: 'b, + { + self.compute.ctx().from_host(pos) + } + #[inline] + fn map_storage<'b>(&'b self, storage: &'b mut Self::Storage) -> &'b mut SliceOn { + storage.mem.sprout_mut(self.compute.ctx()) + } + #[inline] + fn kernels(&self) -> &impl Kernels { + self.kernels + } + #[inline] + fn queue(&self) -> &llama::QueueOf { + self.compute + } + #[inline] + fn constant(&self) -> ComputeConst { + ComputeConst { + nh: self.nh, + nkvh: self.nkvh, + di: self.di, + epsilon: self.epsilon, + theta: self.theta, + } + } + + fn debug(tensor: &Tensor) + where + T: Deref>, + { + println!( + "{}", + tensor.as_ref().map_physical(|s| { + let mut host = Blob::new(s.len()); + memcpy_d2h(&mut host, s); + host + }) + ); + } + + fn layers( + &self, + ) -> impl Iterator::Byte>> + { + self.layers.iter().map(|l|LlamaLayer( self.queue().ctx(), l)) + } +} + +macro_rules! access { + ($self:expr, $name:ident) => { + $self + .1 + .$name + .as_ref() + .map_physical(|p|&**p.sprout_ref(&$self.0)) + }; +} + +struct LlamaLayer<'a>(&'a CurrentCtx,&'a LayerStorage); + +impl<'a> llama::LLamaLayer for LlamaLayer<'a> { + type Byte = DevByte; + type Storage<'m> = &'m[DevByte] where Self: 'm; + + fn att_layernorm(&self) -> Tensor> { + access!(self, att_layernorm) + } + fn att_qkv(&self) -> Tensor> { + access!(self, att_qkv) + } + fn att_o(&self) -> Tensor> { + access!(self, att_o) + } + fn mlp_layernorm(&self) -> Tensor> { + access!(self, mlp_layernorm) + } + fn mlp_gate_up(&self) -> Tensor> { + access!(self, mlp_gate_up) + } + fn mlp_down(&self) -> Tensor> { + access!(self, mlp_down) } }