diff --git a/gguf/src/lib.rs b/gguf/src/lib.rs index f70386f..2a0f31a 100644 --- a/gguf/src/lib.rs +++ b/gguf/src/lib.rs @@ -144,6 +144,17 @@ mod macros { Err(e) => panic!("failed to read meta: {e:?}"), } }; + + ($gguf:expr => (usize) $key:expr) => { + $gguf.get_usize($key).unwrap() + }; + ($gguf:expr => (usize) $key:expr; $default:expr) => { + match $gguf.get_usize($key) { + Ok(val) => val, + Err(gguf::GGufMetaError::NotExist) => $default, + Err(e) => panic!("failed to read meta: {e:?}"), + } + }; } #[macro_export] macro_rules! tensor { diff --git a/models/clip/common-cpu/src/test_infer.rs b/models/clip/common-cpu/src/infer.rs similarity index 97% rename from models/clip/common-cpu/src/test_infer.rs rename to models/clip/common-cpu/src/infer.rs index a5ef1f0..85f2cbd 100644 --- a/models/clip/common-cpu/src/test_infer.rs +++ b/models/clip/common-cpu/src/infer.rs @@ -25,7 +25,7 @@ fn test_infer() { println!("{meta:#?}"); let &ClipMeta { - dt_embd, + dt, d_image, d_patch, @@ -42,7 +42,7 @@ fn test_infer() { let time = Instant::now(); let slices = image .slice_uhd(9, d_image, d_patch) - .normalize(dt_embd, image_mean, image_std); + .normalize(dt, image_mean, image_std); println!("slice image {:?}", time.elapsed()); let weights = Weights::new(&storage); diff --git a/models/clip/common-cpu/src/lib.rs b/models/clip/common-cpu/src/lib.rs index ac3fe16..1a683be 100644 --- a/models/clip/common-cpu/src/lib.rs +++ b/models/clip/common-cpu/src/lib.rs @@ -1,6 +1,6 @@ -use clip::{ClipStorage, WeightLoader}; -use operators::{common_cpu::Cpu, conv, QueueOf, TopoNode}; -use std::marker::PhantomData; +use clip::{BlkWeight, ClipBlkStorage, ClipStorage, Tensor, WeightLoader}; +use operators::{common_cpu::Cpu, conv, ByteOf, QueueOf, TopoNode}; +use std::{marker::PhantomData, ops::Deref}; pub struct Operators(PhantomData); @@ -22,6 +22,18 @@ where type Conv = conv::common_cpu::ConvIm2Col; type AddRows = op!(add_rows); type LayerNorm = op!(layer_norm); + type MatMul = op!(mat_mul); + type Attention = op!(attention); + type Gelu = op!(gelu); + type Add = op!(add); + type Rearrange = op!(rearrange); + + fn debug(tensor: &Tensor) + where + T: Deref]>, + { + println!("{tensor}") + } } impl<'w> Weights<'w> { @@ -32,18 +44,48 @@ impl<'w> Weights<'w> { impl WeightLoader for Weights<'_> { type Hardware = Cpu; - type Weight<'s> + type Memory<'s> = &'s [u8] where Self: 's; + fn load_blk( + &self, + which: BlkWeight, + iblk: usize, + _queue: &QueueOf, + ) -> [Self::Memory<'_>; 2] { + let ClipBlkStorage { + attn_norm_w, + attn_norm_b, + attn_qkv_w, + attn_qkv_b, + attn_o_w, + attn_o_b, + ffn_norm_w, + ffn_norm_b, + ffn_up_w, + ffn_up_b, + ffn_down_w, + ffn_down_b, + } = &self.0.blocks[iblk]; + match which { + BlkWeight::AttnNorm => [attn_norm_w, attn_norm_b], + BlkWeight::AttnQKV => [attn_qkv_w, attn_qkv_b], + BlkWeight::AttnO => [attn_o_w, attn_o_b], + BlkWeight::FfnNorm => [ffn_norm_w, ffn_norm_b], + BlkWeight::FfnUp => [ffn_up_w, ffn_up_b], + BlkWeight::FfnDown => [ffn_down_w, ffn_down_b], + } + } + #[inline] - fn patch_embd<'a>(&'a self, _queue: &'a QueueOf) -> [Self::Weight<'a>; 2] { + fn patch_embd<'a>(&'a self, _queue: &'a QueueOf) -> [Self::Memory<'a>; 2] { [self.0.patch_embd_w, self.0.patch_embd_b] } #[inline] - fn pos_embd<'a>(&'a self, _queue: &'a QueueOf) -> Self::Weight<'a> { + fn pos_embd<'a>(&'a self, _queue: &'a QueueOf) -> Self::Memory<'a> { self.0.pos_embd } @@ -51,7 +93,7 @@ impl WeightLoader for Weights<'_> { fn pre_norm<'a>( &'a self, _queue: &'a QueueOf, - ) -> Option<[Self::Weight<'a>; 2]> { + ) -> Option<[Self::Memory<'a>; 2]> { self.0.pre_norm } @@ -59,10 +101,10 @@ impl WeightLoader for Weights<'_> { fn post_norm<'a>( &'a self, _queue: &'a QueueOf, - ) -> Option<[Self::Weight<'a>; 2]> { + ) -> Option<[Self::Memory<'a>; 2]> { self.0.post_norm } } #[cfg(test)] -mod test_infer; +mod infer; diff --git a/models/clip/common/src/compute.rs b/models/clip/common/src/compute.rs index 53ebfdf..1c96746 100644 --- a/models/clip/common/src/compute.rs +++ b/models/clip/common/src/compute.rs @@ -1,15 +1,21 @@ use super::{args::Args, ClipMeta}; +use itertools::izip; use operators::{ + add::{self, Add}, add_rows::{self, AddRows}, + attention::{self, Attention}, conv::{self, Conv}, + gelu::{self, Gelu}, layer_norm::{self, LayerNorm}, - ByteOf, Hardware, LaunchError, Operator, QueueAlloc, QueueOf, TopoNode, + mat_mul::{self, MatMul}, + rearrange::{self, Rearrange}, + ByteOf, Hardware, LaunchError, Operator, QueueAlloc, QueueOf, TopoNode, Workspace, }; use std::{ ops::{Deref, DerefMut}, time::Instant, }; -use tensor::Tensor; +use tensor::{split, Tensor}; pub trait Operators { type Hardware: Hardware; @@ -17,19 +23,45 @@ pub trait Operators { type Conv: Conv; type AddRows: AddRows; type LayerNorm: LayerNorm; + type MatMul: MatMul; + type Attention: Attention; + type Gelu: Gelu; + type Add: Add; + type Rearrange: Rearrange; + + fn debug(tensor: &Tensor) + where + T: Deref]>; +} + +#[derive(Clone, Copy, PartialEq, Eq, Debug)] +pub enum BlkWeight { + AttnNorm, + AttnQKV, + AttnO, + FfnNorm, + FfnUp, + FfnDown, } pub trait WeightLoader { type Hardware: Hardware; - type Weight<'s>: Deref]> + 's + type Memory<'s>: Deref]> + 's where Self: 's; - fn patch_embd<'a>(&'a self, queue: &'a QueueOf) -> [Self::Weight<'a>; 2]; - fn pos_embd<'a>(&'a self, queue: &'a QueueOf) -> Self::Weight<'a>; - fn pre_norm<'a>(&'a self, queue: &'a QueueOf) -> Option<[Self::Weight<'a>; 2]>; + fn load_blk( + &self, + which: BlkWeight, + iblk: usize, + queue: &QueueOf, + ) -> [Self::Memory<'_>; 2]; + + fn patch_embd<'a>(&'a self, queue: &'a QueueOf) -> [Self::Memory<'a>; 2]; + fn pos_embd<'a>(&'a self, queue: &'a QueueOf) -> Self::Memory<'a>; + fn pre_norm<'a>(&'a self, queue: &'a QueueOf) -> Option<[Self::Memory<'a>; 2]>; fn post_norm<'a>(&'a self, queue: &'a QueueOf) - -> Option<[Self::Weight<'a>; 2]>; + -> Option<[Self::Memory<'a>; 2]>; } pub struct ClipWorker { @@ -38,6 +70,11 @@ pub struct ClipWorker { conv: Ops::Conv, add_rows: Ops::AddRows, layer_norm: Ops::LayerNorm, + mat_mul: Ops::MatMul, + attention: Ops::Attention, + gelu: Ops::Gelu, + add: Ops::Add, + rearrange: Ops::Rearrange, pub debug: bool, } @@ -50,6 +87,11 @@ impl ClipWorker { conv: Ops::Conv::new(processor), add_rows: Ops::AddRows::new(processor), layer_norm: Ops::LayerNorm::new(processor), + mat_mul: Ops::MatMul::new(processor), + attention: Ops::Attention::new(processor), + gelu: Ops::Gelu::new(processor), + add: Ops::Add::new(processor), + rearrange: Ops::Rearrange::new(processor), debug: true, } } @@ -58,6 +100,23 @@ impl ClipWorker { pub const fn meta(&self) -> &ClipMeta { &self.meta } + + pub fn workspace_size(&self, np: usize) -> usize { + let ClipMeta { + nh, nkvh, dh, di, .. + } = self.meta; + + let embd = self.meta.embd(np); + let dt = embd.dt(); + let embd = embd.take(); + + let qkv = Tensor::new(dt, &[np * (nh + nkvh + nkvh), dh]).take(); + let q = Tensor::new(dt, &[np, nh, dh]).take(); + let att = Tensor::new(dt, &[nkvh, np, np]).take(); + + let up = Tensor::new(dt, &[np, di]).take(); + embd + (qkv + q + att).max(up) + } } impl ClipWorker @@ -77,9 +136,17 @@ where { let time = Instant::now(); let Args { raw, pos } = args; - let queue = queue_alloc.queue(); + let ClipMeta { + dt, + nblk, + nh, + nkvh, + dh, + di, + .. + } = self.meta; - let ClipMeta { dt_embd, .. } = self.meta; + let queue = queue_alloc.queue(); let [k, b] = self.weights.patch_embd(queue); let &[n, _, h, w] = raw.shape() else { @@ -89,28 +156,102 @@ where unreachable!() }; - let mut embd = Tensor::new(dt_embd, &[n, m, h / hk, w / wk]).map(|s| queue_alloc.alloc(s)); + let mut embd = Tensor::new(dt, &[n, m, h / hk, w / wk]).map(|s| queue_alloc.alloc(s)); self.conv(&mut embd, &raw, &k, &b, workspace, queue_alloc)?; + drop(k); + drop(b); - let mut embd = embd.merge(2..4).unwrap().transpose(&[2, 1]); + let embd_ = embd.merge(2..4).unwrap().transpose(&[2, 1]); + let mut embd = Tensor::new(embd_.dt(), embd_.shape()).map(|s| queue_alloc.alloc(s)); + self.rearrange(&mut embd, &embd_, workspace, queue_alloc)?; - let pos_embd = self.weights.pos_embd(queue); - self.add_rows(&mut embd, &pos_embd, &pos, workspace, queue_alloc)?; + { + let pos_embd = self.weights.pos_embd(queue); + self.add_rows(&mut embd, &pos_embd, &pos, workspace, queue_alloc)? + } + + let &[batch, size, _] = embd.shape() else { + unreachable!() + }; + let batch_split = vec![size; batch]; - if let Some([scale, bias]) = self.weights.pre_norm(queue) { - let inplace = unsafe { embd.map_slice_static() }; - self.layer_norm(&mut embd, &inplace, &scale, &bias, workspace, queue_alloc)?; + let np = batch * size; + let mut x = embd.merge(0..2).unwrap(); + let x1 = Tensor::new(x.dt(), x.shape()); + let qkv = Tensor::new(x.dt(), &[np, (nh + nkvh + nkvh) * dh]); + let up = Tensor::new(x.dt(), &[np, di]); + + let workspace_size = self.workspace_size(np); + let mut workspace = Workspace::new(queue_alloc, workspace, workspace_size); + let (buf, workspace) = workspace.split_at_mut(*x1.get()); + let mut x1 = x1.map(|_| buf); + + if let Some(wb) = self.weights.pre_norm(queue) { + let inplace = unsafe { x.map_slice_static() }; + self.layer_norm(&mut x, &inplace, wb, workspace, queue_alloc)? } - for _ in 0..self.meta.nblk {} + for iblk in 0..nblk { + { + let wb = self.weights.attn_norm(iblk, queue); + self.layer_norm(&mut x1, &x, wb, workspace, queue_alloc)?; + + let (buf, workspace) = workspace.split_at_mut(*qkv.get()); + let mut qkv = qkv.clone().map(|_| buf); + + let [w, b] = self.weights.attn_qkv(iblk, queue); + self.mat_mul(&mut qkv, &x1, (w, Some(b)), workspace, queue_alloc)?; + + let qkv = qkv.tile(1, &[nh + nkvh + nkvh, dh]); + split!(qkv => q, k, v; [nh, nkvh, nkvh] @ 1); + let mut q = q; + let k = k; + let v = v; + { + let q = q.map_slice_mut().transpose(&[1, 0]); + let k = k.map_slice().transpose(&[1, 0]); + let v = v.map_slice().transpose(&[1, 0]); + let q = q.split(1, &batch_split); + let k = k.split(1, &batch_split); + let v = v.split(1, &batch_split); + + for (mut q, k, v) in izip!(q, k, v) { + let mut o = unsafe { q.map_slice_static_mut() }; + self.attn(&mut q, &k, &v, &mut o, workspace, queue_alloc)? + } + } + let o = q.map_slice().merge(1..3).unwrap(); + let [w, b] = self.weights.attn_o(iblk, queue); + self.mat_mul(&mut x1, &o, (w, Some(b)), workspace, queue_alloc)? + } + let inplace = unsafe { x.map_slice_static() }; + self.add(&mut x, &inplace, &x1, workspace, queue_alloc)?; + + let wb = self.weights.ffn_norm(iblk, queue); + self.layer_norm(&mut x1, &x, wb, workspace, queue_alloc)?; + { + let (buf, workspace) = workspace.split_at_mut(*up.get()); + let mut up = up.clone().map(|_| buf); + + let [w, b] = self.weights.ffn_up(iblk, queue); + self.mat_mul(&mut up, &x1, (w, Some(b)), workspace, queue_alloc)?; + + self.gelu(&mut up, workspace, queue_alloc)?; - if let Some([scale, bias]) = self.weights.post_norm(queue) { - let inplace = unsafe { embd.map_slice_static() }; - self.layer_norm(&mut embd, &inplace, &scale, &bias, workspace, queue_alloc)?; + let [w, b] = self.weights.ffn_down(iblk, queue); + self.mat_mul(&mut x1, &up, (w, Some(b)), workspace, queue_alloc)? + } + let inplace = unsafe { x.map_slice_static() }; + self.add(&mut x, &inplace, &x1, workspace, queue_alloc)? + } + + if let Some(wb) = self.weights.post_norm(queue) { + let inplace = unsafe { x.map_slice_static() }; + self.layer_norm(&mut x, &inplace, wb, workspace, queue_alloc)? } if self.debug { - println!("encode {n} x {h} x {w} image in {:?}", time.elapsed()); + println!("encode {n} x {h} x {w} image in {:?}", time.elapsed()) } Ok(()) @@ -186,20 +327,18 @@ where ) } - fn layer_norm( + fn layer_norm( &self, y: &mut Tensor, x: &Tensor, - scale: &Tensor, - bias: &Tensor, + [w, b]: [Tensor; 2], workspace: &mut [ByteOf], queue_alloc: &QA, ) -> Result<(), LaunchError> where Y: DerefMut]>, X: Deref]>, - Scale: Deref]>, - Bias: Deref]>, + WB: Deref]>, QA: QueueAlloc, { self.layer_norm.launch( @@ -208,24 +347,177 @@ where y_base: y.base_mut(), x_layout: x.layout(), x_base: x.base(), - scale_layout: scale.layout(), - scale_base: scale.base(), - bias_layout: bias.layout(), - bias_base: bias.base(), + scale_layout: w.layout(), + scale_base: w.base(), + bias_layout: b.layout(), + bias_base: b.base(), epsilon: self.meta.epsilon, }, workspace, queue_alloc, ) } + + fn mat_mul( + &self, + c: &mut Tensor, + a: &Tensor, + (w, b): (Tensor, Option>), + workspace: &mut [ByteOf], + queue_alloc: &QA, + ) -> Result<(), LaunchError> + where + C: DerefMut]>, + A: Deref]>, + WB: Deref]>, + QA: QueueAlloc, + { + let beta = if let Some(b) = b { + let n = c.shape()[0]; + let b = b.broadcast(0, n); + self.rearrange(c, &b, workspace, queue_alloc)?; + 1. + } else { + 0. + }; + self.mat_mul.launch( + &mat_mul::Args { + c_layout: c.layout(), + c_base: c.base_mut(), + beta, + a_layout: a.layout(), + a_base: a.base(), + b_layout: w.layout(), + b_base: w.base(), + alpha: 1., + }, + workspace, + queue_alloc, + ) + } + + fn attn( + &self, + q: &mut Tensor, + k: &Tensor, + v: &Tensor, + o: &mut Tensor, + workspace: &mut [ByteOf], + queue_alloc: &QA, + ) -> Result<(), LaunchError> + where + Q: DerefMut]>, + K: Deref]>, + V: Deref]>, + O: DerefMut]>, + QA: QueueAlloc, + { + self.attention.launch( + &attention::Args { + q_layout: q.layout(), + q_base: q.base_mut(), + k_layout: k.layout(), + k_base: k.base(), + v_layout: v.layout(), + v_base: v.base(), + o_layout: o.layout(), + o_base: o.base_mut(), + }, + workspace, + queue_alloc, + ) + } + + fn gelu( + &self, + x: &mut Tensor, + workspace: &mut [ByteOf], + queue_alloc: &QA, + ) -> Result<(), LaunchError> + where + X: DerefMut]>, + QA: QueueAlloc, + { + self.gelu.launch( + &gelu::Args { + layout: x.layout(), + base: x.base_mut(), + }, + workspace, + queue_alloc, + ) + } + + fn add( + &self, + c: &mut Tensor, + a: &Tensor, + b: &Tensor, + workspace: &mut [ByteOf], + queue_alloc: &QA, + ) -> Result<(), LaunchError> + where + C: DerefMut]>, + A: Deref]>, + B: Deref]>, + QA: QueueAlloc, + { + self.add.launch( + &add::Args { + c_layout: c.layout(), + c_base: c.base_mut(), + a_layout: a.layout(), + a_base: a.base(), + b_layout: b.layout(), + b_base: b.base(), + }, + workspace, + queue_alloc, + ) + } + + fn rearrange( + &self, + dst: &mut Tensor, + src: &Tensor, + workspace: &mut [ByteOf], + queue_alloc: &QA, + ) -> Result<(), LaunchError> + where + Dst: DerefMut]>, + Src: Deref]>, + QA: QueueAlloc, + { + self.rearrange.launch( + &rearrange::Args { + dst_layout: dst.layout(), + dst_base: dst.base_mut(), + src_layout: src.layout(), + src_base: src.base(), + }, + workspace, + queue_alloc, + ) + } } struct WeightDecorator { - weights: W, patch_embd_w: Tensor, patch_embd_b: Tensor, pos_embd: Tensor, norm: Tensor, + + attn_qkv_w: Tensor, + attn_qkv_b: Tensor, + attn_o_w: Tensor, + attn_o_b: Tensor, + + ffn_up_w: Tensor, + ffn_up_b: Tensor, + ffn_down_w: Tensor, + ffn_down_b: Tensor, + + weights: W, } impl ClipMeta { @@ -235,6 +527,16 @@ impl ClipMeta { patch_embd_b: self.patch_embd_b(), pos_embd: self.pos_embd(), norm: self.norm(), + + attn_qkv_w: self.attn_qkv_w(), + attn_qkv_b: self.attn_qkv_b(), + attn_o_w: self.attn_o_w(), + attn_o_b: self.attn_o_b(), + ffn_up_w: self.ffn_up_w(), + ffn_up_b: self.ffn_up_b(), + ffn_down_w: self.ffn_down_w(), + ffn_down_b: self.ffn_down_b(), + weights, } } @@ -242,7 +544,7 @@ impl ClipMeta { impl WeightDecorator { #[inline] - pub fn patch_embd<'a>(&'a self, queue: &'a QueueOf) -> [Tensor>; 2] { + pub fn patch_embd<'a>(&'a self, queue: &'a QueueOf) -> [Tensor>; 2] { let [w, b] = self.weights.patch_embd(queue); [ self.patch_embd_w.clone().map(|_| w), @@ -251,7 +553,7 @@ impl WeightDecorator { } #[inline] - pub fn pos_embd<'a>(&'a self, queue: &'a QueueOf) -> Tensor> { + pub fn pos_embd<'a>(&'a self, queue: &'a QueueOf) -> Tensor> { let pos_embd = self.weights.pos_embd(queue); self.pos_embd.clone().map(|_| pos_embd) } @@ -260,7 +562,7 @@ impl WeightDecorator { pub fn pre_norm<'a>( &'a self, queue: &'a QueueOf, - ) -> Option<[Tensor>; 2]> { + ) -> Option<[Tensor>; 2]> { self.weights .pre_norm(queue) .map(|pair| pair.map(|w| self.norm.clone().map(|_| w))) @@ -270,9 +572,67 @@ impl WeightDecorator { pub fn post_norm<'a>( &'a self, queue: &'a QueueOf, - ) -> Option<[Tensor>; 2]> { + ) -> Option<[Tensor>; 2]> { self.weights .post_norm(queue) .map(|pair| pair.map(|w| self.norm.clone().map(|_| w))) } + + pub fn attn_norm( + &self, + iblk: usize, + queue: &QueueOf, + ) -> [Tensor>; 2] { + let [w, b] = self.weights.load_blk(BlkWeight::AttnNorm, iblk, queue); + [self.norm.clone().map(|_| w), self.norm.clone().map(|_| b)] + } + + pub fn attn_qkv( + &self, + iblk: usize, + queue: &QueueOf, + ) -> [Tensor>; 2] { + let [w, b] = self.weights.load_blk(BlkWeight::AttnQKV, iblk, queue); + [ + self.attn_qkv_w.clone().map(|_| w), + self.attn_qkv_b.clone().map(|_| b), + ] + } + + pub fn attn_o(&self, iblk: usize, queue: &QueueOf) -> [Tensor>; 2] { + let [w, b] = self.weights.load_blk(BlkWeight::AttnO, iblk, queue); + [ + self.attn_o_w.clone().map(|_| w), + self.attn_o_b.clone().map(|_| b), + ] + } + + pub fn ffn_norm( + &self, + iblk: usize, + queue: &QueueOf, + ) -> [Tensor>; 2] { + let [w, b] = self.weights.load_blk(BlkWeight::FfnNorm, iblk, queue); + [self.norm.clone().map(|_| w), self.norm.clone().map(|_| b)] + } + + pub fn ffn_up(&self, iblk: usize, queue: &QueueOf) -> [Tensor>; 2] { + let [w, b] = self.weights.load_blk(BlkWeight::FfnUp, iblk, queue); + [ + self.ffn_up_w.clone().map(|_| w), + self.ffn_up_b.clone().map(|_| b), + ] + } + + pub fn ffn_down( + &self, + iblk: usize, + queue: &QueueOf, + ) -> [Tensor>; 2] { + let [w, b] = self.weights.load_blk(BlkWeight::FfnDown, iblk, queue); + [ + self.ffn_down_w.clone().map(|_| w), + self.ffn_down_b.clone().map(|_| b), + ] + } } diff --git a/models/clip/common/src/lib.rs b/models/clip/common/src/lib.rs index d391649..01d785b 100644 --- a/models/clip/common/src/lib.rs +++ b/models/clip/common/src/lib.rs @@ -6,9 +6,9 @@ mod storage; use gguf::ggml_quants::digit_layout::DigitLayout; pub use args::Args as ClipArgs; -pub use compute::{ClipWorker, Operators, WeightLoader}; +pub use compute::{BlkWeight, ClipWorker, Operators, WeightLoader}; pub use image::{Image, ImageGrid}; -pub use storage::Storage as ClipStorage; +pub use storage::{BlkStorage as ClipBlkStorage, Storage as ClipStorage}; pub use tensor::Tensor; pub mod ext { pub use gguf::{ @@ -23,14 +23,15 @@ pub struct ClipMeta { pub minicpmv_version: u8, pub dt: DigitLayout, - pub dt_embd: DigitLayout, - pub dt_norm: DigitLayout, - pub nblk: usize, pub d_patch: usize, pub d_image: usize, + + pub nblk: usize, pub nh: usize, + pub nkvh: usize, pub d: usize, + pub dh: usize, pub di: usize, pub image_mean: [f32; 3], @@ -79,6 +80,16 @@ impl ClipMeta { } } + pub fn embd(&self, np: usize) -> Tensor { + let &Self { dt, d, .. } = self; + Tensor::new(dt, &[np, d]) + } + + pub fn pos_embd(&self) -> Tensor { + let &Self { dt, d, .. } = self; + Tensor::new(dt, &[D_POS_EMBD.pow(2), d]) + } + pub fn patch_embd_w(&self) -> Tensor { let &Self { d, d_patch, .. } = self; Tensor::new(self.dt, &[d, 3, d_patch, d_patch]) @@ -89,13 +100,53 @@ impl ClipMeta { Tensor::new(self.dt, &[d]) } - pub fn pos_embd(&self) -> Tensor { + pub fn norm(&self) -> Tensor { let &Self { d, .. } = self; - Tensor::new(self.dt_embd, &[D_POS_EMBD.pow(2), d]) + Tensor::new(self.dt, &[d]) } - pub fn norm(&self) -> Tensor { + pub fn attn_qkv_w(&self) -> Tensor { + let &Self { d, .. } = self; + self.mat(3 * d, d) + } + + pub fn attn_qkv_b(&self) -> Tensor { + let &Self { d, .. } = self; + self.mat(3 * d, 1) + } + + pub fn attn_o_w(&self) -> Tensor { + let &Self { d, .. } = self; + self.mat(d, d) + } + + pub fn attn_o_b(&self) -> Tensor { + let &Self { d, .. } = self; + self.mat(d, 1) + } + + pub fn ffn_up_w(&self) -> Tensor { + let &Self { d, di, .. } = self; + self.mat(di, d) + } + + pub fn ffn_up_b(&self) -> Tensor { + let &Self { di, .. } = self; + self.mat(di, 1) + } + + pub fn ffn_down_w(&self) -> Tensor { + let &Self { d, di, .. } = self; + self.mat(d, di) + } + + pub fn ffn_down_b(&self) -> Tensor { let &Self { d, .. } = self; - Tensor::new(self.dt_norm, &[d]) + self.mat(d, 1) + } + + fn mat(&self, row: usize, col: usize) -> Tensor { + assert_eq!(self.dt.group_size(), 1); + Tensor::new(self.dt, &[row, col]).transpose(&[1, 0]) } } diff --git a/models/clip/common/src/storage.rs b/models/clip/common/src/storage.rs index 2977aa2..c0ba1f0 100644 --- a/models/clip/common/src/storage.rs +++ b/models/clip/common/src/storage.rs @@ -1,5 +1,5 @@ use crate::{ClipMeta, ProjectorType}; -use gguf::{GGufMetaMapExt, GGufModel}; +use gguf::{meta, GGufMetaMapExt, GGufModel}; #[derive(Clone)] pub struct Storage { @@ -9,13 +9,29 @@ pub struct Storage { pub pos_embd: T, pub pre_norm: Option<[T; 2]>, pub post_norm: Option<[T; 2]>, + pub blocks: Box<[BlkStorage]>, +} + +#[derive(Clone, Copy)] +pub struct BlkStorage { + pub attn_norm_w: T, + pub attn_norm_b: T, + pub attn_qkv_w: T, + pub attn_qkv_b: T, + pub attn_o_w: T, + pub attn_o_b: T, + + pub ffn_norm_w: T, + pub ffn_norm_b: T, + pub ffn_up_w: T, + pub ffn_up_b: T, + pub ffn_down_w: T, + pub ffn_down_b: T, } impl<'a> Storage<&'a [u8]> { pub fn from_gguf(gguf: &GGufModel<'a>) -> Self { let pos_embd = &gguf.tensors["v.position_embd.weight"]; - let patch_embd_w = &gguf.tensors["v.patch_embd.weight"]; - let patch_embd_b = &gguf.tensors["v.patch_embd.bias"]; let projector = match gguf.get_str("clip.projector_type").unwrap() { "mlp" => ProjectorType::Mlp, @@ -25,31 +41,51 @@ impl<'a> Storage<&'a [u8]> { _ => ProjectorType::Unknown, }; + let d = meta![gguf => (usize) "clip.vision.embedding_length"]; + let nh = meta![gguf => (usize) "clip.vision.attention.head_count"]; + #[rustfmt::skip] let meta = ClipMeta { projector, - minicpmv_version: gguf.get_usize("clip.minicpmv_version").unwrap() as _, + minicpmv_version: meta![gguf => (usize) "clip.minicpmv_version"] as _, - dt : patch_embd_w.ty, - dt_embd: pos_embd.ty, - dt_norm: gguf.tensors["v.blk.0.ln1.weight"].ty, + dt : pos_embd.ty, + d_patch: meta![gguf => (usize) "clip.vision.patch_size"], + d_image: meta![gguf => (usize) "clip.vision.image_size"], - nblk : gguf.get_usize("clip.vision.block_count" ).unwrap(), - d_patch: gguf.get_usize("clip.vision.patch_size" ).unwrap(), - d_image: gguf.get_usize("clip.vision.image_size" ).unwrap(), - nh : gguf.get_usize("clip.vision.attention.head_count" ).unwrap(), - d : gguf.get_usize("clip.vision.embedding_length" ).unwrap(), - di : gguf.get_usize("clip.vision.feed_forward_length" ).unwrap(), + d, nh, + nblk: meta![gguf => (usize) "clip.vision.block_count" ], + nkvh: meta![gguf => (usize) "clip.vision.attention.head_count_kv"; nh], + dh : meta![gguf => (usize) "clip.vision.rope_dimension_count"; d / nh], + di : meta![gguf => (usize) "clip.vision.feed_forward_length" ], image_mean: get_rgb(gguf, "clip.vision.image_mean"), image_std : get_rgb(gguf, "clip.vision.image_std" ), epsilon : gguf.get_f32("clip.vision.attention.layer_norm_epsilon").unwrap(), }; + #[rustfmt::skip] + let blocks = (0..meta.nblk) + .map(|i| BlkStorage { + attn_norm_w: gguf.tensors[&*format!("v.blk.{i}.ln1.weight" )].data, + attn_norm_b: gguf.tensors[&*format!("v.blk.{i}.ln1.bias" )].data, + attn_qkv_w: gguf.tensors[&*format!("v.blk.{i}.attn_qkv.weight")].data, + attn_qkv_b: gguf.tensors[&*format!("v.blk.{i}.attn_qkv.bias" )].data, + attn_o_w: gguf.tensors[&*format!("v.blk.{i}.attn_out.weight")].data, + attn_o_b: gguf.tensors[&*format!("v.blk.{i}.attn_out.bias" )].data, + + ffn_norm_w: gguf.tensors[&*format!("v.blk.{i}.ln2.weight" )].data, + ffn_norm_b: gguf.tensors[&*format!("v.blk.{i}.ln2.bias" )].data, + ffn_up_w: gguf.tensors[&*format!("v.blk.{i}.ffn_up.weight" )].data, + ffn_up_b: gguf.tensors[&*format!("v.blk.{i}.ffn_up.bias" )].data, + ffn_down_w: gguf.tensors[&*format!("v.blk.{i}.ffn_down.weight")].data, + ffn_down_b: gguf.tensors[&*format!("v.blk.{i}.ffn_down.bias" )].data, + }) + .collect(); Self { meta, - patch_embd_w: patch_embd_w.data, - patch_embd_b: patch_embd_b.data, + patch_embd_w: gguf.tensors["v.patch_embd.weight"].data, + patch_embd_b: gguf.tensors["v.patch_embd.bias"].data, pos_embd: pos_embd.data, pre_norm: gguf .tensors @@ -59,6 +95,7 @@ impl<'a> Storage<&'a [u8]> { .tensors .get("v.post_ln.weight") .map(|w| [w.data, gguf.tensors["v.post_ln.bias"].data]), + blocks, } } }