Skip to content

Commit

Permalink
refactor(devices): 替换 mat mul
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Jun 13, 2024
1 parent 508fc84 commit 749057c
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 252 deletions.
38 changes: 20 additions & 18 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ rayon = "1.10"
tokio = { version = "1.38", features = ["rt-multi-thread", "sync"] }
build-script-cfg = "0.0"

operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "ea23709", default-features = false }
operators = { git = "https://github.com/YdrMaster/operators-rs", rev = "647dceb", default-features = false }

cuda = { git = "https://github.com/YdrMaster/cuda-driver", rev = "a4b3cf5" }
cublas = { git = "https://github.com/YdrMaster/cuda-driver", rev = "a4b3cf5" }
Expand Down
2 changes: 1 addition & 1 deletion devices/cpu/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ authors = ["YdrMaster <[email protected]>"]
common = { path = "../../common" }
tensor = { path = "../../tensor" }
kernel-lib = { path = "../../kernel-lib" }
gemm = "0.17"
gemm = "0.18"
intel-mkl-src = { version = "0.8", features = ["mkl-dynamic-lp64-iomp"] }
operators = { workspace = true, features = ["common-cpu"] }

Expand Down
153 changes: 30 additions & 123 deletions devices/cpu/src/mat_mul.rs
Original file line number Diff line number Diff line change
@@ -1,137 +1,45 @@
use common::{f16, BetweenF32};
use gemm::gemm;
use kernel_lib::Matrix;
use std::{
ffi::{c_longlong, c_void},
mem::swap,
ops::{Deref, DerefMut},
use crate::layout;
use operators::{
common_cpu::ThisThread,
mat_mul::{
common_cpu::{Operator as MatMul, Scheme as MatMulScheme},
LayoutAttrs,
},
Operator, Scheme, F16,
};
use tensor::{DataType, Tensor};
use std::ops::{Deref, DerefMut};
use tensor::Tensor;

/// c = a x b
///
/// - c: [N0, N1, ... , N_, m, n]
/// - a: [N0, N1, ... , N_, m, k]
/// - b: [N0, N1, ... , N_, k, n]
pub fn mat_mul<T, U, V>(c: &mut Tensor<T>, beta: f32, a: &Tensor<U>, b: &Tensor<V>, alpha: f32)
where
T: DerefMut<Target = [u8]>,
U: Deref<Target = [u8]>,
V: Deref<Target = [u8]>,
{
let dt = c.data_type();

#[inline]
fn base(tensor: &impl Deref<Target = [u8]>) -> *mut c_void {
tensor.as_ptr() as _
}

let mut c = Matrix::new(c, base);
let mut a = Matrix::new(a, base);
let mut b = Matrix::new(b, base);
assert_eq!(c.r, a.r); // m
assert_eq!(c.c, b.c); // n
assert_eq!(a.c, b.r); // k

let batch = c.batch;
if !a.match_batch(batch) || !b.match_batch(batch) {
panic!("Invalid batch size");
}

if c.rs == 1 {
// Nothing to do
} else if c.cs == 1 {
c.transpose();
a.transpose();
b.transpose();
swap(&mut a, &mut b);
} else {
panic!("Matrix is not contiguous");
};

assert_eq!(c.batch, a.batch);
assert_eq!(c.batch, b.batch);
// const LAYER: usize = 40;
// static mut I: usize = 0;
// unsafe {
// if I == 0 {
// println!();
// // #[cfg(detected_mkl)]
// // {
// // println!("MKL threads: {}", mkl::mkl_get_max_threads());
// // println!("MKL dynamic: {}", mkl::mkl_get_dynamic());
// // }
// }
// }
// let time = std::time::Instant::now();
match dt {
// DataType::F16 => mkl::gemm(dt, c, beta, alpha, a, b),
DataType::F16 => gemm_as_blas::<f16>(c, beta, alpha, a, b),
#[cfg(not(detected_mkl))]
DataType::F32 => gemm_as_blas::<f32>(c, beta, alpha, a, b),
#[cfg(detected_mkl)]
DataType::F32 => mkl::gemm(dt, c, beta, alpha, a, b),
_ => unreachable!(),
}
// unsafe {
// if I % 6 == 0 {
// println!();
// }
// println!("{}:{} {}", I / 6, I % 6, time.elapsed().as_micros());
// if I == LAYER * 6 {
// I = 0;
// } else {
// I += 1;
// }
// }
}

fn gemm_as_blas<T: 'static + BetweenF32>(c: Matrix, beta: f32, alpha: f32, a: Matrix, b: Matrix) {
let batch = c.batch;
assert_eq!(a.batch, batch);
assert_eq!(b.batch, batch);

let m = c.r;
let n = c.c;
let k = a.c;
assert_eq!(a.r, m);
assert_eq!(b.r, k);
assert_eq!(b.c, n);

let c_ = c.base.cast::<T>();
let a_ = a.base.cast::<T>();
let b_ = b.base.cast::<T>();
for i in 0..batch as c_longlong {
unsafe {
let c_ = c_.offset((i * c.stride) as isize);
let a_ = a_.offset((i * a.stride) as isize);
let b_ = b_.offset((i * b.stride) as isize);
gemm(
m as _,
n as _,
k as _,
c_,
c.cs as _,
c.rs as _,
beta != 0.,
a_,
a.cs as _,
a.rs as _,
b_,
b.cs as _,
b.rs as _,
T::cast(beta),
T::cast(alpha),
false,
false,
false,
gemm::Parallelism::Rayon(0),
)
}
}
MatMulScheme::new(
&MatMul::new(&F16).unwrap(),
LayoutAttrs {
c: layout(c),
a: layout(a),
b: layout(b),
},
)
.unwrap()
.launch(
&(
c.physical_mut().as_mut_ptr(),
beta,
a.physical().as_ptr(),
b.physical().as_ptr(),
alpha,
),
&ThisThread,
);
}

#[cfg(detected_mkl)]
#[allow(unused)]
mod mkl {
use gemm::f16;
use kernel_lib::Matrix;
Expand All @@ -147,7 +55,6 @@ mod mkl {
Ordinary = 112,
}

#[allow(unused)]
extern "C" {
pub fn mkl_get_max_threads() -> c_int;
pub fn mkl_get_dynamic() -> c_int;
Expand Down
Loading

0 comments on commit 749057c

Please sign in to comment.