Skip to content

Commit

Permalink
feat(clip): 支持 resampler pos embd,pos 广播挪到模型内
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Feb 14, 2025
1 parent 7a71a4c commit 3d8b052
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 22 deletions.
33 changes: 19 additions & 14 deletions models/clip/common-cpu/src/infer.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use crate::{Operators, Weights};
use clip::{ClipArgs, ClipMeta, ClipStorage, ClipWorker, Image, Tensor, D_POS_EMBD};
use gguf::{ggml_quants::digit_layout::types as ty, GGufModel};
use gguf::{
ggml_quants::{digit_layout::types as ty, f16},
GGufModel,
};
use operators::{
common_cpu::{Cpu, ThisThread},
Blob,
Expand Down Expand Up @@ -53,22 +56,24 @@ fn test_infer() {
.launch(
ClipArgs {
raw: whole.to_nchw(),
pos: pos70(1, whole.shape(), d_patch).map_slice(),
pos: pos70(whole.shape(), d_patch).map_slice(),
pos_resampler: pos_resampler(3584, whole.shape(), d_patch).map_slice(),
},
&mut [],
&ThisThread,
)
.unwrap();

if let Some(patches) = slices.patches_nchw() {
let &[n, 3, h, w] = patches.shape() else {
let &[_, 3, h, w] = patches.shape() else {
unreachable!()
};
worker
.launch(
ClipArgs {
raw: patches.map_slice(),
pos: pos70(n, [w, h], d_patch).map_slice(),
pos: pos70([w, h], d_patch).map_slice(),
pos_resampler: pos_resampler(3584, [w, h], d_patch).map_slice(),
},
&mut [],
&ThisThread,
Expand All @@ -77,7 +82,7 @@ fn test_infer() {
}
}

fn pos70(n: usize, [w, h]: [usize; 2], d_patch: usize) -> Tensor<Blob> {
fn pos70([w, h]: [usize; 2], d_patch: usize) -> Tensor<Blob> {
let w = w / d_patch;
let h = h / d_patch;

Expand All @@ -95,15 +100,15 @@ fn pos70(n: usize, [w, h]: [usize; 2], d_patch: usize) -> Tensor<Blob> {
data[i] = (y * D_POS_EMBD + x) as _;
}

ans.broadcast(0, n)
ans
}

fn pos_resampler(d: usize, n: usize, [w, h]: [usize; 2], d_patch: usize) -> Tensor<Blob> {
fn pos_resampler(d: usize, [w, h]: [usize; 2], d_patch: usize) -> Tensor<Blob> {
let w = w / d_patch;
let h = h / d_patch;

let mut ans = Tensor::new(ty::F32, &[1, h * w, d]).map(Blob::new);
let (&mut [], data, &mut []) = (unsafe { ans.get_mut().align_to_mut::<f32>() }) else {
let mut ans = Tensor::new(ty::F16, &[1, h * w, d]).map(Blob::new);
let (&mut [], data, &mut []) = (unsafe { ans.get_mut().align_to_mut::<f16>() }) else {
panic!()
};

Expand All @@ -118,15 +123,15 @@ fn pos_resampler(d: usize, n: usize, [w, h]: [usize; 2], d_patch: usize) -> Tens
let d = d / 4;
for i in 0..d {
let (sin, cos) = cache[c * d + i];
data[0 * d..][i] = sin;
data[1 * d..][i] = cos;
data[0 * d..][i] = f16::from_f32(sin);
data[1 * d..][i] = f16::from_f32(cos);
let (sin, cos) = cache[r * d + i];
data[2 * d..][i] = sin;
data[3 * d..][i] = cos;
data[2 * d..][i] = f16::from_f32(sin);
data[3 * d..][i] = f16::from_f32(cos);
}
}

ans.broadcast(0, n)
ans
}

fn sin_cos_cache(max_idx: usize, d: usize, theta: f32) -> Vec<(f32, f32)> {
Expand Down
4 changes: 3 additions & 1 deletion models/clip/common/src/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ use tensor::Tensor;
pub struct Args<'a, H: Hardware> {
/// shape: [n, c, h, w]
pub raw: Tensor<&'a [H::Byte]>,
/// shape: [n, h x w]
/// shape: [h x w]
pub pos: Tensor<&'a [H::Byte]>,
/// shape: [h x w, resampler.d]
pub pos_resampler: Tensor<&'a [H::Byte]>,
}
24 changes: 17 additions & 7 deletions models/clip/common/src/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,11 @@ where
QA: QueueAlloc<Hardware = Ops::Hardware>,
{
let time = Instant::now();
let Args { raw, pos } = args;
let Args {
raw,
pos,
pos_resampler,
} = args;
let ClipMeta {
dt,
dt_norm,
Expand Down Expand Up @@ -176,14 +180,16 @@ where
let mut embd = Tensor::new(embd_.dt(), embd_.shape()).map(|s| queue_alloc.alloc(s));
self.rearrange(&mut embd, &embd_, workspace, queue_alloc)?;

let &[batch, size, _] = embd.shape() else {
unreachable!()
};

{
let pos_embd = self.weights.pos_embd(queue);
let pos = pos.broadcast(0, batch);
self.add_rows(&mut embd, &pos_embd, &pos, workspace, queue_alloc)?
}

let &[batch, size, _] = embd.shape() else {
unreachable!()
};
let batch_split = vec![size; batch];

let np = batch * size;
Expand Down Expand Up @@ -281,6 +287,7 @@ where

let d0 = self.meta.d;
let w = self.meta.mat(d, d0).map(|_| weights.resampler_wkv(queue));
// (np d0) <- (np d) · (d d0)
self.mat_mul(&mut v, &x, (w, None), workspace, queue_alloc)?;

let [w, b] = weights.resampler_ln_q(queue);
Expand All @@ -292,9 +299,12 @@ where
let inplace = unsafe { v.map_slice_static() };
self.layer_norm(&mut v, &inplace, ln_v, workspace, queue_alloc)?;

let (buf, workspace) = workspace.split_at_mut(*kv.get());
let pos_embd = Tensor::new(dt, v.shape()).map(|_| buf);
self.add(&mut k, &v, &pos_embd, workspace, queue_alloc)?;
{
let mut k = k.map_slice_mut().tile(0, &[batch, size]);
let v = v.map_slice().tile(0, &[batch, size]);
let pos = pos_resampler.broadcast(0, batch);
self.add(&mut k, &v, &pos, workspace, queue_alloc)?
}

let attn_w = self.meta.mat(d, d);
let attn_b = self.meta.mat(d, 1);
Expand Down

0 comments on commit 3d8b052

Please sign in to comment.