Skip to content

Commit

Permalink
简单优化结构
Browse files Browse the repository at this point in the history
  • Loading branch information
onenewcode committed Feb 14, 2025
1 parent 88e0b82 commit 166d2ea
Showing 1 changed file with 58 additions and 68 deletions.
126 changes: 58 additions & 68 deletions models/minicpm3/common/src/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,17 +170,14 @@ where
let dnope = dk - dh;
let tensor = |shape: &[usize]| Tensor::new(dt_embd, shape);
let x1 = tensor(x.shape());
let q = tensor(&[nt, dq_lora]);
let kv_pe = tensor(&[nt, dh + dkv_lora]);

let gate_up = tensor(&[nt, di * 2]);
// 空间 x+x1+q(应该可以删除)+q3+kv_pe+attn
let workspace_size = *x1.get() * 3 + *gate_up.get();
let mut workspace = Workspace::new(queue_alloc, workspace, workspace_size);
let (buf, workspace) = workspace.split_at_mut(*x1.get());
let mut x1 = x1.map(|_| buf);
let (buf, workspace) = workspace.split_at_mut(*q.get());
let mut q = q.map(|_| buf);
let (buf, workspace) = workspace.split_at_mut(*kv_pe.get());
let mut kv_pe = kv_pe.map(|_| buf);

// 经行 attention
let attn = tensor(&[nt, nh, dv]);
let (buf, workspace) = workspace.split_at_mut(*attn.get());
Expand All @@ -191,52 +188,53 @@ where
// norm
let w = self.weights.attn_norm(iblk, queue);
self.rms_norm(&mut x1, &x, &w, workspace, queue_alloc)?;
// if iblk==1{
// Ops::debug(&x1, queue);
// todo!();
// }
drop(w);
let q = tensor(&[nt, dq_lora]);
let (buf, workspace) = workspace.split_at_mut(*q.get());
let mut q = q.map(|_| buf);
let w = self.weights.attn_qa(iblk, queue).transpose(&[1, 0]);
self.mat_mul(&mut q, 0., &x1, &w, 1., workspace, queue_alloc)?;

