Skip to content

Commit

Permalink
feat(llama): 支持 qwen2 推理
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Jan 15, 2025
1 parent e2f4b04 commit 015484f
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 11 deletions.
23 changes: 14 additions & 9 deletions models/llama/common-cpu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,26 +225,29 @@ impl WeightLoader for Weights<'_> {
let LlamaBlkStorage {
attn_norm,
attn_qkv,
attn_qkv_bias: _,
attn_qkv_bias,
attn_o,
ffn_norm,
ffn_gate_inp,
ffn_gate_up,
ffn_down,
} = &blks[iblk];

use BlkWeight::{AttnNorm, AttnO, AttnQKV, FfnDown, FfnGateInp, FfnGateUp, FfnNorm};
use BlkWeight::{
AttnNorm, AttnO, AttnQKV, AttnQKVBias, FfnDown, FfnGateInp, FfnGateUp, FfnNorm,
};
use Dequant::{Borrowed, Cached};

#[rustfmt::skip]
match which {
AttnNorm => return Borrowed(attn_norm ),
AttnQKV if dt_mat == dt_embd => return Borrowed(attn_qkv ),
AttnO if dt_mat == dt_embd => return Borrowed(attn_o ),
FfnNorm => return Borrowed(ffn_norm ),
FfnGateInp if dt_mat == dt_embd => return Borrowed(ffn_gate_inp),
FfnGateUp if dt_mat == dt_embd => return Borrowed(ffn_gate_up ),
FfnDown if dt_mat == dt_embd => return Borrowed(ffn_down ),
AttnNorm => return Borrowed(attn_norm ),
AttnQKV if dt_mat == dt_embd => return Borrowed(attn_qkv ),
AttnQKVBias if dt_mat == dt_embd => return Borrowed(attn_qkv_bias),
AttnO if dt_mat == dt_embd => return Borrowed(attn_o ),
FfnNorm => return Borrowed(ffn_norm ),
FfnGateInp if dt_mat == dt_embd => return Borrowed(ffn_gate_inp ),
FfnGateUp if dt_mat == dt_embd => return Borrowed(ffn_gate_up ),
FfnDown if dt_mat == dt_embd => return Borrowed(ffn_down ),
_ => {}
};

