Skip to content

Commit

Permalink
feat(clip): 实现 resampler pos embd 计算
Browse files Browse the repository at this point in the history
  • Loading branch information
CearX authored and YdrMaster committed Feb 14, 2025
1 parent 81c8e1e commit 7a71a4c
Showing 1 changed file with 51 additions and 9 deletions.
60 changes: 51 additions & 9 deletions models/clip/common-cpu/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,21 +78,63 @@ fn test_infer() {
}

fn pos70(n: usize, [w, h]: [usize; 2], d_patch: usize) -> Tensor<Blob> {
let pos_w = w / d_patch;
let pos_h = h / d_patch;
let w = w / d_patch;
let h = h / d_patch;

let mut ans = Tensor::new(ty::U32, &[1, pos_w * pos_h])
.broadcast(0, n)
.map(Blob::new);
let mut ans = Tensor::new(ty::U32, &[1, h * w]).map(Blob::new);
let (&mut [], data, &mut []) = (unsafe { ans.get_mut().align_to_mut::<u32>() }) else {
panic!()
};

for i in 0..pos_h * pos_w {
let y = (i / pos_w) * D_POS_EMBD / pos_h;
let x = (i % pos_w) * D_POS_EMBD / pos_w;
for i in 0..h * w {
let r = i / w;
let c = i % w;

let y = r * D_POS_EMBD / h;
let x = c * D_POS_EMBD / w;
data[i] = (y * D_POS_EMBD + x) as _;
}

ans
ans.broadcast(0, n)
}

fn pos_resampler(d: usize, n: usize, [w, h]: [usize; 2], d_patch: usize) -> Tensor<Blob> {
let w = w / d_patch;
let h = h / d_patch;

let mut ans = Tensor::new(ty::F32, &[1, h * w, d]).map(Blob::new);
let (&mut [], data, &mut []) = (unsafe { ans.get_mut().align_to_mut::<f32>() }) else {
panic!()
};

assert!(d % 4 == 0);
let cache = sin_cos_cache(w.max(h), d / 4, 1e4);

for i in 0..h * w {
let r = i / w;
let c = i % w;

let data = &mut data[i * d..][..d];
let d = d / 4;
for i in 0..d {
let (sin, cos) = cache[c * d + i];
data[0 * d..][i] = sin;
data[1 * d..][i] = cos;
let (sin, cos) = cache[r * d + i];
data[2 * d..][i] = sin;
data[3 * d..][i] = cos;
}
}

ans.broadcast(0, n)
}

fn sin_cos_cache(max_idx: usize, d: usize, theta: f32) -> Vec<(f32, f32)> {
(0..max_idx * d)
.map(|i| {
let a = (i / d) as f32;
let b = (i % d) as f32;
(a * theta.powf(-(b / d as f32))).sin_cos()
})
.collect()
}

0 comments on commit 7a71a4c

Please sign in to comment.