From 015484f57a44efd5e0fd44a4216c0f1eac8dc978 Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Wed, 15 Jan 2025 12:37:30 +0800 Subject: [PATCH] =?UTF-8?q?feat(llama):=20=E6=94=AF=E6=8C=81=20qwen2=20?= =?UTF-8?q?=E6=8E=A8=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- models/llama/common-cpu/src/lib.rs | 23 ++++++++++++++--------- models/llama/common/src/compute.rs | 24 +++++++++++++++++++++++- models/llama/common/src/lib.rs | 5 +++++ models/llama/common/src/storage.rs | 3 ++- models/llama/infini/src/lib.rs | 2 ++ models/llama/nvidia-gpu/src/lib.rs | 1 + models/llama/opencl/src/lib.rs | 1 + 7 files changed, 48 insertions(+), 11 deletions(-) diff --git a/models/llama/common-cpu/src/lib.rs b/models/llama/common-cpu/src/lib.rs index f2f8444..f879d53 100644 --- a/models/llama/common-cpu/src/lib.rs +++ b/models/llama/common-cpu/src/lib.rs @@ -225,7 +225,7 @@ impl WeightLoader for Weights<'_> { let LlamaBlkStorage { attn_norm, attn_qkv, - attn_qkv_bias: _, + attn_qkv_bias, attn_o, ffn_norm, ffn_gate_inp, @@ -233,18 +233,21 @@ impl WeightLoader for Weights<'_> { ffn_down, } = &blks[iblk]; - use BlkWeight::{AttnNorm, AttnO, AttnQKV, FfnDown, FfnGateInp, FfnGateUp, FfnNorm}; + use BlkWeight::{ + AttnNorm, AttnO, AttnQKV, AttnQKVBias, FfnDown, FfnGateInp, FfnGateUp, FfnNorm, + }; use Dequant::{Borrowed, Cached}; #[rustfmt::skip] match which { - AttnNorm => return Borrowed(attn_norm ), - AttnQKV if dt_mat == dt_embd => return Borrowed(attn_qkv ), - AttnO if dt_mat == dt_embd => return Borrowed(attn_o ), - FfnNorm => return Borrowed(ffn_norm ), - FfnGateInp if dt_mat == dt_embd => return Borrowed(ffn_gate_inp), - FfnGateUp if dt_mat == dt_embd => return Borrowed(ffn_gate_up ), - FfnDown if dt_mat == dt_embd => return Borrowed(ffn_down ), + AttnNorm => return Borrowed(attn_norm ), + AttnQKV if dt_mat == dt_embd => return Borrowed(attn_qkv ), + AttnQKVBias if dt_mat == dt_embd => return Borrowed(attn_qkv_bias), + AttnO if dt_mat == dt_embd => return Borrowed(attn_o ), + FfnNorm => return Borrowed(ffn_norm ), + FfnGateInp if dt_mat == dt_embd => return Borrowed(ffn_gate_inp ), + FfnGateUp if dt_mat == dt_embd => return Borrowed(ffn_gate_up ), + FfnDown if dt_mat == dt_embd => return Borrowed(ffn_down ), _ => {} }; @@ -266,6 +269,7 @@ impl WeightLoader for Weights<'_> { *cached_weight_iblk = iblk; match which { AttnQKV => dequant(dt_mat, dt_embd, attn_qkv, &mut cache[..size_qkv]), + AttnQKVBias => todo!("dequant attn qkv bias"), AttnO => dequant(dt_mat, dt_embd, attn_o, &mut cache[..size_o]), FfnGateInp => todo!("dequant ffn gate inp"), FfnGateUp | FfnDown => { @@ -285,6 +289,7 @@ impl WeightLoader for Weights<'_> { weight_cache.borrow(), match which { AttnQKV => 0..size_qkv, + AttnQKVBias => todo!("dequant attn qkv bias"), AttnO => 0..size_o, FfnGateInp => todo!("dequant ffn gate inp"), FfnGateUp => 0..size_gate_up, diff --git a/models/llama/common/src/compute.rs b/models/llama/common/src/compute.rs index 99f95fc..ef405e1 100644 --- a/models/llama/common/src/compute.rs +++ b/models/llama/common/src/compute.rs @@ -57,6 +57,7 @@ pub trait Operators { pub enum BlkWeight { AttnNorm, AttnQKV, + AttnQKVBias, AttnO, FfnNorm, FfnGateInp, @@ -151,6 +152,7 @@ where max_att_len, } = args; let LlamaMeta { + attn_qkv_bias, nblk, nh, nkvh, @@ -205,8 +207,16 @@ where let (buf, workspace) = workspace.split_at_mut(*qkv.get()); let mut qkv = qkv.clone().map(|_| buf); + let qkv_add = if attn_qkv_bias { + let bias = self.weights.attn_qkv_bias(iblk, queue).broadcast(0, nt); + self.rearrange(&mut qkv, &bias, workspace, queue_alloc)?; + 1. + } else { + 0. + }; + let w = self.weights.attn_qkv(iblk, queue); - self.mat_mul(&mut qkv, 0., &x1, &w, 1., workspace, queue_alloc)?; + self.mat_mul(&mut qkv, qkv_add, &x1, &w, 1., workspace, queue_alloc)?; drop(w); let qkv = qkv.tile(1, &[nh + nkvh + nkvh, dh]); @@ -593,6 +603,7 @@ where struct WeightDecorator { norm: Tensor, attn_qkv: Tensor, + attn_qkv_bias: Tensor, attn_o: Tensor, ffn_gate_inp: Tensor, ffn_gate_up: Tensor, @@ -608,6 +619,7 @@ impl LlamaMeta { WeightDecorator { norm: self.norm(), attn_qkv: self.attn_qkv(Computation), + attn_qkv_bias: self.attn_qkv_bias(Computation), attn_o: self.attn_o(Computation), ffn_gate_inp: self.ffn_gate_inp(Computation), ffn_gate_up: self.ffn_gate_up(Computation), @@ -640,6 +652,16 @@ impl WeightDecorator { self.attn_qkv.clone().map(|_| w) } + #[inline] + pub fn attn_qkv_bias<'a>( + &'a self, + iblk: usize, + queue: &'a QueueOf, + ) -> Tensor> { + let w = self.weights.load_blk(BlkWeight::AttnQKVBias, iblk, queue); + self.attn_qkv_bias.clone().map(|_| w) + } + #[inline] pub fn attn_o<'a>( &'a self, diff --git a/models/llama/common/src/lib.rs b/models/llama/common/src/lib.rs index 800e988..d9c8a0d 100644 --- a/models/llama/common/src/lib.rs +++ b/models/llama/common/src/lib.rs @@ -99,6 +99,11 @@ impl LlamaMeta { self.mat((nh + nkvh + nkvh) * dh, d, usage) } + pub fn attn_qkv_bias(&self, usage: TensorUsage) -> Tensor { + let &Self { nh, nkvh, dh, .. } = self; + self.mat((nh + nkvh + nkvh) * dh, 1, usage) + } + pub fn attn_o(&self, usage: TensorUsage) -> Tensor { let &Self { nh, d, dh, .. } = self; self.mat(d, nh * dh, usage) diff --git a/models/llama/common/src/storage.rs b/models/llama/common/src/storage.rs index 1e3301e..236a4fc 100644 --- a/models/llama/common/src/storage.rs +++ b/models/llama/common/src/storage.rs @@ -176,7 +176,8 @@ impl<'w> BlkStorage<&'w [u8]> { attn_qkv_bias: if len == count { borrow(self.attn_qkv_bias) } else { - todo!() + let part = self.attn_qkv_bias.len() / count; + borrow(&self.attn_qkv_bias[start * part..][..len * part]) }, attn_o: if len == count { borrow(self.attn_o) diff --git a/models/llama/infini/src/lib.rs b/models/llama/infini/src/lib.rs index 1980945..ca3ff9d 100644 --- a/models/llama/infini/src/lib.rs +++ b/models/llama/infini/src/lib.rs @@ -100,6 +100,7 @@ impl Weights { load! { attn_norm attn_qkv + attn_qkv_bias attn_o ffn_norm ffn_gate_inp @@ -162,6 +163,7 @@ impl WeightLoader for Weights { match which { BlkWeight::AttnNorm => &blk.attn_norm, BlkWeight::AttnQKV => &blk.attn_qkv, + BlkWeight::AttnQKVBias => &blk.attn_qkv_bias, BlkWeight::AttnO => &blk.attn_o, BlkWeight::FfnNorm => &blk.ffn_norm, BlkWeight::FfnGateInp => &blk.ffn_gate_inp, diff --git a/models/llama/nvidia-gpu/src/lib.rs b/models/llama/nvidia-gpu/src/lib.rs index 13d0c97..11222dd 100644 --- a/models/llama/nvidia-gpu/src/lib.rs +++ b/models/llama/nvidia-gpu/src/lib.rs @@ -318,6 +318,7 @@ impl<'ctx> WeightLoader for Weights<'ctx> { let cache = match which { BlkWeight::AttnNorm => &self.blks.attn_norm, BlkWeight::AttnQKV => &self.blks.attn_qkv, + BlkWeight::AttnQKVBias => &self.blks.attn_qkv_bias, BlkWeight::AttnO => &self.blks.attn_o, BlkWeight::FfnNorm => &self.blks.ffn_norm, BlkWeight::FfnGateInp => &self.blks.ffn_gate_inp, diff --git a/models/llama/opencl/src/lib.rs b/models/llama/opencl/src/lib.rs index 2351579..6c4d88a 100644 --- a/models/llama/opencl/src/lib.rs +++ b/models/llama/opencl/src/lib.rs @@ -126,6 +126,7 @@ impl WeightLoader for Weights { match which { BlkWeight::AttnNorm => &blk.attn_norm, BlkWeight::AttnQKV => &blk.attn_qkv, + BlkWeight::AttnQKVBias => &blk.attn_qkv_bias, BlkWeight::AttnO => &blk.attn_o, BlkWeight::FfnNorm => &blk.ffn_norm, BlkWeight::FfnGateInp => &blk.ffn_gate_inp,