diff --git a/transformer-cpu/src/kernel.rs b/transformer-cpu/src/kernel.rs deleted file mode 100644 index bea5ac13..00000000 --- a/transformer-cpu/src/kernel.rs +++ /dev/null @@ -1,334 +0,0 @@ -use common::utok; -use gemm::{f16, gemm}; -use std::{ - iter::zip, - ops::{Deref, DerefMut, Mul, MulAssign}, -}; -use tensor::{ - expand_indices, idim, idx_strides, reslice, reslice_mut, udim, DVector, DataType, Tensor, -}; - -macro_rules! slice { - ($blob:expr; $width:expr; [$line:expr]) => { - $blob[$line as usize * $width..][..$width] - }; -} - -pub(super) fn gather(x: &mut Tensor, table: &Tensor, tokens: &[utok]) -where - T: DerefMut, - U: Deref, -{ - debug_assert_eq!(x.data_type(), table.data_type()); - debug_assert_eq!(x.shape().last(), table.shape().last()); - - let x = x.as_mut_slice(); - let table = table.as_slice(); - debug_assert_eq!(x.len() % tokens.len(), 0); - - let d = x.len() / tokens.len(); - for (i, &t) in tokens.iter().enumerate() { - slice!(x; d; [i]).copy_from_slice(&slice!(table; d; [t])) - } -} - -pub(super) fn rms_norm(o: &mut Tensor, x: &Tensor, w: &Tensor, epsilon: f32) -where - T: DerefMut, - U: Deref, - V: Deref, -{ - let dt = o.data_type(); - debug_assert_eq!(x.data_type(), dt); - debug_assert_eq!(w.data_type(), dt); - debug_assert_eq!(o.shape(), x.shape()); - debug_assert_eq!(&[*o.shape().last().unwrap()], w.shape()); - - let o = o.as_mut_slice(); - let x = x.as_slice(); - let w = w.as_slice(); - - match dt { - DataType::F16 => rms_norm_op(o, x, w, |x| { - f16::from_f32(rms_norm_reduce(x.iter().copied().map(f16::to_f32), epsilon)) - }), - DataType::F32 => rms_norm_op(o, x, w, |x| rms_norm_reduce(x.iter().copied(), epsilon)), - _ => unreachable!("unsupported data type \"{dt:?}\""), - } -} - -pub(super) fn rms_norm_inplace(o: &mut Tensor, w: &Tensor, epsilon: f32) -where - T: DerefMut, - U: Deref, -{ - let dt = o.data_type(); - debug_assert_eq!(w.data_type(), dt); - debug_assert_eq!(&[*o.shape().last().unwrap()], w.shape()); - - let o: &mut [u8] = o.as_mut_slice(); - let x = unsafe { std::slice::from_raw_parts(o.as_ptr(), o.len()) }; - let w = w.as_slice(); - - match dt { - DataType::F16 => rms_norm_op(o, x, w, |x| { - f16::from_f32(rms_norm_reduce(x.iter().copied().map(f16::to_f32), epsilon)) - }), - DataType::F32 => rms_norm_op(o, x, w, |x| rms_norm_reduce(x.iter().copied(), epsilon)), - _ => unreachable!("unsupported data type \"{dt:?}\""), - } -} - -fn rms_norm_op + MulAssign>( - o: &mut [u8], - x: &[u8], - w: &[u8], - reduce: impl Fn(&[T]) -> T, -) { - let o: &mut [T] = reslice_mut(o); - let x: &[T] = reslice(x); - let w: &[T] = reslice(w); - let d = w.len(); - - for i in 0..x.len() / w.len() { - let o = &mut slice!(o; d; [i]); - let x = &slice!(x; d; [i]); - let k = reduce(x); - zip(o, zip(x, w)).for_each(|(o, (x, w))| *o = *w * (k * *x)); - } -} - -#[inline] -fn rms_norm_reduce(x: impl Iterator, epsilon: f32) -> f32 { - // (Σx^2 / n + δ)^(-1/2) - let mut len = 0usize; - let mut sum = 0.0f32; - for x in x { - len += 1; - sum += x * x; - } - (sum / (len as f32) + epsilon).sqrt().recip() -} - -/// c = a x b -/// -/// - c: [N0, N1, ... , N_, m, n] -/// - a: [N0, N1, ... , N_, m, k] -/// - b: [N0, N1, ... , N_, k, n] -pub(super) fn mat_mul( - c: &mut Tensor, - beta: f32, - a: &Tensor, - b: &Tensor, - alpha: f32, -) where - T: DerefMut, - U: Deref, - V: Deref, -{ - let dt = c.data_type(); - assert_eq!(a.data_type(), dt); - assert_eq!(b.data_type(), dt); - - let rank = c.shape().len(); - assert_eq!(a.shape().len(), rank); - assert_eq!(b.shape().len(), rank); - - let (batch, mn) = c.shape().split_at(rank - 2); - let (a_batch, mk) = a.shape().split_at(rank - 2); - let (b_batch, kn) = b.shape().split_at(rank - 2); - assert_eq!(a_batch, batch); - assert_eq!(b_batch, batch); - - let m = mn[0]; - let n = mn[1]; - let k = mk[1]; - assert_eq!(mk, &[m, k]); - assert_eq!(kn, &[k, n]); - - let dst_strides = &c.strides()[rank - 2..]; - let dst_cs = dst_strides[1] as isize; - let dst_rs = dst_strides[0] as isize; - - let lhs_strides = &a.strides()[rank - 2..]; - let lhs_cs = lhs_strides[1] as isize; - let lhs_rs = lhs_strides[0] as isize; - - let rhs_strides = &b.strides()[rank - 2..]; - let rhs_cs = rhs_strides[1] as isize; - let rhs_rs = rhs_strides[0] as isize; - - let (batch, idx_strides) = idx_strides(batch); - for i in 0..batch { - let indices = expand_indices(i, &idx_strides, &[0, 0, 1]); - let indices = indices.as_view(); - let dst = c.locate_mut(&indices).unwrap(); - let lhs = a.locate(&indices).unwrap(); - let rhs = b.locate(&indices).unwrap(); - match dt { - DataType::F32 => unsafe { - gemm( - m as _, - n as _, - k as _, - dst.cast::(), - dst_cs, - dst_rs, - beta != 0., - lhs.cast::(), - lhs_cs, - lhs_rs, - rhs.cast::(), - rhs_cs, - rhs_rs, - beta, - alpha, - false, - false, - false, - gemm::Parallelism::Rayon(0), - ) - }, - DataType::F16 => unsafe { - gemm( - m as _, - n as _, - k as _, - dst.cast::(), - dst_cs, - dst_rs, - beta != 0., - lhs.cast::(), - lhs_cs, - lhs_rs, - rhs.cast::(), - rhs_cs, - rhs_rs, - f16::from_f32(beta), - f16::from_f32(alpha), - false, - false, - false, - gemm::Parallelism::Rayon(0), - ) - }, - _ => unreachable!(), - } - } -} - -/// - t: [N0, N1, ... , N_, num_head, head_dim] -/// - pos: [N0, N1, ... , N_] -pub(super) fn rotary_embedding(t: &mut Tensor, pos: &Tensor, theta: f32) -where - T: DerefMut, - U: Deref, -{ - assert!(t.contiguous_len() >= 2); - let [batch @ .., nh, dh] = t.shape() else { - panic!() - }; - assert_eq!(pos.shape(), batch); - let nh = *nh as usize; - let dh = *dh as usize / 2; - - let (n, idx_strides) = idx_strides(batch); - for i in 0..n { - let pos = pos - .locate(&expand_indices(i, &idx_strides, &[1]).as_view()) - .unwrap() - .cast::(); - let pos = unsafe { *pos } as f32; - let ptr = t - .locate_mut(&expand_indices(i, &idx_strides, &[0, 0, 1]).as_view()) - .unwrap() - .cast::<(f16, f16)>(); - let slice = unsafe { std::slice::from_raw_parts_mut(ptr, nh * dh) }; - for j in 0..nh { - for (k, slice) in slice!(slice; dh ; [j]).iter_mut().enumerate() { - let freq = pos / theta.powf(k as f32 / dh as f32); - let (sin, cos) = freq.sin_cos(); - let (a, b) = slice; - let a_ = a.to_f32(); - let b_ = b.to_f32(); - *a = f16::from_f32(a_ * cos - b_ * sin); - *b = f16::from_f32(a_ * sin + b_ * cos); - } - } - } -} - -/// - x: [N0, N1, ... , N_, seq_len, att_len] -pub(super) fn softmax(x: &mut Tensor) -where - T: DerefMut, -{ - assert!(x.contiguous_len() >= 2); - let (batch, dim) = x.shape().split_at(x.shape().len() - 2); - let seq_len = dim[0] as usize; - let att_len = dim[1] as usize; - - let (n, idx_strides) = idx_strides(batch); - for i in 0..n { - let ptr = x - .locate_mut(&expand_indices(i, &idx_strides, &[0, 0, 1]).as_view()) - .unwrap() - .cast::(); - let slice = unsafe { std::slice::from_raw_parts_mut(ptr, seq_len * att_len) }; - for r in 0..seq_len { - let slice = &mut slice!(slice; att_len; [r]); - let (att, tail) = slice.split_at_mut(att_len - seq_len + r + 1); - - let max = att - .iter() - .max_by(|a, b| a.partial_cmp(b).unwrap()) - .unwrap() - .to_f32(); - let sum = att - .iter_mut() - .map(|x| { - let exp = (x.to_f32() - max).exp(); - *x = f16::from_f32(exp); - exp - }) - .sum::(); - let sum = f16::from_f32(sum); - att.iter_mut().for_each(|x| *x /= sum); - - tail.fill(f16::ZERO); - } - } -} - -pub(super) fn swiglu(gate: &mut Tensor, up: &Tensor) -where - T: DerefMut, - U: Deref, -{ - let &[seq_len, di] = gate.shape() else { - panic!("gate shape: {:?}", gate.shape()); - }; - assert_eq!(gate.data_type(), up.data_type()); - assert_eq!(up.shape(), &[seq_len, di]); - assert!(gate.contiguous_len() >= 1); - assert!(up.contiguous_len() >= 1); - - for i in 0..seq_len { - let indices = DVector::from_vec(vec![i as idim, 0, 1]); - let gate = gate.locate_mut(&indices.as_view()).unwrap(); - let gate = unsafe { std::slice::from_raw_parts_mut(gate.cast::(), di as usize) }; - let up = up.locate(&indices.as_view()).unwrap(); - let up = unsafe { std::slice::from_raw_parts(up.cast::(), di as usize) }; - for (gate, up) in gate.iter_mut().zip(up) { - let x = gate.to_f32(); - let y = up.to_f32(); - - #[inline(always)] - fn sigmoid(x: f32) -> f32 { - 1. / (1. + (-x).exp()) - } - - *gate = f16::from_f32(x * sigmoid(x) * y); - } - } -} diff --git a/transformer-cpu/src/kernel/fused_softmax.rs b/transformer-cpu/src/kernel/fused_softmax.rs new file mode 100644 index 00000000..3b317f98 --- /dev/null +++ b/transformer-cpu/src/kernel/fused_softmax.rs @@ -0,0 +1,46 @@ +use super::slice; +use gemm::f16; +use std::ops::DerefMut; +use tensor::{expand_indices, idx_strides, Tensor}; + +/// - x: [N0, N1, ... , N_, seq_len, att_len] +pub fn softmax(x: &mut Tensor) +where + T: DerefMut, +{ + assert!(x.contiguous_len() >= 2); + let (batch, dim) = x.shape().split_at(x.shape().len() - 2); + let seq_len = dim[0] as usize; + let att_len = dim[1] as usize; + + let (n, idx_strides) = idx_strides(batch); + for i in 0..n { + let ptr = x + .locate_mut(&expand_indices(i, &idx_strides, &[0, 0, 1]).as_view()) + .unwrap() + .cast::(); + let slice = unsafe { std::slice::from_raw_parts_mut(ptr, seq_len * att_len) }; + for r in 0..seq_len { + let slice = &mut slice!(slice; att_len; [r]); + let (att, tail) = slice.split_at_mut(att_len - seq_len + r + 1); + + let max = att + .iter() + .max_by(|a, b| a.partial_cmp(b).unwrap()) + .unwrap() + .to_f32(); + let sum = att + .iter_mut() + .map(|x| { + let exp = (x.to_f32() - max).exp(); + *x = f16::from_f32(exp); + exp + }) + .sum::(); + let sum = f16::from_f32(sum); + att.iter_mut().for_each(|x| *x /= sum); + + tail.fill(f16::ZERO); + } + } +} diff --git a/transformer-cpu/src/kernel/gather.rs b/transformer-cpu/src/kernel/gather.rs new file mode 100644 index 00000000..6a221b64 --- /dev/null +++ b/transformer-cpu/src/kernel/gather.rs @@ -0,0 +1,29 @@ +use super::slice; +use common::utok; +use std::ops::{Deref, DerefMut}; +use tensor::Tensor; + +pub fn gather(x: &mut Tensor, table: &Tensor, tokens: &[&[utok]]) +where + T: DerefMut, + U: Deref, +{ + let &[num_token, d] = x.shape() else { panic!() }; + + debug_assert_eq!(x.data_type(), table.data_type()); + debug_assert_eq!(table.shape().len(), 2); + debug_assert_eq!(table.shape()[1], d); + debug_assert_eq!( + tokens.iter().map(|s| s.len()).sum::(), + num_token as usize + ); + debug_assert!(x.is_contiguous()); + debug_assert!(table.is_contiguous()); + let d = d as usize * x.data_type().size(); + + let x = x.as_mut_slice(); + let table = table.as_slice(); + for (i, &t) in tokens.iter().flat_map(|s| s.iter()).enumerate() { + slice!(x; d; [i]).copy_from_slice(&slice!(table; d; [t])) + } +} diff --git a/transformer-cpu/src/kernel/mat_mul.rs b/transformer-cpu/src/kernel/mat_mul.rs new file mode 100644 index 00000000..8b74ad81 --- /dev/null +++ b/transformer-cpu/src/kernel/mat_mul.rs @@ -0,0 +1,105 @@ +use gemm::{f16, gemm}; +use std::ops::{Deref, DerefMut}; +use tensor::{expand_indices, idx_strides, DataType, Tensor}; + +/// c = a x b +/// +/// - c: [N0, N1, ... , N_, m, n] +/// - a: [N0, N1, ... , N_, m, k] +/// - b: [N0, N1, ... , N_, k, n] +pub fn mat_mul(c: &mut Tensor, beta: f32, a: &Tensor, b: &Tensor, alpha: f32) +where + T: DerefMut, + U: Deref, + V: Deref, +{ + let dt = c.data_type(); + assert_eq!(a.data_type(), dt); + assert_eq!(b.data_type(), dt); + + let rank = c.shape().len(); + assert_eq!(a.shape().len(), rank); + assert_eq!(b.shape().len(), rank); + + let (batch, mn) = c.shape().split_at(rank - 2); + let (a_batch, mk) = a.shape().split_at(rank - 2); + let (b_batch, kn) = b.shape().split_at(rank - 2); + assert_eq!(a_batch, batch); + assert_eq!(b_batch, batch); + + let m = mn[0]; + let n = mn[1]; + let k = mk[1]; + assert_eq!(mk, &[m, k]); + assert_eq!(kn, &[k, n]); + + let dst_strides = &c.strides()[rank - 2..]; + let dst_cs = dst_strides[1] as isize; + let dst_rs = dst_strides[0] as isize; + + let lhs_strides = &a.strides()[rank - 2..]; + let lhs_cs = lhs_strides[1] as isize; + let lhs_rs = lhs_strides[0] as isize; + + let rhs_strides = &b.strides()[rank - 2..]; + let rhs_cs = rhs_strides[1] as isize; + let rhs_rs = rhs_strides[0] as isize; + + let (batch, idx_strides) = idx_strides(batch); + for i in 0..batch { + let indices = expand_indices(i, &idx_strides, &[0, 0, 1]); + let indices = indices.as_view(); + let dst = c.locate_mut(&indices).unwrap(); + let lhs = a.locate(&indices).unwrap(); + let rhs = b.locate(&indices).unwrap(); + match dt { + DataType::F32 => unsafe { + gemm( + m as _, + n as _, + k as _, + dst.cast::(), + dst_cs, + dst_rs, + beta != 0., + lhs.cast::(), + lhs_cs, + lhs_rs, + rhs.cast::(), + rhs_cs, + rhs_rs, + beta, + alpha, + false, + false, + false, + gemm::Parallelism::Rayon(0), + ) + }, + DataType::F16 => unsafe { + gemm( + m as _, + n as _, + k as _, + dst.cast::(), + dst_cs, + dst_rs, + beta != 0., + lhs.cast::(), + lhs_cs, + lhs_rs, + rhs.cast::(), + rhs_cs, + rhs_rs, + f16::from_f32(beta), + f16::from_f32(alpha), + false, + false, + false, + gemm::Parallelism::Rayon(0), + ) + }, + _ => unreachable!(), + } + } +} diff --git a/transformer-cpu/src/kernel/mod.rs b/transformer-cpu/src/kernel/mod.rs new file mode 100644 index 00000000..8cf7a7dc --- /dev/null +++ b/transformer-cpu/src/kernel/mod.rs @@ -0,0 +1,21 @@ +mod fused_softmax; +mod gather; +mod mat_mul; +mod rms_norm; +mod rotary_embedding; +mod swiglu; + +pub(super) use fused_softmax::softmax; +pub(super) use gather::gather; +pub(super) use mat_mul::mat_mul; +pub(super) use rms_norm::{rms_norm, rms_norm_inplace}; +pub(super) use rotary_embedding::rotary_embedding; +pub(super) use swiglu::swiglu; + +macro_rules! slice { + ($blob:expr; $width:expr; [$line:expr]) => { + $blob[$line as usize * $width..][..$width] + }; +} + +use slice; diff --git a/transformer-cpu/src/kernel/rms_norm.rs b/transformer-cpu/src/kernel/rms_norm.rs new file mode 100644 index 00000000..17d8e543 --- /dev/null +++ b/transformer-cpu/src/kernel/rms_norm.rs @@ -0,0 +1,85 @@ +use super::slice; +use gemm::f16; +use std::{ + iter::zip, + ops::{Deref, DerefMut, Mul, MulAssign}, +}; +use tensor::{reslice, reslice_mut, DataType, Tensor}; + +pub fn rms_norm(o: &mut Tensor, x: &Tensor, w: &Tensor, epsilon: f32) +where + T: DerefMut, + U: Deref, + V: Deref, +{ + let dt = o.data_type(); + debug_assert_eq!(x.data_type(), dt); + debug_assert_eq!(w.data_type(), dt); + debug_assert_eq!(o.shape(), x.shape()); + debug_assert_eq!(&[*o.shape().last().unwrap()], w.shape()); + + let o = o.as_mut_slice(); + let x = x.as_slice(); + let w = w.as_slice(); + + match dt { + DataType::F16 => rms_norm_op(o, x, w, |x| { + f16::from_f32(rms_norm_reduce(x.iter().copied().map(f16::to_f32), epsilon)) + }), + DataType::F32 => rms_norm_op(o, x, w, |x| rms_norm_reduce(x.iter().copied(), epsilon)), + _ => unreachable!("unsupported data type \"{dt:?}\""), + } +} + +pub fn rms_norm_inplace(o: &mut Tensor, w: &Tensor, epsilon: f32) +where + T: DerefMut, + U: Deref, +{ + let dt = o.data_type(); + debug_assert_eq!(w.data_type(), dt); + debug_assert_eq!(&[*o.shape().last().unwrap()], w.shape()); + + let o: &mut [u8] = o.as_mut_slice(); + let x = unsafe { std::slice::from_raw_parts(o.as_ptr(), o.len()) }; + let w = w.as_slice(); + + match dt { + DataType::F16 => rms_norm_op(o, x, w, |x| { + f16::from_f32(rms_norm_reduce(x.iter().copied().map(f16::to_f32), epsilon)) + }), + DataType::F32 => rms_norm_op(o, x, w, |x| rms_norm_reduce(x.iter().copied(), epsilon)), + _ => unreachable!("unsupported data type \"{dt:?}\""), + } +} + +fn rms_norm_op + MulAssign>( + o: &mut [u8], + x: &[u8], + w: &[u8], + reduce: impl Fn(&[T]) -> T, +) { + let o: &mut [T] = reslice_mut(o); + let x: &[T] = reslice(x); + let w: &[T] = reslice(w); + let d = w.len(); + + for i in 0..x.len() / w.len() { + let o = &mut slice!(o; d; [i]); + let x = &slice!(x; d; [i]); + let k = reduce(x); + zip(o, zip(x, w)).for_each(|(o, (x, w))| *o = *w * (k * *x)); + } +} + +#[inline] +fn rms_norm_reduce(x: impl Iterator, epsilon: f32) -> f32 { + // (Σx^2 / n + δ)^(-1/2) + let mut len = 0usize; + let mut sum = 0.0f32; + for x in x { + len += 1; + sum += x * x; + } + (sum / (len as f32) + epsilon).sqrt().recip() +} diff --git a/transformer-cpu/src/kernel/rotary_embedding.rs b/transformer-cpu/src/kernel/rotary_embedding.rs new file mode 100644 index 00000000..b4133896 --- /dev/null +++ b/transformer-cpu/src/kernel/rotary_embedding.rs @@ -0,0 +1,45 @@ +use super::slice; +use gemm::f16; +use std::ops::{Deref, DerefMut}; +use tensor::{expand_indices, idx_strides, udim, Tensor}; + +/// - t: [N0, N1, ... , N_, num_head, head_dim] +/// - pos: [N0, N1, ... , N_] +pub fn rotary_embedding(t: &mut Tensor, pos: &Tensor, theta: f32) +where + T: DerefMut, + U: Deref, +{ + assert!(t.contiguous_len() >= 2); + let [batch @ .., nh, dh] = t.shape() else { + panic!() + }; + assert_eq!(pos.shape(), batch); + let nh = *nh as usize; + let dh = *dh as usize / 2; + + let (n, idx_strides) = idx_strides(batch); + for i in 0..n { + let pos = pos + .locate(&expand_indices(i, &idx_strides, &[1]).as_view()) + .unwrap() + .cast::(); + let pos = unsafe { *pos } as f32; + let ptr = t + .locate_mut(&expand_indices(i, &idx_strides, &[0, 0, 1]).as_view()) + .unwrap() + .cast::<(f16, f16)>(); + let slice = unsafe { std::slice::from_raw_parts_mut(ptr, nh * dh) }; + for j in 0..nh { + for (k, slice) in slice!(slice; dh ; [j]).iter_mut().enumerate() { + let freq = pos / theta.powf(k as f32 / dh as f32); + let (sin, cos) = freq.sin_cos(); + let (a, b) = slice; + let a_ = a.to_f32(); + let b_ = b.to_f32(); + *a = f16::from_f32(a_ * cos - b_ * sin); + *b = f16::from_f32(a_ * sin + b_ * cos); + } + } + } +} diff --git a/transformer-cpu/src/kernel/swiglu.rs b/transformer-cpu/src/kernel/swiglu.rs new file mode 100644 index 00000000..d27669e0 --- /dev/null +++ b/transformer-cpu/src/kernel/swiglu.rs @@ -0,0 +1,36 @@ +use gemm::f16; +use std::ops::{Deref, DerefMut}; +use tensor::{idim, DVector, Tensor}; + +pub fn swiglu(gate: &mut Tensor, up: &Tensor) +where + T: DerefMut, + U: Deref, +{ + let &[seq_len, di] = gate.shape() else { + panic!("gate shape: {:?}", gate.shape()); + }; + assert_eq!(gate.data_type(), up.data_type()); + assert_eq!(up.shape(), &[seq_len, di]); + assert!(gate.contiguous_len() >= 1); + assert!(up.contiguous_len() >= 1); + + for i in 0..seq_len { + let indices = DVector::from_vec(vec![i as idim, 0, 1]); + let gate = gate.locate_mut(&indices.as_view()).unwrap(); + let gate = unsafe { std::slice::from_raw_parts_mut(gate.cast::(), di as usize) }; + let up = up.locate(&indices.as_view()).unwrap(); + let up = unsafe { std::slice::from_raw_parts(up.cast::(), di as usize) }; + for (gate, up) in gate.iter_mut().zip(up) { + let x = gate.to_f32(); + let y = up.to_f32(); + + #[inline(always)] + fn sigmoid(x: f32) -> f32 { + 1. / (1. + (-x).exp()) + } + + *gate = f16::from_f32(x * sigmoid(x) * y); + } + } +} diff --git a/transformer-cpu/src/lib.rs b/transformer-cpu/src/lib.rs index d6d6a091..392b6fb0 100644 --- a/transformer-cpu/src/lib.rs +++ b/transformer-cpu/src/lib.rs @@ -39,8 +39,13 @@ impl Transformer { self.model.max_position_embeddings() } - pub fn update(&self, tokens: &[utok], cache: &mut [LayerCache], pos: upos) -> Tensor { - let seq_len = tokens.len() as udim; + pub fn update( + &self, + tokens: &[&[utok]], + cache: &mut [LayerCache], + pos: upos, + ) -> Tensor { + let seq_len = tokens.iter().map(|s| s.len()).sum::() as udim; let d = self.model.hidden_size() as udim; let nh = self.model.num_attention_heads() as udim; let nkvh = self.model.num_key_value_heads() as udim; @@ -178,7 +183,7 @@ impl Transformer { } pub fn forward(&mut self, token: utok, cache: &mut [LayerCache], pos: upos) -> &[f32] { - let mut x = self.update(&[token], cache, pos); + let mut x = self.update(&[&[token]], cache, pos); let model_norm = self.model.model_norm(); rms_norm_inplace(&mut x.access_mut(), &model_norm, self.model.rms_norm_eps()); diff --git a/xtask/src/generate.rs b/xtask/src/generate.rs index f8636eef..2c14a0e0 100644 --- a/xtask/src/generate.rs +++ b/xtask/src/generate.rs @@ -136,7 +136,7 @@ fn on_host( let time = Instant::now(); let (last, tokens) = prompt_tokens.split_last().expect("prompt is empty"); if !tokens.is_empty() { - transformer.update(tokens, &mut kv_cache, 0); + transformer.update(&[tokens], &mut kv_cache, 0); } info!("prefill transformer ... {:?}", time.elapsed()); diff --git a/xtask/src/service/cpu.rs b/xtask/src/service/cpu.rs index b3f21d46..9750e764 100644 --- a/xtask/src/service/cpu.rs +++ b/xtask/src/service/cpu.rs @@ -52,7 +52,7 @@ pub(super) fn run( }); if !tokens.is_empty() { - transformer.update(tokens, &mut session.kv_cache, session.pos as _); + transformer.update(&[tokens], &mut session.kv_cache, session.pos as _); session.pos += tokens.len() as upos; }