From 3f4e442ffe498986216d371f60f5d9cdf7826534 Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Tue, 19 Mar 2024 22:26:55 +0800 Subject: [PATCH] =?UTF-8?q?feat(transform):=20=E6=95=B4=E5=90=88=20blas=20?= =?UTF-8?q?=E5=B0=81=E8=A3=85=E4=BB=A5=E7=AE=80=E5=8C=96=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- transformer-cpu/src/kernel/mat_mul.rs | 339 +++++++++-------------- transformer-nvidia/src/kernel/mat_mul.rs | 108 ++------ transformer/src/blas.rs | 68 +++++ transformer/src/lib.rs | 2 + 4 files changed, 234 insertions(+), 283 deletions(-) create mode 100644 transformer/src/blas.rs diff --git a/transformer-cpu/src/kernel/mat_mul.rs b/transformer-cpu/src/kernel/mat_mul.rs index 780715ce..31dce21e 100644 --- a/transformer-cpu/src/kernel/mat_mul.rs +++ b/transformer-cpu/src/kernel/mat_mul.rs @@ -1,6 +1,11 @@ use gemm::{f16, gemm}; -use std::ops::{Deref, DerefMut}; -use tensor::{expand_indices, idx_strides, DataType, Tensor}; +use std::{ + ffi::{c_longlong, c_void}, + mem::swap, + ops::{Deref, DerefMut}, +}; +use tensor::{DataType, Tensor}; +use transformer::{BetweenF32, Matrix}; /// c = a x b /// @@ -14,100 +19,88 @@ where V: Deref, { let dt = c.data_type(); - assert_eq!(a.data_type(), dt); - assert_eq!(b.data_type(), dt); - #[cfg(detected_mkl)] - { - if dt == DataType::F32 { - mkl::gemm(c, beta, a, b, alpha); - return; - } + #[inline] + fn base(tensor: &impl Deref) -> *mut c_void { + tensor.as_ptr() as _ } - let rank = c.shape().len(); - assert_eq!(a.shape().len(), rank); - assert_eq!(b.shape().len(), rank); - - let (batch, mn) = c.shape().split_at(rank - 2); - let (a_batch, mk) = a.shape().split_at(rank - 2); - let (b_batch, kn) = b.shape().split_at(rank - 2); - assert_eq!(a_batch, batch); - assert_eq!(b_batch, batch); - - let m = mn[0]; - let n = mn[1]; - let k = mk[1]; - assert_eq!(mk, &[m, k]); - assert_eq!(kn, &[k, n]); + let mut c = Matrix::new(c, base); + let mut a = Matrix::new(a, base); + let mut b = Matrix::new(b, base); + assert_eq!(c.r, a.r); // m + assert_eq!(c.c, b.c); // n + assert_eq!(a.c, b.r); // k - let dst_strides = &c.strides()[rank - 2..]; - let dst_cs = dst_strides[1] as isize; - let dst_rs = dst_strides[0] as isize; + let batch = c.batch; + if !a.match_batch(batch) || !b.match_batch(batch) { + panic!("Invalid batch size"); + } - let lhs_strides = &a.strides()[rank - 2..]; - let lhs_cs = lhs_strides[1] as isize; - let lhs_rs = lhs_strides[0] as isize; + if c.rs == 1 { + // Nothing to do + } else if c.cs == 1 { + c.transpose(); + a.transpose(); + b.transpose(); + swap(&mut a, &mut b); + } else { + panic!("Matrix is not contiguous"); + }; - let rhs_strides = &b.strides()[rank - 2..]; - let rhs_cs = rhs_strides[1] as isize; - let rhs_rs = rhs_strides[0] as isize; + assert_eq!(c.batch, a.batch); + assert_eq!(c.batch, b.batch); + match dt { + DataType::F16 => gemm_as_blas::(c, beta, alpha, a, b), + #[cfg(not(detected_mkl))] + DataType::F32 => gemm_as_blas::(c, beta, alpha, a, b), + #[cfg(detected_mkl)] + DataType::F32 => mkl::gemm(dt, c, beta, alpha, a, b), + _ => unreachable!(), + } +} - let (batch, idx_strides) = idx_strides(batch); - for i in 0..batch { - let indices = expand_indices(i, &idx_strides, &[0, 0, 1]); - let indices = indices.as_view(); - let dst = c.locate_mut(&indices).unwrap(); - let lhs = a.locate(&indices).unwrap(); - let rhs = b.locate(&indices).unwrap(); - match dt { - DataType::F32 => unsafe { - gemm( - m as _, - n as _, - k as _, - dst.cast::(), - dst_cs, - dst_rs, - beta != 0., - lhs.cast::(), - lhs_cs, - lhs_rs, - rhs.cast::(), - rhs_cs, - rhs_rs, - beta, - alpha, - false, - false, - false, - gemm::Parallelism::Rayon(0), - ) - }, - DataType::F16 => unsafe { - gemm( - m as _, - n as _, - k as _, - dst.cast::(), - dst_cs, - dst_rs, - beta != 0., - lhs.cast::(), - lhs_cs, - lhs_rs, - rhs.cast::(), - rhs_cs, - rhs_rs, - f16::from_f32(beta), - f16::from_f32(alpha), - false, - false, - false, - gemm::Parallelism::Rayon(0), - ) - }, - _ => unreachable!(), +fn gemm_as_blas(c: Matrix, beta: f32, alpha: f32, a: Matrix, b: Matrix) { + let batch = c.batch; + assert_eq!(a.batch, batch); + assert_eq!(b.batch, batch); + + let m = c.r; + let n = c.c; + let k = a.c; + assert_eq!(a.r, m); + assert_eq!(b.r, k); + assert_eq!(b.c, n); + + let c_ = c.base.cast::(); + let a_ = a.base.cast::(); + let b_ = b.base.cast::(); + for i in 0..batch as c_longlong { + unsafe { + let c_ = c_.offset((i * c.stride) as isize); + let a_ = a_.offset((i * a.stride) as isize); + let b_ = b_.offset((i * b.stride) as isize); + gemm( + m as _, + n as _, + k as _, + c_, + c.cs as _, + c.rs as _, + beta != 0., + a_, + a.cs as _, + a.rs as _, + b_, + b.cs as _, + b.rs as _, + T::cast(beta), + T::cast(alpha), + false, + false, + false, + gemm::Parallelism::Rayon(0), + ) } } } @@ -115,11 +108,8 @@ where #[cfg(detected_mkl)] mod mkl { use gemm::f16; - use std::{ - mem::swap, - ops::{Deref, DerefMut}, - }; - use tensor::Tensor; + use tensor::DataType; + use transformer::Matrix; const COL_MAJOR: i32 = 102; #[derive(Clone, Copy, Debug, Eq, PartialEq)] @@ -166,136 +156,79 @@ mod mkl { ); } - pub fn gemm(c: &mut Tensor, beta: f32, a: &Tensor, b: &Tensor, alpha: f32) - where - T: DerefMut, - U: Deref, - V: Deref, - { - let dt = c.data_type(); - let mut c = Matrix::from(&*c); - let mut a = Matrix::from(a); - let mut b = Matrix::from(b); - assert_eq!(c.r, a.r); // m - assert_eq!(c.c, b.c); // n - assert_eq!(a.c, b.r); // k - + pub fn gemm(dt: DataType, c: Matrix, beta: f32, alpha: f32, a: Matrix, b: Matrix) { let batch = c.batch; - if !a.match_batch(batch) || !b.match_batch(batch) { - panic!("Invalid batch size"); - } - - if c.rs == 1 { - // Nothing to do - } else if c.cs == 1 { - c.transpose(); - a.transpose(); - b.transpose(); - swap(&mut a, &mut b); - } else { - panic!("Matrix is not contiguous"); - }; + assert_eq!(a.batch, batch); + assert_eq!(b.batch, batch); let m = c.r; let n = c.c; let k = a.c; + assert_eq!(a.r, m); + assert_eq!(b.r, k); + assert_eq!(b.c, n); - let (lda, at) = a.ld_op(); - let (ldb, bt) = b.ld_op(); - let ldc = c.cs; - - assert_eq!(c.batch, a.batch); - assert_eq!(c.batch, b.batch); + #[inline] + fn trans(m: &Matrix) -> CBLAS_TRANSPOSE { + if m.rs == 1 { + CBLAS_TRANSPOSE::None + } else if m.cs == 1 { + CBLAS_TRANSPOSE::Ordinary + } else { + panic!("Matrix is not contiguous"); + } + } match dt { - tensor::DataType::F16 => unsafe { + DataType::F16 => unsafe { let alpha = f16::from_f32(alpha); let beta = f16::from_f32(beta); for i in 0..batch { - let a = a.ptr.cast::().offset((i * a.stride) as isize); - let b = b.ptr.cast::().offset((i * b.stride) as isize); - let c = c.ptr.cast::().offset((i * c.stride) as isize); + let a_ = a.base.cast::().offset(i as isize * a.stride as isize); + let b_ = b.base.cast::().offset(i as isize * b.stride as isize); + let c_ = c.base.cast::().offset(i as isize * c.stride as isize); cblas_hgemm( - COL_MAJOR, at, bt, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + COL_MAJOR, + trans(&a), + trans(&b), + m, + n, + k, + alpha, + a_, + a.ld(), + b_, + b.ld(), + beta, + c_, + c.ld(), ); } }, - tensor::DataType::F32 => unsafe { + DataType::F32 => unsafe { for i in 0..batch { - let a = a.ptr.cast::().offset((i * a.stride) as isize); - let b = b.ptr.cast::().offset((i * b.stride) as isize); - let c = c.ptr.cast::().offset((i * c.stride) as isize); + let a_ = a.base.cast::().offset(i as isize * a.stride as isize); + let b_ = b.base.cast::().offset(i as isize * b.stride as isize); + let c_ = c.base.cast::().offset(i as isize * c.stride as isize); cblas_sgemm( - COL_MAJOR, at, bt, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + COL_MAJOR, + trans(&a), + trans(&b), + m, + n, + k, + alpha, + a_, + a.ld(), + b_, + b.ld(), + beta, + c_, + c.ld(), ); } }, _ => unreachable!(), } } - - struct Matrix { - batch: i32, - stride: i32, - r: i32, - c: i32, - rs: i32, - cs: i32, - ptr: *mut u8, - } - - impl From<&Tensor> for Matrix - where - T: Deref, - { - fn from(tensor: &Tensor) -> Self { - let strides = tensor.strides(); - let ptr = tensor.locate_start().cast_mut(); - match tensor.shape() { - &[r, c] => Self { - batch: 1, - stride: 0, - r: r as _, - c: c as _, - rs: strides[0] as _, - cs: strides[1] as _, - ptr, - }, - &[batch, r, c] => Self { - batch: batch as _, - stride: if batch == 1 { 0 } else { strides[0] as _ }, - r: r as _, - c: c as _, - rs: strides[1] as _, - cs: strides[2] as _, - ptr, - }, - s => panic!("Invalid matrix shape: {s:?}"), - } - } - } - - impl Matrix { - #[inline] - fn match_batch(&self, batch: i32) -> bool { - self.batch == batch || self.batch == 1 - } - - #[inline] - fn transpose(&mut self) { - swap(&mut self.r, &mut self.c); - swap(&mut self.rs, &mut self.cs); - } - - #[inline] - fn ld_op(&self) -> (i32, CBLAS_TRANSPOSE) { - if self.rs == 1 { - (self.cs, CBLAS_TRANSPOSE::None) - } else if self.cs == 1 { - (self.rs, CBLAS_TRANSPOSE::Ordinary) - } else { - panic!("Matrix is not contiguous"); - } - } - } } diff --git a/transformer-nvidia/src/kernel/mat_mul.rs b/transformer-nvidia/src/kernel/mat_mul.rs index 038545ea..8ccb72c8 100644 --- a/transformer-nvidia/src/kernel/mat_mul.rs +++ b/transformer-nvidia/src/kernel/mat_mul.rs @@ -2,11 +2,12 @@ use cuda::{AsRaw, DevSlice}; use half::f16; use std::{ - ffi::{c_int, c_longlong, c_void}, mem::swap, ops::{Deref, DerefMut}, + os::raw::c_void, }; use tensor::{DataType, Tensor}; +use transformer::Matrix; pub fn mat_mul( handle: &Cublas, @@ -25,9 +26,14 @@ pub fn mat_mul( assert_eq!(a.data_type(), dt); assert_eq!(b.data_type(), dt); - let mut c = Matrix::from(&*c); - let mut a = Matrix::from(a); - let mut b = Matrix::from(b); + #[inline] + fn base(tensor: &impl Deref) -> *mut c_void { + unsafe { tensor.as_raw() as _ } + } + + let mut c = Matrix::new(c, base); + let mut a = Matrix::new(a, base); + let mut b = Matrix::new(b, base); assert_eq!(c.r, a.r); // m assert_eq!(c.c, b.c); // n assert_eq!(a.c, b.r); // k @@ -55,98 +61,40 @@ pub fn mat_mul( let n = c.c; let k = a.c; - let (lda, transa) = a.ld_op(); - let (ldb, transb) = b.ld_op(); - let ldc = c.cs; + #[inline] + fn trans(m: &Matrix) -> cublasOperation_t { + if m.rs == 1 { + cublasOperation_t::CUBLAS_OP_N + } else if m.cs == 1 { + cublasOperation_t::CUBLAS_OP_T + } else { + panic!("Matrix is not contiguous"); + } + } cublas!(cublasGemmStridedBatchedEx( handle.as_raw(), - transa, - transb, + trans(&a), + trans(&b), m, n, k, ((&alpha) as *const f16).cast(), - a.ptr, + a.base as _, cudaDataType_t::CUDA_R_16F, - lda, + a.ld(), a.stride, - b.ptr, + b.base as _, cudaDataType_t::CUDA_R_16F, - ldb, + b.ld(), b.stride, ((&beta) as *const f16).cast(), - c.ptr, + c.base as _, cudaDataType_t::CUDA_R_16F, - ldc, + c.ld(), c.stride, batch, cublasComputeType_t::CUBLAS_COMPUTE_16F, cublasGemmAlgo_t::CUBLAS_GEMM_DFALT, )); } - -struct Matrix { - batch: c_int, - stride: c_longlong, - r: c_int, - c: c_int, - rs: c_int, - cs: c_int, - ptr: *mut c_void, -} - -impl From<&Tensor> for Matrix -where - T: Deref, -{ - fn from(tensor: &Tensor) -> Self { - let strides = tensor.strides(); - let ptr = (unsafe { tensor.physical().as_raw() } as isize + tensor.bytes_offset()) as _; - match tensor.shape() { - &[r, c] => Self { - batch: 1, - stride: 0, - r: r as _, - c: c as _, - rs: strides[0] as _, - cs: strides[1] as _, - ptr, - }, - &[batch, r, c] => Self { - batch: batch as _, - stride: if batch == 1 { 0 } else { strides[0] as _ }, - r: r as _, - c: c as _, - rs: strides[1] as _, - cs: strides[2] as _, - ptr, - }, - s => panic!("Invalid matrix shape: {s:?}"), - } - } -} - -impl Matrix { - #[inline] - fn match_batch(&self, batch: c_int) -> bool { - self.batch == batch || self.batch == 1 - } - - #[inline] - fn transpose(&mut self) { - swap(&mut self.r, &mut self.c); - swap(&mut self.rs, &mut self.cs); - } - - #[inline] - fn ld_op(&self) -> (c_int, cublasOperation_t) { - if self.rs == 1 { - (self.cs, cublasOperation_t::CUBLAS_OP_N) - } else if self.cs == 1 { - (self.rs, cublasOperation_t::CUBLAS_OP_T) - } else { - panic!("Matrix is not contiguous"); - } - } -} diff --git a/transformer/src/blas.rs b/transformer/src/blas.rs new file mode 100644 index 00000000..73dde096 --- /dev/null +++ b/transformer/src/blas.rs @@ -0,0 +1,68 @@ +use std::{ + ffi::{c_int, c_longlong}, + mem::swap, + os::raw::c_void, +}; +use tensor::Tensor; + +pub struct Matrix { + pub batch: c_int, + pub stride: c_longlong, + pub r: c_int, + pub c: c_int, + pub rs: c_int, + pub cs: c_int, + pub base: *mut c_void, +} + +impl Matrix { + pub fn new(tensor: &Tensor, f: impl FnOnce(&T) -> *mut c_void) -> Self { + let strides = tensor.strides(); + let base = (f(tensor.physical()) as usize + tensor.bytes_offset() as usize) as _; + match tensor.shape() { + &[r, c] => Self { + batch: 1, + stride: 0, + r: r as _, + c: c as _, + rs: strides[0] as _, + cs: strides[1] as _, + base, + }, + &[batch, r, c] => Self { + batch: batch as _, + stride: if batch == 1 { 0 } else { strides[0] as _ }, + r: r as _, + c: c as _, + rs: strides[1] as _, + cs: strides[2] as _, + base, + }, + s => panic!("Invalid matrix shape: {s:?}"), + } + } +} + +impl Matrix { + #[inline] + pub fn match_batch(&self, batch: c_int) -> bool { + self.batch == batch || self.batch == 1 + } + + #[inline] + pub fn transpose(&mut self) { + swap(&mut self.r, &mut self.c); + swap(&mut self.rs, &mut self.cs); + } + + #[inline] + pub fn ld(&self) -> c_int { + if self.rs == 1 { + self.cs + } else if self.cs == 1 { + self.rs + } else { + panic!("Matrix is not contiguous"); + } + } +} diff --git a/transformer/src/lib.rs b/transformer/src/lib.rs index 073e2b43..0987358e 100644 --- a/transformer/src/lib.rs +++ b/transformer/src/lib.rs @@ -2,12 +2,14 @@ #![deny(warnings)] +mod blas; mod cache; mod host_memory; mod parameters; mod request; mod sample; +pub use blas::Matrix; pub use cache::LayerCache; pub use host_memory::HostMemory; pub use parameters::{save, Llama2, Memory, SafeTensorError};