diff --git a/Cargo.lock b/Cargo.lock index 480a8ff1..151849fb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -369,7 +369,7 @@ dependencies = [ [[package]] name = "common" version = "0.1.0" -source = "git+https://github.com/YdrMaster/operators-rs?rev=ea23709#ea2370907af8ed6d59220ed0d2d8b959208a010c" +source = "git+https://github.com/YdrMaster/operators-rs?rev=647dceb#647dcebbdbd03b5fed9531d08398c4a484903210" [[package]] name = "common-cpu" @@ -387,7 +387,7 @@ dependencies = [ [[package]] name = "common-cpu" version = "0.1.0" -source = "git+https://github.com/YdrMaster/operators-rs?rev=ea23709#ea2370907af8ed6d59220ed0d2d8b959208a010c" +source = "git+https://github.com/YdrMaster/operators-rs?rev=647dceb#647dcebbdbd03b5fed9531d08398c4a484903210" dependencies = [ "common 0.1.0", ] @@ -741,9 +741,9 @@ dependencies = [ [[package]] name = "gemm" -version = "0.17.1" +version = "0.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ab24cc62135b40090e31a76a9b2766a501979f3070fa27f689c27ec04377d32" +checksum = "e400f2ffd14e7548356236c35dc39cad6666d833a852cb8a8f3f28029359bb03" dependencies = [ "dyn-stack", "gemm-c32", @@ -761,9 +761,9 @@ dependencies = [ [[package]] name = "gemm-c32" -version = "0.17.1" +version = "0.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9c030d0b983d1e34a546b86e08f600c11696fde16199f971cd46c12e67512c0" +checksum = "10dc4a6176c8452d60eac1a155b454c91c668f794151a303bf3c75ea2874812d" dependencies = [ "dyn-stack", "gemm-common", @@ -776,9 +776,9 @@ dependencies = [ [[package]] name = "gemm-c64" -version = "0.17.1" +version = "0.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fbb5f2e79fefb9693d18e1066a557b4546cd334b226beadc68b11a8f9431852a" +checksum = "cc2032ce2c0bb150da0256338759a6fb01ca056f6dfe28c4d14af32d7f878f6f" dependencies = [ "dyn-stack", "gemm-common", @@ -791,9 +791,9 @@ dependencies = [ [[package]] name = "gemm-common" -version = "0.17.1" +version = "0.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2e7ea062c987abcd8db95db917b4ffb4ecdfd0668471d8dc54734fdff2354e8" +checksum = "90fd234fc525939654f47b39325fd5f55e552ceceea9135f3aa8bdba61eabef6" dependencies = [ "bytemuck", "dyn-stack", @@ -811,9 +811,9 @@ dependencies = [ [[package]] name = "gemm-f16" -version = "0.17.1" +version = "0.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ca4c06b9b11952071d317604acb332e924e817bd891bec8dfb494168c7cedd4" +checksum = "3fc3652651f96a711d46b8833e1fac27a864be4bdfa81a374055f33ddd25c0c6" dependencies = [ "dyn-stack", "gemm-common", @@ -829,9 +829,9 @@ dependencies = [ [[package]] name = "gemm-f32" -version = "0.17.1" +version = "0.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e9a69f51aaefbd9cf12d18faf273d3e982d9d711f60775645ed5c8047b4ae113" +checksum = "acbc51c44ae3defd207e6d9416afccb3c4af1e7cef5e4960e4c720ac4d6f998e" dependencies = [ "dyn-stack", "gemm-common", @@ -844,9 +844,9 @@ dependencies = [ [[package]] name = "gemm-f64" -version = "0.17.1" +version = "0.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa397a48544fadf0b81ec8741e5c0fba0043008113f71f2034def1935645d2b0" +checksum = "3f37fc86e325c2415a4d0cab8324a0c5371ec06fc7d2f9cb1636fcfc9536a8d8" dependencies = [ "dyn-stack", "gemm-common", @@ -1582,10 +1582,11 @@ dependencies = [ [[package]] name = "nvidia-gpu" version = "0.1.0" -source = "git+https://github.com/YdrMaster/operators-rs?rev=ea23709#ea2370907af8ed6d59220ed0d2d8b959208a010c" +source = "git+https://github.com/YdrMaster/operators-rs?rev=647dceb#647dcebbdbd03b5fed9531d08398c4a484903210" dependencies = [ "build-script-cfg", "common 0.1.0", + "cublas", "cuda", "search-cuda-tools", ] @@ -1648,11 +1649,12 @@ checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" [[package]] name = "operators" version = "0.0.0" -source = "git+https://github.com/YdrMaster/operators-rs?rev=ea23709#ea2370907af8ed6d59220ed0d2d8b959208a010c" +source = "git+https://github.com/YdrMaster/operators-rs?rev=647dceb#647dcebbdbd03b5fed9531d08398c4a484903210" dependencies = [ "build-script-cfg", "common 0.1.0", "common-cpu 0.1.0", + "gemm", "half", "log", "nvidia-gpu", diff --git a/Cargo.toml b/Cargo.toml index edb5811e..cb226088 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,7 +33,7 @@ rayon = "1.10" tokio = { version = "1.38", features = ["rt-multi-thread", "sync"] } build-script-cfg = "0.0" -operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "ea23709", default-features = false } +operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "647dceb", default-features = false } cuda = { git = "https://github.com/YdrMaster/cuda-driver", rev = "a4b3cf5" } cublas = { git = "https://github.com/YdrMaster/cuda-driver", rev = "a4b3cf5" } diff --git a/devices/cpu/Cargo.toml b/devices/cpu/Cargo.toml index 0f561113..4183d025 100644 --- a/devices/cpu/Cargo.toml +++ b/devices/cpu/Cargo.toml @@ -10,7 +10,7 @@ authors = ["YdrMaster "] common = { path = "../../common" } tensor = { path = "../../tensor" } kernel-lib = { path = "../../kernel-lib" } -gemm = "0.17" +gemm = "0.18" intel-mkl-src = { version = "0.8", features = ["mkl-dynamic-lp64-iomp"] } operators = { workspace = true, features = ["common-cpu"] } diff --git a/devices/cpu/src/mat_mul.rs b/devices/cpu/src/mat_mul.rs index 78dfedbb..e55b0181 100644 --- a/devices/cpu/src/mat_mul.rs +++ b/devices/cpu/src/mat_mul.rs @@ -1,137 +1,45 @@ -use common::{f16, BetweenF32}; -use gemm::gemm; -use kernel_lib::Matrix; -use std::{ - ffi::{c_longlong, c_void}, - mem::swap, - ops::{Deref, DerefMut}, +use crate::layout; +use operators::{ + common_cpu::ThisThread, + mat_mul::{ + common_cpu::{Operator as MatMul, Scheme as MatMulScheme}, + LayoutAttrs, + }, + Operator, Scheme, F16, }; -use tensor::{DataType, Tensor}; +use std::ops::{Deref, DerefMut}; +use tensor::Tensor; /// c = a x b -/// -/// - c: [N0, N1, ... , N_, m, n] -/// - a: [N0, N1, ... , N_, m, k] -/// - b: [N0, N1, ... , N_, k, n] pub fn mat_mul(c: &mut Tensor, beta: f32, a: &Tensor, b: &Tensor, alpha: f32) where T: DerefMut, U: Deref, V: Deref, { - let dt = c.data_type(); - - #[inline] - fn base(tensor: &impl Deref) -> *mut c_void { - tensor.as_ptr() as _ - } - - let mut c = Matrix::new(c, base); - let mut a = Matrix::new(a, base); - let mut b = Matrix::new(b, base); - assert_eq!(c.r, a.r); // m - assert_eq!(c.c, b.c); // n - assert_eq!(a.c, b.r); // k - - let batch = c.batch; - if !a.match_batch(batch) || !b.match_batch(batch) { - panic!("Invalid batch size"); - } - - if c.rs == 1 { - // Nothing to do - } else if c.cs == 1 { - c.transpose(); - a.transpose(); - b.transpose(); - swap(&mut a, &mut b); - } else { - panic!("Matrix is not contiguous"); - }; - - assert_eq!(c.batch, a.batch); - assert_eq!(c.batch, b.batch); - // const LAYER: usize = 40; - // static mut I: usize = 0; - // unsafe { - // if I == 0 { - // println!(); - // // #[cfg(detected_mkl)] - // // { - // // println!("MKL threads: {}", mkl::mkl_get_max_threads()); - // // println!("MKL dynamic: {}", mkl::mkl_get_dynamic()); - // // } - // } - // } - // let time = std::time::Instant::now(); - match dt { - // DataType::F16 => mkl::gemm(dt, c, beta, alpha, a, b), - DataType::F16 => gemm_as_blas::(c, beta, alpha, a, b), - #[cfg(not(detected_mkl))] - DataType::F32 => gemm_as_blas::(c, beta, alpha, a, b), - #[cfg(detected_mkl)] - DataType::F32 => mkl::gemm(dt, c, beta, alpha, a, b), - _ => unreachable!(), - } - // unsafe { - // if I % 6 == 0 { - // println!(); - // } - // println!("{}:{} {}", I / 6, I % 6, time.elapsed().as_micros()); - // if I == LAYER * 6 { - // I = 0; - // } else { - // I += 1; - // } - // } -} - -fn gemm_as_blas(c: Matrix, beta: f32, alpha: f32, a: Matrix, b: Matrix) { - let batch = c.batch; - assert_eq!(a.batch, batch); - assert_eq!(b.batch, batch); - - let m = c.r; - let n = c.c; - let k = a.c; - assert_eq!(a.r, m); - assert_eq!(b.r, k); - assert_eq!(b.c, n); - - let c_ = c.base.cast::(); - let a_ = a.base.cast::(); - let b_ = b.base.cast::(); - for i in 0..batch as c_longlong { - unsafe { - let c_ = c_.offset((i * c.stride) as isize); - let a_ = a_.offset((i * a.stride) as isize); - let b_ = b_.offset((i * b.stride) as isize); - gemm( - m as _, - n as _, - k as _, - c_, - c.cs as _, - c.rs as _, - beta != 0., - a_, - a.cs as _, - a.rs as _, - b_, - b.cs as _, - b.rs as _, - T::cast(beta), - T::cast(alpha), - false, - false, - false, - gemm::Parallelism::Rayon(0), - ) - } - } + MatMulScheme::new( + &MatMul::new(&F16).unwrap(), + LayoutAttrs { + c: layout(c), + a: layout(a), + b: layout(b), + }, + ) + .unwrap() + .launch( + &( + c.physical_mut().as_mut_ptr(), + beta, + a.physical().as_ptr(), + b.physical().as_ptr(), + alpha, + ), + &ThisThread, + ); } #[cfg(detected_mkl)] +#[allow(unused)] mod mkl { use gemm::f16; use kernel_lib::Matrix; @@ -147,7 +55,6 @@ mod mkl { Ordinary = 112, } - #[allow(unused)] extern "C" { pub fn mkl_get_max_threads() -> c_int; pub fn mkl_get_dynamic() -> c_int; diff --git a/devices/nvidia/src/lib.rs b/devices/nvidia/src/lib.rs index 6996d8b6..cf847761 100644 --- a/devices/nvidia/src/lib.rs +++ b/devices/nvidia/src/lib.rs @@ -5,18 +5,17 @@ extern crate log; pub extern crate cuda; mod gather; -mod mat_mul; mod reform; mod sample; use common::utok; -use cublas::{Cublas, CublasSpore}; use cuda::{ memcpy_d2h, ComputeCapability, ContextGuard, ContextResource, ContextSpore, CudaDataType, DevByte, Device, ModuleSpore, Ptx, Stream, }; use operators::{ fuesd_softmax::{self, nvidia_gpu as softmax_nv}, + mat_mul::{self, nvidia_gpu as mat_mul_nv}, rms_norm::{self, nvidia_gpu as rms_norm_nv}, rope::{self, nvidia_gpu as rope_nv}, swiglu::{self, nvidia_gpu as swiglu_nv}, @@ -33,6 +32,7 @@ pub use sample::{sample_cpu, sample_nv}; pub use tensor::{reslice, reslice_mut, slice, split, udim, DataType, LocalSplitable, Tensor}; pub struct NvidiaKernelsPtx { + mat_mul: mat_mul_nv::Operator, rms_norm: rms_norm_nv::Operator, rope: rope_nv::Operator, reform: Arc, @@ -58,6 +58,7 @@ impl NvidiaKernelsPtx { }, ); Self { + mat_mul: mat_mul_nv::Operator::new(&F16).unwrap(), rms_norm: rms_norm_nv::Operator::new(&rms_norm_nv::Config::config_for( &devices[0], F16, @@ -95,7 +96,7 @@ struct ModuleWapper { } pub struct NvidiaKernels { - cublas: CublasSpore, + mat_mul: mat_mul_nv::Operator, rms_norm: rms_norm_nv::Operator, rope: rope_nv::Operator, reform: ModuleWapper, @@ -106,9 +107,8 @@ pub struct NvidiaKernels { impl NvidiaKernelsPtx { pub fn load(&self, stream: &Stream) -> NvidiaKernels { let ctx = stream.ctx(); - let cublas = Cublas::new(ctx); NvidiaKernels { - cublas: cublas.sporulate(), + mat_mul: self.mat_mul.clone(), rms_norm: self.rms_norm.clone(), rope: self.rope.clone(), reform: self.reform.clone().load(ctx), @@ -120,7 +120,6 @@ impl NvidiaKernelsPtx { impl NvidiaKernels { pub fn kill(self, ctx: &ContextGuard) { - drop(self.cublas.sprout(ctx)); drop(self.reform.module.sprout(ctx)); } } @@ -133,7 +132,6 @@ pub struct KernelRuntime<'a> { impl NvidiaKernels { #[inline] pub fn on<'a>(&'a self, stream: &'a Stream) -> KernelRuntime<'a> { - self.cublas.sprout_ref(stream.ctx()).set_stream(stream); KernelRuntime { kernels: self, stream, @@ -194,8 +192,25 @@ impl Kernels for KernelRuntime<'_> { U: Deref, V: Deref, { - let cublas = self.kernels.cublas.sprout_ref(self.stream.ctx()); - mat_mul::mat_mul(cublas, c, beta, a, b, alpha) + mat_mul_nv::Scheme::new( + &self.kernels.mat_mul, + mat_mul::LayoutAttrs { + c: layout(c), + a: layout(a), + b: layout(b), + }, + ) + .unwrap() + .launch( + &( + c.physical_mut().as_mut_ptr(), + beta, + a.physical().as_ptr(), + b.physical().as_ptr(), + alpha, + ), + self.stream, + ); } #[inline] diff --git a/devices/nvidia/src/mat_mul.rs b/devices/nvidia/src/mat_mul.rs deleted file mode 100644 index a6a346a0..00000000 --- a/devices/nvidia/src/mat_mul.rs +++ /dev/null @@ -1,100 +0,0 @@ -use common::f16; -use cublas::{bindings::cublasOperation_t, cublas, Cublas}; -use cuda::{AsRaw, DevByte}; -use kernel_lib::Matrix; -use std::{ - mem::swap, - ops::{Deref, DerefMut}, - os::raw::c_void, -}; -use tensor::{DataType, Tensor}; - -pub fn mat_mul( - handle: &Cublas, - c: &mut Tensor, - beta: f32, - a: &Tensor, - b: &Tensor, - alpha: f32, -) where - T: DerefMut, - U: Deref, - V: Deref, -{ - let dt = c.data_type(); - assert_eq!(dt, DataType::F16); - assert_eq!(a.data_type(), dt); - assert_eq!(b.data_type(), dt); - - #[inline] - fn base(tensor: &impl Deref) -> *mut c_void { - tensor.as_ptr() as _ - } - - let mut c = Matrix::new(c, base); - let mut a = Matrix::new(a, base); - let mut b = Matrix::new(b, base); - assert_eq!(c.r, a.r); // m - assert_eq!(c.c, b.c); // n - assert_eq!(a.c, b.r); // k - - let batch = c.batch; - if !a.match_batch(batch) || !b.match_batch(batch) { - panic!("Invalid batch size"); - } - - if c.rs == 1 { - // Nothing to do - } else if c.cs == 1 { - c.transpose(); - a.transpose(); - b.transpose(); - swap(&mut a, &mut b); - } else { - panic!("Matrix is not contiguous"); - }; - - let alpha = f16::from_f32(alpha); - let beta = f16::from_f32(beta); - - let m = c.r; - let n = c.c; - let k = a.c; - - #[inline] - fn trans(m: &Matrix) -> cublasOperation_t { - if m.rs == 1 { - cublasOperation_t::CUBLAS_OP_N - } else if m.cs == 1 { - cublasOperation_t::CUBLAS_OP_T - } else { - panic!("Matrix is not contiguous"); - } - } - - cublas!(cublasGemmStridedBatchedEx( - handle.as_raw(), - trans(&a), - trans(&b), - m, - n, - k, - ((&alpha) as *const f16).cast(), - a.base as _, - cudaDataType_t::CUDA_R_16F, - a.ld(), - a.stride, - b.base as _, - cudaDataType_t::CUDA_R_16F, - b.ld(), - b.stride, - ((&beta) as *const f16).cast(), - c.base as _, - cudaDataType_t::CUDA_R_16F, - c.ld(), - c.stride, - batch, - cublasComputeType_t::CUBLAS_COMPUTE_16F, - cublasGemmAlgo_t::CUBLAS_GEMM_DFALT, - )); -}