diff --git a/models/minicpm3/common/src/compute.rs b/models/minicpm3/common/src/compute.rs index 054ab73..7009012 100644 --- a/models/minicpm3/common/src/compute.rs +++ b/models/minicpm3/common/src/compute.rs @@ -170,17 +170,14 @@ where let dnope = dk - dh; let tensor = |shape: &[usize]| Tensor::new(dt_embd, shape); let x1 = tensor(x.shape()); - let q = tensor(&[nt, dq_lora]); - let kv_pe = tensor(&[nt, dh + dkv_lora]); + let gate_up = tensor(&[nt, di * 2]); + // 空间 x+x1+q(应该可以删除)+q3+kv_pe+attn let workspace_size = *x1.get() * 3 + *gate_up.get(); let mut workspace = Workspace::new(queue_alloc, workspace, workspace_size); let (buf, workspace) = workspace.split_at_mut(*x1.get()); let mut x1 = x1.map(|_| buf); - let (buf, workspace) = workspace.split_at_mut(*q.get()); - let mut q = q.map(|_| buf); - let (buf, workspace) = workspace.split_at_mut(*kv_pe.get()); - let mut kv_pe = kv_pe.map(|_| buf); + // 经行 attention let attn = tensor(&[nt, nh, dv]); let (buf, workspace) = workspace.split_at_mut(*attn.get()); @@ -191,52 +188,53 @@ where // norm let w = self.weights.attn_norm(iblk, queue); self.rms_norm(&mut x1, &x, &w, workspace, queue_alloc)?; - // if iblk==1{ - // Ops::debug(&x1, queue); - // todo!(); - // } drop(w); + let q = tensor(&[nt, dq_lora]); + let (buf, workspace) = workspace.split_at_mut(*q.get()); + let mut q = q.map(|_| buf); + let w = self.weights.attn_qa(iblk, queue).transpose(&[1, 0]); + self.mat_mul(&mut q, 0., &x1, &w, 1., workspace, queue_alloc)?; + + let inplace = unsafe { q.map_slice_static() }; + let w = self.weights.attn_qa_norm(iblk, queue); + self.rms_norm(&mut q, &inplace, &w, workspace, queue_alloc)?; { - let w = self.weights.attn_qa(iblk, queue).transpose(&[1, 0]); - self.mat_mul(&mut q, 0., &x1, &w, 1., workspace, queue_alloc)?; - - let inplace = unsafe { q.map_slice_static() }; - let w = self.weights.attn_qa_norm(iblk, queue); - self.rms_norm(&mut q, &inplace, &w, workspace, queue_alloc)?; - - let w = self.weights.attn_qb(iblk, queue).transpose(&[1, 0]); + // q [1, 768] q1 [1, 3840] kv_pe [1,288] kv [1, 5120] k [1, 3840] attn [1, 2560] let q1 = tensor(&[nt, nh * dk]); let (buf, workspace) = workspace.split_at_mut(*q1.get()); let mut q1 = q1.map(|_| buf); + let w = self.weights.attn_qb(iblk, queue).transpose(&[1, 0]); self.mat_mul(&mut q1, 0., &q, &w, 1., workspace, queue_alloc)?; - let q3 = q1.tile(1, &[nh, dk]); - let parts = [dnope, dh]; - let mut parts = q3.split(2, &parts); - let _ = parts.next().unwrap(); - let mut q_rope_0 = parts.next().unwrap(); - assert!(parts.next().is_none()); - drop(parts); + drop(q); + // q3 是计算 attn 需要用到的数据,但是我们仍然需要对 q3 的的部分进行嵌入操作 + let mut q3 = q1.tile(1, &[nh, dk]); + let q2 = unsafe { q3.map_slice_static_mut() }; + split_mut!(q2=>_q, q_rope;[dnope, dh]@ 2); + + // kv_pe [1,288] + let kv_pe = tensor(&[nt, dkv_lora + dh]); + let (buf, workspace) = workspace.split_at_mut(*kv_pe.get()); + let mut kv_pe = kv_pe.map(|_| buf); + let w = self.weights.attn_kva(iblk, queue).transpose(&[1, 0]); self.mat_mul(&mut kv_pe, 0., &x1, &w, 1., workspace, queue_alloc)?; - split_mut!(kv_pe => kv_lora_0, k_rope_0; [dkv_lora, dh] @ 1); + split_mut!(kv_pe => kv_lora, k_rope; [dkv_lora, dh] @ 1); - // kv_pe - let kv_lora_1 = tensor(&[nt, dkv_lora]); - let (buf, workspace) = workspace.split_at_mut(*kv_lora_1.get()); - let mut kv_lora_1 = kv_lora_1.map(|_| buf); + let inplace = unsafe { kv_lora.map_slice_static() }; let w = self.weights.attn_kva_norm(iblk, queue); - self.rms_norm(&mut kv_lora_1, &kv_lora_0, &w, workspace, queue_alloc)?; - - let kv_0 = tensor(&[nt, nh * (dnope + dv)]); - let (buf, workspace) = workspace.split_at_mut(*kv_0.get()); - let mut kv_0 = kv_0.map(|_| buf); + self.rms_norm(&mut kv_lora, &inplace, &w, workspace, queue_alloc)?; + // kv X[1, 5120] + let kv = tensor(&[nt, nh * (dnope + dv)]); + let (buf, workspace) = workspace.split_at_mut(*kv.get()); + let mut kv = kv.map(|_| buf); let w = self.weights.attn_kvb(iblk, queue).transpose(&[1, 0]); - self.mat_mul(&mut kv_0, 0., &kv_lora_1, &w, 1., workspace, queue_alloc)?; - let kv_1 = kv_0.tile(1, &[nh, dnope + dv]); + self.mat_mul(&mut kv, 0., &kv_lora, &w, 1., workspace, queue_alloc)?; - split_mut!(kv_1 => k_nope ,v ; [dnope , dv ] @ 2); + let kv = kv.tile(1, &[nh, dnope + dv]); + + split_mut!(kv => k_nope ,v ; [dnope , dv ] @ 2); /// longrope pub fn longrope( @@ -276,15 +274,13 @@ where let long_factor = cast(long_factor.base().cast()); let short_factor = cast(short_factor.base().cast()); - // k dk + // k [1, 3840] let k = tensor(&[nt, nh, dk]); let (buf, workspace) = workspace.split_at_mut(*k.get()); let mut k = k.map(|_| buf); - let parts = [dnope, dh]; - let mut parts = k.split(2, &parts); - let mut k_nope_r = parts.next().unwrap(); - let mut k_rope_r = parts.next().unwrap(); - assert!(parts.next().is_none()); + + split_mut!(k => k_nope_r ,k_rope_r ; [dnope, dh] @ 2); + let pos = requests.last().unwrap().pos as f32; let (max_pos, origin_max_pos) = (100f32, 100f32); @@ -292,7 +288,7 @@ where (0..nh).for_each(|i| { let mut tmp_q = unsafe { std::slice::from_raw_parts_mut( - q_rope_0.base_mut().cast::().offset((i * 32) as isize), + q_rope.base_mut().cast::().offset((i * 32) as isize), 32, ) }; @@ -306,30 +302,23 @@ where origin_max_pos, ); }); + // k 嵌入 + + let mut k_rope_1 = + unsafe { std::slice::from_raw_parts_mut(k_rope.base_mut().cast::(), 32) }; + longrope( + &mut k_rope_1, + pos, + self.meta.theta, + long_factor, + short_factor, + max_pos, + origin_max_pos, + ); - // println!("q {:?}",k_rope_0.shape()); - // todo!(); - // k 嵌入 - - { - let mut k_rope_1 = unsafe { - std::slice::from_raw_parts_mut(k_rope_0.base_mut().cast::(), 32) - }; - longrope( - &mut k_rope_1, - pos, - self.meta.theta, - long_factor, - short_factor, - max_pos, - origin_max_pos, - ); - } - - // TODO 未确认 // 经行广播和拷贝 - let k_rope_2 = k_rope_0.tile(1, &[1, dh]).broadcast(1, nh); - self.rearrange(&mut k_rope_r, &k_rope_2, workspace, queue_alloc)?; + let k_rope = k_rope.tile(1, &[1, dh]).broadcast(1, nh); + self.rearrange(&mut k_rope_r, &k_rope, workspace, queue_alloc)?; self.rearrange(&mut k_nope_r, &k_nope, workspace, queue_alloc)?; let mut q = q3.transpose(&[1, 0]); @@ -393,7 +382,8 @@ where if logits.shape()[0] == 0 { return Ok(()); } - + Ops::debug(&x, queue); + todo!(); // 集中要采样的 token // NOTICE: 输入之前将请求按 seq len 升序排列可降低移动开销 let mut dst = 0;