Skip to content

Commit

Permalink
feat: 添加mla
Browse files Browse the repository at this point in the history
  • Loading branch information
onenewcode committed Feb 24, 2025
1 parent c469a04 commit ad2ff0b
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 74 deletions.
3 changes: 2 additions & 1 deletion models/minicpm3/common-cpu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,15 @@ where
type Hardware = Cpu;
type TopoNode = N;
type Rope = op!(rope);
type Attention = op!(attention);
type AttentionMLA = op!(attention_mla);
type RmsNorm = op!(rms_norm);
type Add = op!(add);
type MatMul = op!(mat_mul);
type Swiglu = op!(swiglu);
type Rearrange = op!(rearrange);
type Scale = op!(scale);
type AttnKVCached = op!(attention_kv_cached);
type FuesdSoftmax = op!(fuesd_softmax);
type AllReduce = R;

fn debug<T>(tensor: &Tensor<T>, _queue: &QueueOf<Self::Hardware>)
Expand Down
171 changes: 99 additions & 72 deletions models/minicpm3/common/src/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@ use gguf::ggml_quants::digit_layout::types as ty;
use gguf::ggml_quants::digit_layout::DigitLayout;
use half::f16;
use itertools::Itertools;
use operators::fuesd_softmax;
use operators::fuesd_softmax::FusedSoftmax;
use operators::scale;
use operators::scale::Scale;
use operators::{
add::{self, Add},
all_reduce::{self, AllReduce, ReduceOp},
attention::{self, Attention},
attention_kv_cached::AttnKVCached,
attention_mla::{self, AttentionMLA},
fuesd_softmax::AttnMask,
mat_mul::{self, MatMul},
rearrange::{self, Rearrange},
Expand All @@ -19,20 +21,22 @@ use operators::{
ByteOf, Hardware, LaunchError, Operator, QueueAlloc, QueueOf, TopoNode, Workspace,
};
use std::ops::{Deref, DerefMut};
use std::process::Output;
use tensor::split_mut;
use tensor::{split, Tensor};

pub trait Operators {
type Hardware: Hardware;
type TopoNode: TopoNode<Self::Hardware>;
type Attention: Attention<Self::Hardware>;
type AttentionMLA: AttentionMLA<Self::Hardware>;
type AttnKVCached: AttnKVCached<Self::Hardware>;
type Rope: Rope<Self::Hardware>;
type RmsNorm: RmsNorm<Self::Hardware>;
type Add: Add<Self::Hardware>;
type MatMul: MatMul<Self::Hardware>;
type Swiglu: Swiglu<Self::Hardware>;
type Scale: Scale<Self::Hardware>;
type FuesdSoftmax: FusedSoftmax<Self::Hardware>;
type Rearrange: Rearrange<Self::Hardware>;
type AllReduce: AllReduce<Self::Hardware, Self::TopoNode>;

Expand Down Expand Up @@ -81,12 +85,13 @@ pub struct Minicpm3Worker<Ops: Operators, W> {
dt_pos: DigitLayout,
add: Ops::Add,
attn_kv_cached: Ops::AttnKVCached,
attention: Ops::Attention,
attention_mla: Ops::AttentionMLA,
rope: Ops::Rope,
rms_norm: Ops::RmsNorm,
mat_mul: Ops::MatMul,
scale: Ops::Scale,
swiglu: Ops::Swiglu,
fuesd_softmax: Ops::FuesdSoftmax,
rearrange: Ops::Rearrange,
all_reduce: Ops::AllReduce,
}
Expand All @@ -108,7 +113,8 @@ impl<Ops: Operators, W> Minicpm3Worker<Ops, W> {
add: Ops::Add::new(processor),
all_reduce: Ops::AllReduce::new(node),
dt_pos: ty::U64,
attention: Ops::Attention::new(processor),
attention_mla: Ops::AttentionMLA::new(processor),
fuesd_softmax: Ops::FuesdSoftmax::new(processor),
}
}

Expand Down Expand Up @@ -165,12 +171,11 @@ where

let gate_up = tensor(&[nt, di * 2]);
// 空间 x+x1+q(应该可以删除)+q3+kv_pe+attn
let workspace_size = *x1.get() * 3 + *gate_up.get();
let workspace_size = *x1.get() * 20 + *gate_up.get();
let mut workspace = Workspace::new(queue_alloc, workspace, workspace_size);
let (buf, workspace) = workspace.split_at_mut(*x1.get());
let mut x1 = x1.map(|_| buf);


let queue = queue_alloc.queue();

let sin = sin_cos.clone().index(0, 0);
Expand Down Expand Up @@ -205,17 +210,15 @@ where
let w = self.weights.attn_qa_norm(iblk, queue);
self.rms_norm(&mut q, &inplace, &w, workspace, queue_alloc)?;
{
// q [1, 768] q1 [1, 3840] kv_pe [1,288] kv [1, 5120] k [1, 3840] attn [1, 2560]
let q1 = tensor(&[nt, nh * dk]);
let (buf, workspace) = workspace.split_at_mut(*q1.get());
let mut q1 = q1.map(|_| buf);
let w = self.weights.attn_qb(iblk, queue).transpose(&[1, 0]);
self.mat_mul(&mut q1, 0., &q, &w, 1., workspace, queue_alloc)?;
drop(q);
// q3 是计算 attn 需要用到的数据,但是我们仍然需要对 q3 的的部分进行嵌入操作

let mut q3 = q1.tile(1, &[nh, dk]);
let q2 = unsafe { q3.map_slice_static_mut() };
split_mut!(q2=>_q, q_rope;[dnope, dh]@ 2);
split_mut!(q2=>q_nope, q_rope;[dnope, dh]@ 2);

// kv_pe [1,288]
let kv_pe = tensor(&[nt, dkv_lora + dh]);
Expand All @@ -224,62 +227,69 @@ where

let w = self.weights.attn_kva(iblk, queue).transpose(&[1, 0]);
self.mat_mul(&mut kv_pe, 0., &x1, &w, 1., workspace, queue_alloc)?;

drop(q);
split_mut!(kv_pe => kv_lora, k_rope; [dkv_lora, dh] @ 1);

self.rope(&mut q_rope, &pos, &sin, &cos, workspace, queue_alloc)?;
let mut k_rope = k_rope.tile(1, &[1, dh]);
self.rope(&mut k_rope, &pos, &sin, &cos, workspace, queue_alloc)?;
let k_rope = k_rope.broadcast(1, nh);

let inplace = unsafe { kv_lora.map_slice_static() };
let w = self.weights.attn_kva_norm(iblk, queue);
self.rms_norm(&mut kv_lora, &inplace, &w, workspace, queue_alloc)?;
// kv X[1, 5120]
let kv = tensor(&[nt, nh * (dnope + dv)]);
let (buf, workspace) = workspace.split_at_mut(*kv.get());
let mut kv = kv.map(|_| buf);
let w = self.weights.attn_kvb(iblk, queue).transpose(&[1, 0]);

self.mat_mul(&mut kv, 0., &kv_lora, &w, 1., workspace, queue_alloc)?;

let kv = kv.tile(1, &[nh, dnope + dv]);

split_mut!(kv => k_nope ,v ; [dnope , dv ] @ 2);

// k [1, 3840]
let k = tensor(&[nt, nh, dk]);
let (buf, workspace) = workspace.split_at_mut(*k.get());
let k = k.map(|_| buf);

split_mut!(k => k_nope_r ,k_rope_r ; [dnope, dh] @ 2);
let kv_b_proj = unsafe {
self.weights
.attn_kvb(iblk, queue)
.tile(0, &[nh, dnope + dv])
.map_slice_static()
};
split!(kv_b_proj=> q_absorb , out_absorb ; [dnope, dv] @ 1);
let inplace = unsafe { q_nope.map_slice_static() };

let q_nope_0 = q_nope.map_slice().transpose(&[1, 0]);
let q_nope_1 = tensor(&[nh, nt, dkv_lora]);
let (buf, workspace) = workspace.split_at_mut(*q_nope_1.get());
let mut q_nope = q_nope_1.map(|_| buf);
self.mat_mul(
&mut q_nope,
0.,
&q_nope_0,
&q_absorb,
1.,
workspace,
queue_alloc,
)?;

self.rope(&mut q_rope, &pos, &sin, &cos, workspace, queue_alloc)?;
let mut k_rope = k_rope.tile(1, &[1, dh]);
self.rope(&mut k_rope, &pos, &sin, &cos, workspace, queue_alloc)?;
let k_rope = k_rope.broadcast(1, nh);
self.rearrange(&mut k_rope_r, &k_rope, workspace, queue_alloc)?;
self.rearrange(&mut k_nope_r, &k_nope, workspace, queue_alloc)?;

let pos = requests.last().unwrap().pos as f32;
let mut q = q3.transpose(&[1, 0]);
let k = k.map_slice().transpose(&[1, 0]);
let v = v.map_slice_mut().transpose(&[1, 0]);
// 经行 attention
let attn = tensor(&[nt, nh, dv]);
let (buf, workspace) = workspace.split_at_mut(*attn.get());
let mut attn = attn.map(|_| buf);

let mut attn = unsafe { attn.map_slice_mut().transpose(&[1, 0]) };
let pos = requests.last().unwrap().pos as f32;
drop(q3);
// attn_output
let attn_output = tensor(&[nt, nh, dv]);
let (buf, workspace) = workspace.split_at_mut(*attn_output.get());
let mut attn_output = attn_output.map(|_| buf);
let q_rope = q_rope.transpose(&[1, 0]);
let k_rope = k_rope.transpose(&[1, 0]);
let kv_lora = kv_lora.map_slice().tile(0, &[1, 1]).broadcast(0, nh);
let mut o=unsafe {
attn_output.map_slice_static_mut().transpose(&[1, 0])
};
self.attnention(
&mut q,
&k,
&v,
&mut attn,
pos as usize,
&mut q_nope,
&kv_lora,
&out_absorb,
&q_rope,
&k_rope,
&mut o,
1,
workspace,
queue_alloc,
)?;

let o = attn.transpose(&[1, 0]).merge(1..3).unwrap();
let o = attn_output.map_slice().merge(1..3).unwrap();
let w = self.weights.attn_o(iblk, queue);

self.mat_mul(&mut x1, 0., &o, &w, s, workspace, queue_alloc)?;
let inplace = unsafe { x.map_slice_static() };
self.add(&mut x, &inplace, &x1, workspace, queue_alloc)?;
Expand All @@ -301,17 +311,6 @@ where

self.swiglu(&mut gate, &up, workspace, queue_alloc)?;

fn print_first_10_elements(ptr: *const f16) {
assert!(!ptr.is_null(), "Pointer must not be null");

unsafe {
for i in 0..10 {
// 逐个访问并打印前 10 个元素
let element = ptr.offset(i as isize).read();
println!("Element {}: {:?}", i, element);
}
}
}

let w = self.weights.ffn_down(iblk, queue);
self.mat_mul(&mut x1, 0., &gate, &w, s, workspace, queue_alloc)?;
Expand Down Expand Up @@ -460,31 +459,39 @@ where
queue_alloc,
)
}
fn attnention<Q, K, V, O, QA>(
fn attnention<Q, KV, A, QR, KR, O, QA>(
&self,
q: &mut Tensor<Q>,
k: &Tensor<K>,
v: &Tensor<V>,
kv: &Tensor<KV>,
a: &Tensor<A>,
qr: &Tensor<QR>,
kr: &Tensor<KR>,
o: &mut Tensor<O>,
pos: usize,
workspace: &mut [ByteOf<Ops::Hardware>],
queue_alloc: &QA,
) -> Result<(), LaunchError>
where
Q: DerefMut<Target = [ByteOf<Ops::Hardware>]>,
K: Deref<Target = [ByteOf<Ops::Hardware>]>,
V: Deref<Target = [ByteOf<Ops::Hardware>]>,
KV: Deref<Target = [ByteOf<Ops::Hardware>]>,
A: Deref<Target = [ByteOf<Ops::Hardware>]>,
QR: Deref<Target = [ByteOf<Ops::Hardware>]>,
KR: Deref<Target = [ByteOf<Ops::Hardware>]>,
O: DerefMut<Target = [ByteOf<Ops::Hardware>]>,
QA: QueueAlloc<Hardware = Ops::Hardware>,
{
self.attention.launch(
&attention::Args {
self.attention_mla.launch(
&attention_mla::Args {
q_layout: q.layout(),
q_base: q.base_mut(),
k_layout: k.layout(),
k_base: k.base(),
v_layout: v.layout(),
v_base: v.base(),
kv_layout: kv.layout(),
kv_base: kv.base(),
absorb_layout: a.layout(),
absorb_base: a.base(),
qr_layout: qr.layout(),
qr_base: qr.base(),
kr_layout: kr.layout(),
kr_base: kr.base(),
o_layout: o.layout(),
o_base: o.base_mut(),
mask: AttnMask::Causal,
Expand Down Expand Up @@ -594,6 +601,26 @@ where
queue_alloc,
)
}
fn softmax<A, QA>(
&self,
a: &mut Tensor<A>,
workspace: &mut [ByteOf<Ops::Hardware>],
queue_alloc: &QA,
) -> Result<(), LaunchError>
where
A: DerefMut<Target = [ByteOf<Ops::Hardware>]>,
QA: QueueAlloc<Hardware = Ops::Hardware>,
{
self.fuesd_softmax.launch(
&fuesd_softmax::Args {
att_mask: AttnMask::Causal,
att_layout: a.layout(),
att_base: a.base_mut(),
},
workspace,
queue_alloc,
)
}
fn all_reduce<X, QA>(
&self,
x: &mut Tensor<X>,
Expand Down
1 change: 0 additions & 1 deletion tensor/src/split.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ impl<T> Splitable for &[T] {
self
}
}

impl<T> Splitable for &mut [T] {
#[inline]
fn split(&self) -> Self {
Expand Down

0 comments on commit ad2ff0b

Please sign in to comment.