Skip to content

Commit

Permalink
feat(transformer-cpu): 测试 mkl 性能
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Mar 19, 2024
1 parent 3f4e442 commit 6286fea
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 35 deletions.
78 changes: 55 additions & 23 deletions transformer-cpu/src/kernel/mat_mul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,39 @@ where

assert_eq!(c.batch, a.batch);
assert_eq!(c.batch, b.batch);
// const LAYER: usize = 40;
// static mut I: usize = 0;
// unsafe {
// if I == 0 {
// println!();
// // #[cfg(detected_mkl)]
// // {
// // println!("MKL threads: {}", mkl::mkl_get_max_threads());
// // println!("MKL dynamic: {}", mkl::mkl_get_dynamic());
// // }
// }
// }
// let time = std::time::Instant::now();
match dt {
// DataType::F16 => mkl::gemm(dt, c, beta, alpha, a, b),
DataType::F16 => gemm_as_blas::<f16>(c, beta, alpha, a, b),
#[cfg(not(detected_mkl))]
DataType::F32 => gemm_as_blas::<f32>(c, beta, alpha, a, b),
#[cfg(detected_mkl)]
DataType::F32 => mkl::gemm(dt, c, beta, alpha, a, b),
_ => unreachable!(),
}
// unsafe {
// if I % 6 == 0 {
// println!();
// }
// println!("{}:{} {}", I / 6, I % 6, time.elapsed().as_micros());
// if I == LAYER * 6 {
// I = 0;
// } else {
// I += 1;
// }
// }
}

