From 6286fea73041e52ec6ff5541ba67338888c5dac2 Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Wed, 20 Mar 2024 04:55:50 +0800 Subject: [PATCH] =?UTF-8?q?feat(transformer-cpu):=20=E6=B5=8B=E8=AF=95=20m?= =?UTF-8?q?kl=20=E6=80=A7=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- transformer-cpu/src/kernel/mat_mul.rs | 78 +++++++++++++++++++-------- transformer-cpu/src/lib.rs | 24 ++++----- transformer/src/blas.rs | 1 + 3 files changed, 68 insertions(+), 35 deletions(-) diff --git a/transformer-cpu/src/kernel/mat_mul.rs b/transformer-cpu/src/kernel/mat_mul.rs index 31dce21e..0e04f327 100644 --- a/transformer-cpu/src/kernel/mat_mul.rs +++ b/transformer-cpu/src/kernel/mat_mul.rs @@ -50,7 +50,21 @@ where assert_eq!(c.batch, a.batch); assert_eq!(c.batch, b.batch); + // const LAYER: usize = 40; + // static mut I: usize = 0; + // unsafe { + // if I == 0 { + // println!(); + // // #[cfg(detected_mkl)] + // // { + // // println!("MKL threads: {}", mkl::mkl_get_max_threads()); + // // println!("MKL dynamic: {}", mkl::mkl_get_dynamic()); + // // } + // } + // } + // let time = std::time::Instant::now(); match dt { + // DataType::F16 => mkl::gemm(dt, c, beta, alpha, a, b), DataType::F16 => gemm_as_blas::(c, beta, alpha, a, b), #[cfg(not(detected_mkl))] DataType::F32 => gemm_as_blas::(c, beta, alpha, a, b), @@ -58,6 +72,17 @@ where DataType::F32 => mkl::gemm(dt, c, beta, alpha, a, b), _ => unreachable!(), } + // unsafe { + // if I % 6 == 0 { + // println!(); + // } + // println!("{}:{} {}", I / 6, I % 6, time.elapsed().as_micros()); + // if I == LAYER * 6 { + // I = 0; + // } else { + // I += 1; + // } + // } } fn gemm_as_blas(c: Matrix, beta: f32, alpha: f32, a: Matrix, b: Matrix) { @@ -108,9 +133,10 @@ fn gemm_as_blas(c: Matrix, beta: f32, alpha: f32, a: Ma #[cfg(detected_mkl)] mod mkl { use gemm::f16; + use std::ffi::c_int; use tensor::DataType; use transformer::Matrix; - const COL_MAJOR: i32 = 102; + const COL_MAJOR: c_int = 102; #[derive(Clone, Copy, Debug, Eq, PartialEq)] #[repr(C)] @@ -120,39 +146,45 @@ mod mkl { Ordinary = 112, } + #[allow(unused)] extern "C" { + pub fn mkl_get_max_threads() -> c_int; + pub fn mkl_get_dynamic() -> c_int; + pub fn mkl_set_num_threads(nt: c_int); + pub fn mkl_set_num_threads_local(nt: c_int); + fn cblas_hgemm( - layout: i32, + layout: c_int, transa: CBLAS_TRANSPOSE, transb: CBLAS_TRANSPOSE, - m: i32, - n: i32, - k: i32, + m: c_int, + n: c_int, + k: c_int, alpha: f16, a: *const f16, - lda: i32, + lda: c_int, b: *const f16, - ldb: i32, + ldb: c_int, beta: f16, c: *mut f16, - ldc: i32, + ldc: c_int, ); fn cblas_sgemm( - layout: i32, + layout: c_int, transa: CBLAS_TRANSPOSE, transb: CBLAS_TRANSPOSE, - m: i32, - n: i32, - k: i32, + m: c_int, + n: c_int, + k: c_int, alpha: f32, a: *const f32, - lda: i32, + lda: c_int, b: *const f32, - ldb: i32, + ldb: c_int, beta: f32, c: *mut f32, - ldc: i32, + ldc: c_int, ); } @@ -183,10 +215,10 @@ mod mkl { DataType::F16 => unsafe { let alpha = f16::from_f32(alpha); let beta = f16::from_f32(beta); - for i in 0..batch { - let a_ = a.base.cast::().offset(i as isize * a.stride as isize); - let b_ = b.base.cast::().offset(i as isize * b.stride as isize); - let c_ = c.base.cast::().offset(i as isize * c.stride as isize); + for i in 0..batch as isize { + let a_ = a.base.cast::().offset(i * a.stride as isize); + let b_ = b.base.cast::().offset(i * b.stride as isize); + let c_ = c.base.cast::().offset(i * c.stride as isize); cblas_hgemm( COL_MAJOR, trans(&a), @@ -206,10 +238,10 @@ mod mkl { } }, DataType::F32 => unsafe { - for i in 0..batch { - let a_ = a.base.cast::().offset(i as isize * a.stride as isize); - let b_ = b.base.cast::().offset(i as isize * b.stride as isize); - let c_ = c.base.cast::().offset(i as isize * c.stride as isize); + for i in 0..batch as isize { + let a_ = a.base.cast::().offset(i * a.stride as isize); + let b_ = b.base.cast::().offset(i * b.stride as isize); + let c_ = c.base.cast::().offset(i * c.stride as isize); cblas_sgemm( COL_MAJOR, trans(&a), diff --git a/transformer-cpu/src/lib.rs b/transformer-cpu/src/lib.rs index d31a376d..09ad69f1 100644 --- a/transformer-cpu/src/lib.rs +++ b/transformer-cpu/src/lib.rs @@ -151,9 +151,9 @@ impl Transformer { let v_att = v_cache.as_ref().slice(att_slice); let k_att = unsafe { k_att.map_physical(|u| &**u) }; let v_att = unsafe { v_att.map_physical(|u| &**u) }; - // println!("layer {layer} q attention:\n{}", q_att); - // println!("layer {layer} k attention:\n{}", k_att.access()); - // println!("layer {layer} v attention:\n{}", v_att.access()); + // println!("layer {layer} q attention:\n{q_att}"); + // println!("layer {layer} k attention:\n{k_att}"); + // println!("layer {layer} v attention:\n{v_att}"); let shape_att0 = &[nkvh, head_group * seq_len, att_len]; let shape_att1 = &[nkvh * head_group, seq_len, att_len]; @@ -171,26 +171,26 @@ impl Transformer { let wo = self.0.self_attn_o_proj(layer).transpose(&[1, 0]); mat_mul(&mut x0, 1., &x1, &wo, 1.); - // println!("layer {layer} o_proj:\n{}", x0.access()); + // println!("layer {layer} o_proj:\n{x0}"); let post_layernorm = self.0.post_attention_layernorm(layer); rms_norm(&mut x1, &x0, &post_layernorm, epsilon); - // println!("layer {layer} post norm:\n{}", x1.access()); + // println!("layer {layer} post norm:\n{x1}"); let w_gate_up = self.0.mlp_gate_up(layer).transpose(&[1, 0]); mat_mul(&mut gate_up, 0., &x1, &w_gate_up, 1.); let mut gate_up = gate_up.split(1, &[di as _, di as _]); let up = gate_up.pop().unwrap(); let mut gate = gate_up.pop().unwrap(); - // println!("layer {layer} gate:\n{}", gate.access()); - // println!("layer {layer} up:\n{}", up.access()); + // println!("layer {layer} gate:\n{gate}"); + // println!("layer {layer} up:\n{up}"); swiglu(&mut gate, &up); - // println!("layer {layer} swiglu:\n{}", gate.access()); + // println!("layer {layer} swiglu:\n{gate}"); let mlp_down = self.0.mlp_down(layer).transpose(&[1, 0]); mat_mul(&mut x0, 1., &gate, &mlp_down, 1.); - // println!("layer {layer} down:\n{}", x0.access()); + // println!("layer {layer} down:\n{x0}"); } let tokens = { @@ -218,7 +218,7 @@ impl Transformer { vec![0u8; (tokens.len * voc) as usize * dt.size()], ); let mut x = x0.slice(&[tokens, slice![all]]); - // println!("decode slice:\n{}", x.access()); + // println!("decode slice:\n{x}"); // 复制一个 x 以实现原地归一化 let x_ = unsafe { @@ -226,11 +226,11 @@ impl Transformer { .map_physical(|u| std::slice::from_raw_parts(u.as_ptr(), u.len())) }; rms_norm(&mut x, &x_, &self.0.model_norm(), epsilon); - // println!("model norm:\n{}", x.access()); + // println!("model norm:\n{x}"); let lm_head = self.0.lm_head().transpose(&[1, 0]); mat_mul(&mut logits, 0., &x, &lm_head, 1.); - // println!("logits:\n{}", logits.access()); + // println!("logits:\n{logits}"); macro_rules! sample { ($ty:ty) => {{ diff --git a/transformer/src/blas.rs b/transformer/src/blas.rs index 73dde096..631b2ee6 100644 --- a/transformer/src/blas.rs +++ b/transformer/src/blas.rs @@ -5,6 +5,7 @@ }; use tensor::Tensor; +#[derive(Clone, Debug)] pub struct Matrix { pub batch: c_int, pub stride: c_longlong,