let inplace = unsafe { q.map_slice_static() };
let w = self.weights.attn_qa_norm(iblk, queue);
self.rms_norm(&mut q, &inplace, &w, workspace, queue_alloc)?;
{
let w = self.weights.attn_qa(iblk, queue).transpose(&[1, 0]);
self.mat_mul(&mut q, 0., &x1, &w, 1., workspace, queue_alloc)?;

let inplace = unsafe { q.map_slice_static() };
let w = self.weights.attn_qa_norm(iblk, queue);
self.rms_norm(&mut q, &inplace, &w, workspace, queue_alloc)?;

let w = self.weights.attn_qb(iblk, queue).transpose(&[1, 0]);
// q [1, 768] q1 [1, 3840] kv_pe [1,288] kv [1, 5120] k [1, 3840] attn [1, 2560]
let q1 = tensor(&[nt, nh * dk]);
let (buf, workspace) = workspace.split_at_mut(*q1.get());
let mut q1 = q1.map(|_| buf);
let w = self.weights.attn_qb(iblk, queue).transpose(&[1, 0]);
self.mat_mul(&mut q1, 0., &q, &w, 1., workspace, queue_alloc)?;
let q3 = q1.tile(1, &[nh, dk]);
let parts = [dnope, dh];
let mut parts = q3.split(2, &parts);
let _ = parts.next().unwrap();
let mut q_rope_0 = parts.next().unwrap();
assert!(parts.next().is_none());
drop(parts);
drop(q);
// q3 是计算 attn 需要用到的数据,但是我们仍然需要对 q3 的的部分进行嵌入操作
let mut q3 = q1.tile(1, &[nh, dk]);
let q2 = unsafe { q3.map_slice_static_mut() };
split_mut!(q2=>_q, q_rope;[dnope, dh]@ 2);

// kv_pe [1,288]
let kv_pe = tensor(&[nt, dkv_lora + dh]);
let (buf, workspace) = workspace.split_at_mut(*kv_pe.get());
let mut kv_pe = kv_pe.map(|_| buf);

let w = self.weights.attn_kva(iblk, queue).transpose(&[1, 0]);
self.mat_mul(&mut kv_pe, 0., &x1, &w, 1., workspace, queue_alloc)?;

split_mut!(kv_pe => kv_lora_0, k_rope_0; [dkv_lora, dh] @ 1);
split_mut!(kv_pe => kv_lora, k_rope; [dkv_lora, dh] @ 1);

// kv_pe
let kv_lora_1 = tensor(&[nt, dkv_lora]);
let (buf, workspace) = workspace.split_at_mut(*kv_lora_1.get());
let mut kv_lora_1 = kv_lora_1.map(|_| buf);
let inplace = unsafe { kv_lora.map_slice_static() };
let w = self.weights.attn_kva_norm(iblk, queue);
self.rms_norm(&mut kv_lora_1, &kv_lora_0, &w, workspace, queue_alloc)?;

let kv_0 = tensor(&[nt, nh * (dnope + dv)]);
let (buf, workspace) = workspace.split_at_mut(*kv_0.get());
let mut kv_0 = kv_0.map(|_| buf);
self.rms_norm(&mut kv_lora, &inplace, &w, workspace, queue_alloc)?;
// kv X[1, 5120]
let kv = tensor(&[nt, nh * (dnope + dv)]);
let (buf, workspace) = workspace.split_at_mut(*kv.get());
let mut kv = kv.map(|_| buf);
let w = self.weights.attn_kvb(iblk, queue).transpose(&[1, 0]);
self.mat_mul(&mut kv_0, 0., &kv_lora_1, &w, 1., workspace, queue_alloc)?;

let kv_1 = kv_0.tile(1, &[nh, dnope + dv]);
self.mat_mul(&mut kv, 0., &kv_lora, &w, 1., workspace, queue_alloc)?;

split_mut!(kv_1 => k_nope ,v ; [dnope , dv ] @ 2);
let kv = kv.tile(1, &[nh, dnope + dv]);

split_mut!(kv => k_nope ,v ; [dnope , dv ] @ 2);

/// longrope
pub fn longrope(
Expand Down Expand Up @@ -276,23 +274,21 @@ where
let long_factor = cast(long_factor.base().cast());
let short_factor = cast(short_factor.base().cast());

// k dk
// k [1, 3840]
let k = tensor(&[nt, nh, dk]);
let (buf, workspace) = workspace.split_at_mut(*k.get());
let mut k = k.map(|_| buf);
let parts = [dnope, dh];
let mut parts = k.split(2, &parts);
let mut k_nope_r = parts.next().unwrap();
let mut k_rope_r = parts.next().unwrap();
assert!(parts.next().is_none());

split_mut!(k => k_nope_r ,k_rope_r ; [dnope, dh] @ 2);

let pos = requests.last().unwrap().pos as f32;
let (max_pos, origin_max_pos) = (100f32, 100f32);

// q 嵌入
(0..nh).for_each(|i| {
let mut tmp_q = unsafe {
std::slice::from_raw_parts_mut(
q_rope_0.base_mut().cast::<f32>().offset((i * 32) as isize),
q_rope.base_mut().cast::<f32>().offset((i * 32) as isize),
32,
)
};
Expand All @@ -306,30 +302,23 @@ where
origin_max_pos,
);
});
// k 嵌入

let mut k_rope_1 =
unsafe { std::slice::from_raw_parts_mut(k_rope.base_mut().cast::<f32>(), 32) };
longrope(
&mut k_rope_1,
pos,
self.meta.theta,
long_factor,
short_factor,
max_pos,
origin_max_pos,
);

// println!("q {:?}",k_rope_0.shape());
// todo!();
// k 嵌入

{
let mut k_rope_1 = unsafe {
std::slice::from_raw_parts_mut(k_rope_0.base_mut().cast::<f32>(), 32)
};
longrope(
&mut k_rope_1,
pos,
self.meta.theta,
long_factor,
short_factor,
max_pos,
origin_max_pos,
);
}

// TODO 未确认
// 经行广播和拷贝
let k_rope_2 = k_rope_0.tile(1, &[1, dh]).broadcast(1, nh);
self.rearrange(&mut k_rope_r, &k_rope_2, workspace, queue_alloc)?;
let k_rope = k_rope.tile(1, &[1, dh]).broadcast(1, nh);
self.rearrange(&mut k_rope_r, &k_rope, workspace, queue_alloc)?;
self.rearrange(&mut k_nope_r, &k_nope, workspace, queue_alloc)?;

let mut q = q3.transpose(&[1, 0]);
Expand Down Expand Up @@ -393,7 +382,8 @@ where
if logits.shape()[0] == 0 {
return Ok(());
}

Ops::debug(&x, queue);
todo!();
// 集中要采样的 token
// NOTICE: 输入之前将请求按 seq len 升序排列可降低移动开销
let mut dst = 0;
Expand Down

0 comments on commit 166d2ea

Please sign in to comment.