From 261472529f2884ac19081a7d0726af9665689b42 Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Wed, 27 Mar 2024 16:58:01 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E6=94=B9=E6=AD=A3=20tensor=20slice=20?= =?UTF-8?q?=E5=92=8C=20decode=20=E7=A7=BB=E5=8A=A8=E7=AE=97=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- tensor/src/slice.rs | 144 +++++++++++++++++++++++----------- transformer-cpu/src/lib.rs | 33 ++++---- transformer-nvidia/src/lib.rs | 36 ++++----- 3 files changed, 130 insertions(+), 83 deletions(-) diff --git a/tensor/src/slice.rs b/tensor/src/slice.rs index 83a9e36c..31380209 100644 --- a/tensor/src/slice.rs +++ b/tensor/src/slice.rs @@ -1,4 +1,5 @@ use crate::{idim, pattern::Pattern, udim, Affine, Shape, Tensor}; +use std::{cmp::Ordering, iter::zip}; impl Tensor { pub fn slice(self, dims: &[SliceDim]) -> Self { @@ -11,55 +12,13 @@ impl Tensor { } } -#[derive(Clone, Debug)] -pub struct SliceDim { - pub start: udim, - pub step: idim, - pub len: udim, -} - -#[macro_export] -macro_rules! slice { - [all] => { - $crate::SliceDim { - start: 0, - step: 1, - len: udim::MAX, - } - }; - [from $start:expr, take $len:expr] => { - $crate::SliceDim { - start: $start, - step: 1, - len: $len, - } - }; - [$start:expr; $step:expr; $len:expr] => { - $crate::SliceDim { - start: $start, - step: $step, - len: $len, - } - }; -} - fn build(meta: &[SliceDim], input: &[udim]) -> (Shape, Affine) { assert_eq!(input.len(), meta.len()); - assert!(meta - .iter() - .zip(input) - .all(|(d, &i)| { (0..i).contains(&d.start) })); - - let shape = meta - .iter() - .zip(input) - .map(|(d, &i)| { - let distance = if d.step > 0 { i - d.start } else { d.start }; - let step = d.step.unsigned_abs(); - ((distance + step - 1) / step).min(d.len) - }) - .collect::(); + let meta = zip(meta, input) + .map(|(d, &len)| d.normalize(len)) + .collect::>(); + let shape = meta.iter().map(|d| d.len).collect::(); let n = meta.len(); let affine = Affine::from_fn(n + 1, n + 1, |r, c| { if r == n { @@ -75,6 +34,7 @@ fn build(meta: &[SliceDim], input: &[udim]) -> (Shape, Affine) { #[test] fn test() { + use crate::slice; let (shape, affine) = build(&[slice![2;1;2], slice![0;1;4], slice![1;2;3]], &[5, 6, 7]); assert_eq!(shape.as_slice(), &[2, 4, 3]); assert_eq!( @@ -88,3 +48,95 @@ fn test() { ] ); } + +#[derive(Clone, PartialEq, Eq, Hash, Debug)] +pub struct SliceDim { + pub start: udim, + pub step: idim, + pub len: udim, +} + +impl SliceDim { + #[inline] + pub fn normalize(&self, len: udim) -> Self { + match self.step.cmp(&0) { + Ordering::Greater => { + assert!(self.start < len); + Self { + start: self.start, + step: self.step, + len: { + let step = self.step as udim; + ((len - self.start + step - 1) / step).min(self.len) + }, + } + } + Ordering::Equal => { + assert!(self.start < len); + Self { + start: self.start, + step: self.step, + len: self.len, + } + } + Ordering::Less => { + let start = self.start.min(len - 1); + Self { + start, + step: self.step, + len: { + let step = self.step.unsigned_abs(); + ((start + 1 + step - 1) / step).min(self.len) + }, + } + } + } + } +} + +#[macro_export] +macro_rules! slice { + [all] => { + slice![0; 1; usize::MAX] + }; + [rev] => { + slice![usize::MAX; -1; usize::MAX] + }; + [take $len:expr] => { + slice![0; 1; $len] + }; + [from $start:expr, until $end:expr] => { + slice![$start; 1; $end - $start] + }; + [from $start:expr, take $len:expr] => { + slice![$start; 1; $len] + }; + [from $start:expr, take $len:expr, per $step:expr] => { + slice![$start; $step; $len] + }; + [$start:expr; $step:expr; $len:expr] => { + $crate::SliceDim { + start: $start as _, + step : $step as _, + len : $len as _, + } + }; +} + +#[test] +fn test_macro() { + assert_eq!( + slice![5; -3; 2], + SliceDim { + start: 5, + step: -3, + len: 2, + } + ); + assert_eq!(slice![all], slice![0; 1; usize::MAX]); + assert_eq!(slice![rev], slice![usize::MAX; -1; usize::MAX]); + assert_eq!(slice![take 5], slice![0; 1; 5]); + assert_eq!(slice![from 3, until 5], slice![3; 1; 2]); + assert_eq!(slice![from 3, take 5], slice![3; 1; 5]); + assert_eq!(slice![from 3, take 5, per 2], slice![3; 2; 5]); +} diff --git a/transformer-cpu/src/lib.rs b/transformer-cpu/src/lib.rs index 2870eef1..37b2179b 100644 --- a/transformer-cpu/src/lib.rs +++ b/transformer-cpu/src/lib.rs @@ -275,28 +275,25 @@ impl Transformer { requests: &[Request], mut x0: Tensor, ) -> Tensor { - let dt = self.0.data_type(); - let d = self.0.hidden_size() as udim; + let buf = x0.as_mut_slice(); + let len = self.0.hidden_size() * self.0.data_type().size(); let (head, others) = requests.split_first().unwrap(); - let tokens = { - let begin = head.seq_len() as usize; - let mut i = begin; - let mut j = begin; - let buf = x0.as_mut_slice(); - let len = d as usize * dt.size(); - for r in others { - i += r.seq_len() as usize; - j += 1; - if r.decode() && i > j { - buf.copy_within((i - 1) * len..i * len, (j - 1) * len); + let begin = head.seq_len() as usize - 1; + + let mut src = begin; + let mut dst = begin; + for r in others { + src += r.seq_len() as usize; + if r.decode() { + dst += 1; + if dst < src { + buf.copy_within(src * len..(src + 1) * len, dst * len); } } - let begin = begin as udim - 1; - let len = j as udim - begin; - slice![from begin, take len] - }; - x0.slice(&[tokens, slice![all]]) + } + + x0.slice(&[slice![from begin, until dst + 1], slice![all]]) } fn logits(&self, mut x: Tensor) -> Tensor { diff --git a/transformer-nvidia/src/lib.rs b/transformer-nvidia/src/lib.rs index 95d1d894..af9e4153 100644 --- a/transformer-nvidia/src/lib.rs +++ b/transformer-nvidia/src/lib.rs @@ -126,6 +126,7 @@ impl<'ctx> Transformer<'ctx> { requests .into_iter() + .filter(Request::decode) .enumerate() .map(|(i, r)| { ( @@ -330,33 +331,30 @@ impl<'ctx> Transformer<'ctx> { x0: Tensor>, compute: &Stream, ) -> Tensor> { - let dt = self.host.data_type(); - let d = self.host.hidden_size() as udim; + let buf = unsafe { x0.physical().as_raw() }; + let len = self.host.hidden_size() * self.host.data_type().size(); let (head, others) = requests.split_first().unwrap(); - let tokens = { - let begin = head.seq_len() as usize; - let mut i = begin; - let mut j = begin; - let buf = unsafe { x0.physical().as_raw() }; - let len = d as usize * dt.size(); - for r in others { - i += r.seq_len() as usize; - j += 1; - if r.decode() && i > j { + let begin = head.seq_len() as usize - 1; + + let mut src = begin; + let mut dst = begin; + for r in others { + src += r.seq_len() as usize; + if r.decode() { + dst += 1; + if dst < src { cuda::driver!(cuMemcpyDtoDAsync_v2( - buf + ((j - 1) * len) as CUdeviceptr, - buf + ((i - 1) * len) as CUdeviceptr, + buf + (dst * len) as CUdeviceptr, + buf + (src * len) as CUdeviceptr, len, compute.as_raw() )); } } - let begin = begin as udim - 1; - let len = j as udim - begin; - slice![from begin, take len] - }; - x0.slice(&[tokens, slice![all]]) + } + + x0.slice(&[slice![from begin, until dst + 1], slice![all]]) } fn logits(&self, mut x: Tensor, compute: &Stream<'ctx>) -> Tensor> {