Expand All @@ -266,6 +269,7 @@ impl WeightLoader for Weights<'_> {
*cached_weight_iblk = iblk;
match which {
AttnQKV => dequant(dt_mat, dt_embd, attn_qkv, &mut cache[..size_qkv]),
AttnQKVBias => todo!("dequant attn qkv bias"),
AttnO => dequant(dt_mat, dt_embd, attn_o, &mut cache[..size_o]),
FfnGateInp => todo!("dequant ffn gate inp"),
FfnGateUp | FfnDown => {
Expand All @@ -285,6 +289,7 @@ impl WeightLoader for Weights<'_> {
weight_cache.borrow(),
match which {
AttnQKV => 0..size_qkv,
AttnQKVBias => todo!("dequant attn qkv bias"),
AttnO => 0..size_o,
FfnGateInp => todo!("dequant ffn gate inp"),
FfnGateUp => 0..size_gate_up,
Expand Down
24 changes: 23 additions & 1 deletion models/llama/common/src/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ pub trait Operators {
pub enum BlkWeight {
AttnNorm,
AttnQKV,
AttnQKVBias,
AttnO,
FfnNorm,
FfnGateInp,
Expand Down Expand Up @@ -151,6 +152,7 @@ where
max_att_len,
} = args;
let LlamaMeta {
attn_qkv_bias,
nblk,
nh,
nkvh,
Expand Down Expand Up @@ -205,8 +207,16 @@ where
let (buf, workspace) = workspace.split_at_mut(*qkv.get());
let mut qkv = qkv.clone().map(|_| buf);

let qkv_add = if attn_qkv_bias {
let bias = self.weights.attn_qkv_bias(iblk, queue).broadcast(0, nt);
self.rearrange(&mut qkv, &bias, workspace, queue_alloc)?;
1.
} else {
0.
};

let w = self.weights.attn_qkv(iblk, queue);
self.mat_mul(&mut qkv, 0., &x1, &w, 1., workspace, queue_alloc)?;
self.mat_mul(&mut qkv, qkv_add, &x1, &w, 1., workspace, queue_alloc)?;
drop(w);

let qkv = qkv.tile(1, &[nh + nkvh + nkvh, dh]);
Expand Down Expand Up @@ -593,6 +603,7 @@ where
struct WeightDecorator<W> {
norm: Tensor<usize>,
attn_qkv: Tensor<usize>,
attn_qkv_bias: Tensor<usize>,
attn_o: Tensor<usize>,
ffn_gate_inp: Tensor<usize>,
ffn_gate_up: Tensor<usize>,
Expand All @@ -608,6 +619,7 @@ impl LlamaMeta {
WeightDecorator {
norm: self.norm(),
attn_qkv: self.attn_qkv(Computation),
attn_qkv_bias: self.attn_qkv_bias(Computation),
attn_o: self.attn_o(Computation),
ffn_gate_inp: self.ffn_gate_inp(Computation),
ffn_gate_up: self.ffn_gate_up(Computation),
Expand Down Expand Up @@ -640,6 +652,16 @@ impl<W: WeightLoader> WeightDecorator<W> {
self.attn_qkv.clone().map(|_| w)
}

#[inline]
pub fn attn_qkv_bias<'a>(
&'a self,
iblk: usize,
queue: &'a QueueOf<W::Hardware>,
) -> Tensor<W::Weight<'a>> {
let w = self.weights.load_blk(BlkWeight::AttnQKVBias, iblk, queue);
self.attn_qkv_bias.clone().map(|_| w)
}

#[inline]
pub fn attn_o<'a>(
&'a self,
Expand Down
5 changes: 5 additions & 0 deletions models/llama/common/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,11 @@ impl LlamaMeta {
self.mat((nh + nkvh + nkvh) * dh, d, usage)
}

pub fn attn_qkv_bias(&self, usage: TensorUsage) -> Tensor<usize> {
let &Self { nh, nkvh, dh, .. } = self;
self.mat((nh + nkvh + nkvh) * dh, 1, usage)
}

pub fn attn_o(&self, usage: TensorUsage) -> Tensor<usize> {
let &Self { nh, d, dh, .. } = self;
self.mat(d, nh * dh, usage)
Expand Down
3 changes: 2 additions & 1 deletion models/llama/common/src/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,8 @@ impl<'w> BlkStorage<&'w [u8]> {
attn_qkv_bias: if len == count {
borrow(self.attn_qkv_bias)
} else {
todo!()
let part = self.attn_qkv_bias.len() / count;
borrow(&self.attn_qkv_bias[start * part..][..len * part])
},
attn_o: if len == count {
borrow(self.attn_o)
Expand Down
2 changes: 2 additions & 0 deletions models/llama/infini/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ impl Weights {
load! {
attn_norm
attn_qkv
attn_qkv_bias
attn_o
ffn_norm
ffn_gate_inp
Expand Down Expand Up @@ -162,6 +163,7 @@ impl WeightLoader for Weights {
match which {
BlkWeight::AttnNorm => &blk.attn_norm,
BlkWeight::AttnQKV => &blk.attn_qkv,
BlkWeight::AttnQKVBias => &blk.attn_qkv_bias,
BlkWeight::AttnO => &blk.attn_o,
BlkWeight::FfnNorm => &blk.ffn_norm,
BlkWeight::FfnGateInp => &blk.ffn_gate_inp,
Expand Down
1 change: 1 addition & 0 deletions models/llama/nvidia-gpu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,7 @@ impl<'ctx> WeightLoader for Weights<'ctx> {
let cache = match which {
BlkWeight::AttnNorm => &self.blks.attn_norm,
BlkWeight::AttnQKV => &self.blks.attn_qkv,
BlkWeight::AttnQKVBias => &self.blks.attn_qkv_bias,
BlkWeight::AttnO => &self.blks.attn_o,
BlkWeight::FfnNorm => &self.blks.ffn_norm,
BlkWeight::FfnGateInp => &self.blks.ffn_gate_inp,
Expand Down
1 change: 1 addition & 0 deletions models/llama/opencl/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ impl WeightLoader for Weights {
match which {
BlkWeight::AttnNorm => &blk.attn_norm,
BlkWeight::AttnQKV => &blk.attn_qkv,
BlkWeight::AttnQKVBias => &blk.attn_qkv_bias,
BlkWeight::AttnO => &blk.attn_o,
BlkWeight::FfnNorm => &blk.ffn_norm,
BlkWeight::FfnGateInp => &blk.ffn_gate_inp,
Expand Down

0 comments on commit 015484f

Please sign in to comment.