diff --git a/models/clip/common-cpu/src/infer.rs b/models/clip/common-cpu/src/infer.rs index 334e562..ac7ac56 100644 --- a/models/clip/common-cpu/src/infer.rs +++ b/models/clip/common-cpu/src/infer.rs @@ -1,6 +1,9 @@ use crate::{Operators, Weights}; use clip::{ClipArgs, ClipMeta, ClipStorage, ClipWorker, Image, Tensor, D_POS_EMBD}; -use gguf::{ggml_quants::digit_layout::types as ty, GGufModel}; +use gguf::{ + ggml_quants::{digit_layout::types as ty, f16}, + GGufModel, +}; use operators::{ common_cpu::{Cpu, ThisThread}, Blob, @@ -53,7 +56,8 @@ fn test_infer() { .launch( ClipArgs { raw: whole.to_nchw(), - pos: pos70(1, whole.shape(), d_patch).map_slice(), + pos: pos70(whole.shape(), d_patch).map_slice(), + pos_resampler: pos_resampler(3584, whole.shape(), d_patch).map_slice(), }, &mut [], &ThisThread, @@ -61,14 +65,15 @@ fn test_infer() { .unwrap(); if let Some(patches) = slices.patches_nchw() { - let &[n, 3, h, w] = patches.shape() else { + let &[_, 3, h, w] = patches.shape() else { unreachable!() }; worker .launch( ClipArgs { raw: patches.map_slice(), - pos: pos70(n, [w, h], d_patch).map_slice(), + pos: pos70([w, h], d_patch).map_slice(), + pos_resampler: pos_resampler(3584, [w, h], d_patch).map_slice(), }, &mut [], &ThisThread, @@ -77,7 +82,7 @@ fn test_infer() { } } -fn pos70(n: usize, [w, h]: [usize; 2], d_patch: usize) -> Tensor { +fn pos70([w, h]: [usize; 2], d_patch: usize) -> Tensor { let w = w / d_patch; let h = h / d_patch; @@ -95,15 +100,15 @@ fn pos70(n: usize, [w, h]: [usize; 2], d_patch: usize) -> Tensor { data[i] = (y * D_POS_EMBD + x) as _; } - ans.broadcast(0, n) + ans } -fn pos_resampler(d: usize, n: usize, [w, h]: [usize; 2], d_patch: usize) -> Tensor { +fn pos_resampler(d: usize, [w, h]: [usize; 2], d_patch: usize) -> Tensor { let w = w / d_patch; let h = h / d_patch; - let mut ans = Tensor::new(ty::F32, &[1, h * w, d]).map(Blob::new); - let (&mut [], data, &mut []) = (unsafe { ans.get_mut().align_to_mut::() }) else { + let mut ans = Tensor::new(ty::F16, &[1, h * w, d]).map(Blob::new); + let (&mut [], data, &mut []) = (unsafe { ans.get_mut().align_to_mut::() }) else { panic!() }; @@ -118,15 +123,15 @@ fn pos_resampler(d: usize, n: usize, [w, h]: [usize; 2], d_patch: usize) -> Tens let d = d / 4; for i in 0..d { let (sin, cos) = cache[c * d + i]; - data[0 * d..][i] = sin; - data[1 * d..][i] = cos; + data[0 * d..][i] = f16::from_f32(sin); + data[1 * d..][i] = f16::from_f32(cos); let (sin, cos) = cache[r * d + i]; - data[2 * d..][i] = sin; - data[3 * d..][i] = cos; + data[2 * d..][i] = f16::from_f32(sin); + data[3 * d..][i] = f16::from_f32(cos); } } - ans.broadcast(0, n) + ans } fn sin_cos_cache(max_idx: usize, d: usize, theta: f32) -> Vec<(f32, f32)> { diff --git a/models/clip/common/src/args.rs b/models/clip/common/src/args.rs index 7b6159e..bc14107 100644 --- a/models/clip/common/src/args.rs +++ b/models/clip/common/src/args.rs @@ -4,6 +4,8 @@ use tensor::Tensor; pub struct Args<'a, H: Hardware> { /// shape: [n, c, h, w] pub raw: Tensor<&'a [H::Byte]>, - /// shape: [n, h x w] + /// shape: [h x w] pub pos: Tensor<&'a [H::Byte]>, + /// shape: [h x w, resampler.d] + pub pos_resampler: Tensor<&'a [H::Byte]>, } diff --git a/models/clip/common/src/compute.rs b/models/clip/common/src/compute.rs index ad38c0f..8031d2a 100644 --- a/models/clip/common/src/compute.rs +++ b/models/clip/common/src/compute.rs @@ -145,7 +145,11 @@ where QA: QueueAlloc, { let time = Instant::now(); - let Args { raw, pos } = args; + let Args { + raw, + pos, + pos_resampler, + } = args; let ClipMeta { dt, dt_norm, @@ -176,14 +180,16 @@ where let mut embd = Tensor::new(embd_.dt(), embd_.shape()).map(|s| queue_alloc.alloc(s)); self.rearrange(&mut embd, &embd_, workspace, queue_alloc)?; + let &[batch, size, _] = embd.shape() else { + unreachable!() + }; + { let pos_embd = self.weights.pos_embd(queue); + let pos = pos.broadcast(0, batch); self.add_rows(&mut embd, &pos_embd, &pos, workspace, queue_alloc)? } - let &[batch, size, _] = embd.shape() else { - unreachable!() - }; let batch_split = vec![size; batch]; let np = batch * size; @@ -281,6 +287,7 @@ where let d0 = self.meta.d; let w = self.meta.mat(d, d0).map(|_| weights.resampler_wkv(queue)); + // (np d0) <- (np d) · (d d0) self.mat_mul(&mut v, &x, (w, None), workspace, queue_alloc)?; let [w, b] = weights.resampler_ln_q(queue); @@ -292,9 +299,12 @@ where let inplace = unsafe { v.map_slice_static() }; self.layer_norm(&mut v, &inplace, ln_v, workspace, queue_alloc)?; - let (buf, workspace) = workspace.split_at_mut(*kv.get()); - let pos_embd = Tensor::new(dt, v.shape()).map(|_| buf); - self.add(&mut k, &v, &pos_embd, workspace, queue_alloc)?; + { + let mut k = k.map_slice_mut().tile(0, &[batch, size]); + let v = v.map_slice().tile(0, &[batch, size]); + let pos = pos_resampler.broadcast(0, batch); + self.add(&mut k, &v, &pos, workspace, queue_alloc)? + } let attn_w = self.meta.mat(d, d); let attn_b = self.meta.mat(d, 1);