Skip to content

Commit

Permalink
fix: 改正 tensor slice 和 decode 移动算法
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Mar 27, 2024
1 parent 270b860 commit 2614725
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 83 deletions.
144 changes: 98 additions & 46 deletions tensor/src/slice.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::{idim, pattern::Pattern, udim, Affine, Shape, Tensor};
use std::{cmp::Ordering, iter::zip};

impl<Physical> Tensor<Physical> {
pub fn slice(self, dims: &[SliceDim]) -> Self {
Expand All @@ -11,55 +12,13 @@ impl<Physical> Tensor<Physical> {
}
}

#[derive(Clone, Debug)]
pub struct SliceDim {
pub start: udim,
pub step: idim,
pub len: udim,
}

#[macro_export]
macro_rules! slice {
[all] => {
$crate::SliceDim {
start: 0,
step: 1,
len: udim::MAX,
}
};
[from $start:expr, take $len:expr] => {
$crate::SliceDim {
start: $start,
step: 1,
len: $len,
}
};
[$start:expr; $step:expr; $len:expr] => {
$crate::SliceDim {
start: $start,
step: $step,
len: $len,
}
};
}

fn build(meta: &[SliceDim], input: &[udim]) -> (Shape, Affine) {
assert_eq!(input.len(), meta.len());
assert!(meta
.iter()
.zip(input)
.all(|(d, &i)| { (0..i).contains(&d.start) }));

let shape = meta
.iter()
.zip(input)
.map(|(d, &i)| {
let distance = if d.step > 0 { i - d.start } else { d.start };
let step = d.step.unsigned_abs();
((distance + step - 1) / step).min(d.len)
})
.collect::<Shape>();
let meta = zip(meta, input)
.map(|(d, &len)| d.normalize(len))
.collect::<Vec<_>>();

let shape = meta.iter().map(|d| d.len).collect::<Shape>();
let n = meta.len();
let affine = Affine::from_fn(n + 1, n + 1, |r, c| {
if r == n {
Expand All @@ -75,6 +34,7 @@ fn build(meta: &[SliceDim], input: &[udim]) -> (Shape, Affine) {

#[test]
fn test() {
use crate::slice;
let (shape, affine) = build(&[slice![2;1;2], slice![0;1;4], slice![1;2;3]], &[5, 6, 7]);
assert_eq!(shape.as_slice(), &[2, 4, 3]);
assert_eq!(
Expand All @@ -88,3 +48,95 @@ fn test() {
]
);
}

#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct SliceDim {
pub start: udim,
pub step: idim,
pub len: udim,
}

impl SliceDim {
#[inline]
pub fn normalize(&self, len: udim) -> Self {
match self.step.cmp(&0) {
Ordering::Greater => {
assert!(self.start < len);
Self {
start: self.start,
step: self.step,
len: {
let step = self.step as udim;
((len - self.start + step - 1) / step).min(self.len)
},
}
}
Ordering::Equal => {
assert!(self.start < len);
Self {
start: self.start,
step: self.step,
len: self.len,
}
}
Ordering::Less => {
let start = self.start.min(len - 1);
Self {
start,
step: self.step,
len: {
let step = self.step.unsigned_abs();
((start + 1 + step - 1) / step).min(self.len)
},
}
}
}
}
}

#[macro_export]
macro_rules! slice {
[all] => {
slice![0; 1; usize::MAX]
};
[rev] => {
slice![usize::MAX; -1; usize::MAX]
};
[take $len:expr] => {
slice![0; 1; $len]
};
[from $start:expr, until $end:expr] => {
slice![$start; 1; $end - $start]
};
[from $start:expr, take $len:expr] => {
slice![$start; 1; $len]
};
[from $start:expr, take $len:expr, per $step:expr] => {
slice![$start; $step; $len]
};
[$start:expr; $step:expr; $len:expr] => {
$crate::SliceDim {
start: $start as _,
step : $step as _,
len : $len as _,
}
};
}

#[test]
fn test_macro() {
assert_eq!(
slice![5; -3; 2],
SliceDim {
start: 5,
step: -3,
len: 2,
}
);
assert_eq!(slice![all], slice![0; 1; usize::MAX]);
assert_eq!(slice![rev], slice![usize::MAX; -1; usize::MAX]);
assert_eq!(slice![take 5], slice![0; 1; 5]);
assert_eq!(slice![from 3, until 5], slice![3; 1; 2]);
assert_eq!(slice![from 3, take 5], slice![3; 1; 5]);
assert_eq!(slice![from 3, take 5, per 2], slice![3; 2; 5]);
}
33 changes: 15 additions & 18 deletions transformer-cpu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -275,28 +275,25 @@ impl Transformer {
requests: &[Request<Id>],
mut x0: Tensor<Storage>,
) -> Tensor<Storage> {
let dt = self.0.data_type();
let d = self.0.hidden_size() as udim;
let buf = x0.as_mut_slice();
let len = self.0.hidden_size() * self.0.data_type().size();

let (head, others) = requests.split_first().unwrap();
let tokens = {
let begin = head.seq_len() as usize;
let mut i = begin;
let mut j = begin;
let buf = x0.as_mut_slice();
let len = d as usize * dt.size();
for r in others {
i += r.seq_len() as usize;
j += 1;
if r.decode() && i > j {
buf.copy_within((i - 1) * len..i * len, (j - 1) * len);
let begin = head.seq_len() as usize - 1;

let mut src = begin;
let mut dst = begin;
for r in others {
src += r.seq_len() as usize;
if r.decode() {
dst += 1;
if dst < src {
buf.copy_within(src * len..(src + 1) * len, dst * len);
}
}
let begin = begin as udim - 1;
let len = j as udim - begin;
slice![from begin, take len]
};
x0.slice(&[tokens, slice![all]])
}

x0.slice(&[slice![from begin, until dst + 1], slice![all]])
}

fn logits(&self, mut x: Tensor<Storage>) -> Tensor<Storage> {
Expand Down
36 changes: 17 additions & 19 deletions transformer-nvidia/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ impl<'ctx> Transformer<'ctx> {

requests
.into_iter()
.filter(Request::decode)
.enumerate()
.map(|(i, r)| {
(
Expand Down Expand Up @@ -330,33 +331,30 @@ impl<'ctx> Transformer<'ctx> {
x0: Tensor<Storage<'ctx>>,
compute: &Stream,
) -> Tensor<Storage<'ctx>> {
let dt = self.host.data_type();
let d = self.host.hidden_size() as udim;
let buf = unsafe { x0.physical().as_raw() };
let len = self.host.hidden_size() * self.host.data_type().size();

let (head, others) = requests.split_first().unwrap();
let tokens = {
let begin = head.seq_len() as usize;
let mut i = begin;
let mut j = begin;
let buf = unsafe { x0.physical().as_raw() };
let len = d as usize * dt.size();
for r in others {
i += r.seq_len() as usize;
j += 1;
if r.decode() && i > j {
let begin = head.seq_len() as usize - 1;

let mut src = begin;
let mut dst = begin;
for r in others {
src += r.seq_len() as usize;
if r.decode() {
dst += 1;
if dst < src {
cuda::driver!(cuMemcpyDtoDAsync_v2(
buf + ((j - 1) * len) as CUdeviceptr,
buf + ((i - 1) * len) as CUdeviceptr,
buf + (dst * len) as CUdeviceptr,
buf + (src * len) as CUdeviceptr,
len,
compute.as_raw()
));
}
}
let begin = begin as udim - 1;
let len = j as udim - begin;
slice![from begin, take len]
};
x0.slice(&[tokens, slice![all]])
}

x0.slice(&[slice![from begin, until dst + 1], slice![all]])
}

fn logits(&self, mut x: Tensor<Storage>, compute: &Stream<'ctx>) -> Tensor<Storage<'ctx>> {
Expand Down

0 comments on commit 2614725

Please sign in to comment.