Skip to content

Commit

Permalink
feat(clip): 完成 clip 模型的 transformer 部分
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Dec 30, 2024
1 parent 2f059ce commit 256d219
Show file tree
Hide file tree
Showing 6 changed files with 571 additions and 70 deletions.
11 changes: 11 additions & 0 deletions gguf/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,17 @@ mod macros {
Err(e) => panic!("failed to read meta: {e:?}"),
}
};

($gguf:expr => (usize) $key:expr) => {
$gguf.get_usize($key).unwrap()
};
($gguf:expr => (usize) $key:expr; $default:expr) => {
match $gguf.get_usize($key) {
Ok(val) => val,
Err(gguf::GGufMetaError::NotExist) => $default,
Err(e) => panic!("failed to read meta: {e:?}"),
}
};
}
#[macro_export]
macro_rules! tensor {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ fn test_infer() {
println!("{meta:#?}");

let &ClipMeta {
dt_embd,
dt,

d_image,
d_patch,
Expand All @@ -42,7 +42,7 @@ fn test_infer() {
let time = Instant::now();
let slices = image
.slice_uhd(9, d_image, d_patch)
.normalize(dt_embd, image_mean, image_std);
.normalize(dt, image_mean, image_std);
println!("slice image {:?}", time.elapsed());

let weights = Weights::new(&storage);
Expand Down
60 changes: 51 additions & 9 deletions models/clip/common-cpu/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use clip::{ClipStorage, WeightLoader};
use operators::{common_cpu::Cpu, conv, QueueOf, TopoNode};
use std::marker::PhantomData;
use clip::{BlkWeight, ClipBlkStorage, ClipStorage, Tensor, WeightLoader};
use operators::{common_cpu::Cpu, conv, ByteOf, QueueOf, TopoNode};
use std::{marker::PhantomData, ops::Deref};

pub struct Operators<N = Cpu>(PhantomData<N>);

Expand All @@ -22,6 +22,18 @@ where
type Conv = conv::common_cpu::ConvIm2Col;
type AddRows = op!(add_rows);
type LayerNorm = op!(layer_norm);
type MatMul = op!(mat_mul);
type Attention = op!(attention);
type Gelu = op!(gelu);
type Add = op!(add);
type Rearrange = op!(rearrange);

fn debug<T>(tensor: &Tensor<T>)
where
T: Deref<Target = [ByteOf<Self::Hardware>]>,
{
println!("{tensor}")
}
}

impl<'w> Weights<'w> {
Expand All @@ -32,37 +44,67 @@ impl<'w> Weights<'w> {

impl WeightLoader for Weights<'_> {
type Hardware = Cpu;
type Weight<'s>
type Memory<'s>
= &'s [u8]
where
Self: 's;

fn load_blk(
&self,
which: BlkWeight,
iblk: usize,
_queue: &QueueOf<Self::Hardware>,
) -> [Self::Memory<'_>; 2] {
let ClipBlkStorage {
attn_norm_w,
attn_norm_b,
attn_qkv_w,
attn_qkv_b,
attn_o_w,
attn_o_b,
ffn_norm_w,
ffn_norm_b,
ffn_up_w,
ffn_up_b,
ffn_down_w,
ffn_down_b,
} = &self.0.blocks[iblk];
match which {
BlkWeight::AttnNorm => [attn_norm_w, attn_norm_b],
BlkWeight::AttnQKV => [attn_qkv_w, attn_qkv_b],
BlkWeight::AttnO => [attn_o_w, attn_o_b],
BlkWeight::FfnNorm => [ffn_norm_w, ffn_norm_b],
BlkWeight::FfnUp => [ffn_up_w, ffn_up_b],
BlkWeight::FfnDown => [ffn_down_w, ffn_down_b],
}
}

#[inline]
fn patch_embd<'a>(&'a self, _queue: &'a QueueOf<Self::Hardware>) -> [Self::Weight<'a>; 2] {
fn patch_embd<'a>(&'a self, _queue: &'a QueueOf<Self::Hardware>) -> [Self::Memory<'a>; 2] {
[self.0.patch_embd_w, self.0.patch_embd_b]
}

#[inline]
fn pos_embd<'a>(&'a self, _queue: &'a QueueOf<Self::Hardware>) -> Self::Weight<'a> {
fn pos_embd<'a>(&'a self, _queue: &'a QueueOf<Self::Hardware>) -> Self::Memory<'a> {
self.0.pos_embd
}

#[inline]
fn pre_norm<'a>(
&'a self,
_queue: &'a QueueOf<Self::Hardware>,
) -> Option<[Self::Weight<'a>; 2]> {
) -> Option<[Self::Memory<'a>; 2]> {
self.0.pre_norm
}

#[inline]
fn post_norm<'a>(
&'a self,
_queue: &'a QueueOf<Self::Hardware>,
) -> Option<[Self::Weight<'a>; 2]> {
) -> Option<[Self::Memory<'a>; 2]> {
self.0.post_norm
}
}

#[cfg(test)]
mod test_infer;
mod infer;
Loading

0 comments on commit 256d219

Please sign in to comment.