fn gemm_as_blas<T: 'static + BetweenF32>(c: Matrix, beta: f32, alpha: f32, a: Matrix, b: Matrix) {
Expand Down Expand Up @@ -108,9 +133,10 @@ fn gemm_as_blas<T: 'static + BetweenF32>(c: Matrix, beta: f32, alpha: f32, a: Ma
#[cfg(detected_mkl)]
mod mkl {
use gemm::f16;
use std::ffi::c_int;
use tensor::DataType;
use transformer::Matrix;
const COL_MAJOR: i32 = 102;
const COL_MAJOR: c_int = 102;

#[derive(Clone, Copy, Debug, Eq, PartialEq)]
#[repr(C)]
Expand All @@ -120,39 +146,45 @@ mod mkl {
Ordinary = 112,
}

#[allow(unused)]
extern "C" {
pub fn mkl_get_max_threads() -> c_int;
pub fn mkl_get_dynamic() -> c_int;
pub fn mkl_set_num_threads(nt: c_int);
pub fn mkl_set_num_threads_local(nt: c_int);

fn cblas_hgemm(
layout: i32,
layout: c_int,
transa: CBLAS_TRANSPOSE,
transb: CBLAS_TRANSPOSE,
m: i32,
n: i32,
k: i32,
m: c_int,
n: c_int,
k: c_int,
alpha: f16,
a: *const f16,
lda: i32,
lda: c_int,
b: *const f16,
ldb: i32,
ldb: c_int,
beta: f16,
c: *mut f16,
ldc: i32,
ldc: c_int,
);

fn cblas_sgemm(
layout: i32,
layout: c_int,
transa: CBLAS_TRANSPOSE,
transb: CBLAS_TRANSPOSE,
m: i32,
n: i32,
k: i32,
m: c_int,
n: c_int,
k: c_int,
alpha: f32,
a: *const f32,
lda: i32,
lda: c_int,
b: *const f32,
ldb: i32,
ldb: c_int,
beta: f32,
c: *mut f32,
ldc: i32,
ldc: c_int,
);
}

Expand Down Expand Up @@ -183,10 +215,10 @@ mod mkl {
DataType::F16 => unsafe {
let alpha = f16::from_f32(alpha);
let beta = f16::from_f32(beta);
for i in 0..batch {
let a_ = a.base.cast::<f16>().offset(i as isize * a.stride as isize);
let b_ = b.base.cast::<f16>().offset(i as isize * b.stride as isize);
let c_ = c.base.cast::<f16>().offset(i as isize * c.stride as isize);
for i in 0..batch as isize {
let a_ = a.base.cast::<f16>().offset(i * a.stride as isize);
let b_ = b.base.cast::<f16>().offset(i * b.stride as isize);
let c_ = c.base.cast::<f16>().offset(i * c.stride as isize);
cblas_hgemm(
COL_MAJOR,
trans(&a),
Expand All @@ -206,10 +238,10 @@ mod mkl {
}
},
DataType::F32 => unsafe {
for i in 0..batch {
let a_ = a.base.cast::<f32>().offset(i as isize * a.stride as isize);
let b_ = b.base.cast::<f32>().offset(i as isize * b.stride as isize);
let c_ = c.base.cast::<f32>().offset(i as isize * c.stride as isize);
for i in 0..batch as isize {
let a_ = a.base.cast::<f32>().offset(i * a.stride as isize);
let b_ = b.base.cast::<f32>().offset(i * b.stride as isize);
let c_ = c.base.cast::<f32>().offset(i * c.stride as isize);
cblas_sgemm(
COL_MAJOR,
trans(&a),
Expand Down
24 changes: 12 additions & 12 deletions transformer-cpu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,9 @@ impl Transformer {
let v_att = v_cache.as_ref().slice(att_slice);
let k_att = unsafe { k_att.map_physical(|u| &**u) };
let v_att = unsafe { v_att.map_physical(|u| &**u) };
// println!("layer {layer} q attention:\n{}", q_att);
// println!("layer {layer} k attention:\n{}", k_att.access());
// println!("layer {layer} v attention:\n{}", v_att.access());
// println!("layer {layer} q attention:\n{q_att}");
// println!("layer {layer} k attention:\n{k_att}");
// println!("layer {layer} v attention:\n{v_att}");

let shape_att0 = &[nkvh, head_group * seq_len, att_len];
let shape_att1 = &[nkvh * head_group, seq_len, att_len];
Expand All @@ -171,26 +171,26 @@ impl Transformer {

let wo = self.0.self_attn_o_proj(layer).transpose(&[1, 0]);
mat_mul(&mut x0, 1., &x1, &wo, 1.);
// println!("layer {layer} o_proj:\n{}", x0.access());
// println!("layer {layer} o_proj:\n{x0}");

let post_layernorm = self.0.post_attention_layernorm(layer);
rms_norm(&mut x1, &x0, &post_layernorm, epsilon);
// println!("layer {layer} post norm:\n{}", x1.access());
// println!("layer {layer} post norm:\n{x1}");

let w_gate_up = self.0.mlp_gate_up(layer).transpose(&[1, 0]);
mat_mul(&mut gate_up, 0., &x1, &w_gate_up, 1.);
let mut gate_up = gate_up.split(1, &[di as _, di as _]);
let up = gate_up.pop().unwrap();
let mut gate = gate_up.pop().unwrap();
// println!("layer {layer} gate:\n{}", gate.access());
// println!("layer {layer} up:\n{}", up.access());
// println!("layer {layer} gate:\n{gate}");
// println!("layer {layer} up:\n{up}");

swiglu(&mut gate, &up);
// println!("layer {layer} swiglu:\n{}", gate.access());
// println!("layer {layer} swiglu:\n{gate}");

let mlp_down = self.0.mlp_down(layer).transpose(&[1, 0]);
mat_mul(&mut x0, 1., &gate, &mlp_down, 1.);
// println!("layer {layer} down:\n{}", x0.access());
// println!("layer {layer} down:\n{x0}");
}

let tokens = {
Expand Down Expand Up @@ -218,19 +218,19 @@ impl Transformer {
vec![0u8; (tokens.len * voc) as usize * dt.size()],
);
let mut x = x0.slice(&[tokens, slice![all]]);
// println!("decode slice:\n{}", x.access());
// println!("decode slice:\n{x}");

// 复制一个 x 以实现原地归一化
let x_ = unsafe {
x.as_ref()
.map_physical(|u| std::slice::from_raw_parts(u.as_ptr(), u.len()))
};
rms_norm(&mut x, &x_, &self.0.model_norm(), epsilon);
// println!("model norm:\n{}", x.access());
// println!("model norm:\n{x}");

let lm_head = self.0.lm_head().transpose(&[1, 0]);
mat_mul(&mut logits, 0., &x, &lm_head, 1.);
// println!("logits:\n{}", logits.access());
// println!("logits:\n{logits}");

macro_rules! sample {
($ty:ty) => {{
Expand Down
1 change: 1 addition & 0 deletions transformer/src/blas.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
};
use tensor::Tensor;

#[derive(Clone, Debug)]
pub struct Matrix {
pub batch: c_int,
pub stride: c_longlong,
Expand Down

0 comments on commit 6286fea

Please sign in to comment.