Skip to content

Commit

Permalink
refactor(transformer): 恢复对不解码请求的支持,方便批量化时在任何位置切分输入
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Mar 21, 2024
1 parent 05596b6 commit 33c3097
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 40 deletions.
14 changes: 7 additions & 7 deletions service/src/cpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,12 @@ impl SessionContext {
#[inline]
fn request(&mut self, tokens: &[utok], max_seq_len: usize) -> Request<usize> {
let pos = self.0.request(tokens, max_seq_len);
Request {
id: self.0.id,
tokens: &self.0.cache_map[pos..],
cache: &mut self.0.cache,
pos: pos as _,
decode: true,
}
Request::new(
self.0.id,
&self.0.cache_map[pos..],
&mut self.0.cache,
pos as _,
true,
)
}
}
14 changes: 7 additions & 7 deletions service/src/nvidia.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,12 @@ impl<'a> SessionContext<'a> {
#[inline]
fn request(&mut self, tokens: &[utok], max_seq_len: usize) -> Request<'_, 'a, usize> {
let pos = self.0.request(tokens, max_seq_len);
Request {
id: self.0.id,
tokens: &self.0.cache_map[pos..],
cache: &mut self.0.cache,
pos: pos as _,
decode: true,
}
Request::new(
self.0.id,
&self.0.cache_map[pos..],
&mut self.0.cache,
pos as _,
true,
)
}
}
25 changes: 15 additions & 10 deletions transformer-cpu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ impl Transformer {
mut requests: Vec<Request<Id>>,
sample: &SampleArgs,
) -> Vec<(Id, utok)> {
requests.sort_unstable_by_key(|t| t.tokens.len());
// 归拢所有纯解码的请求到前面,减少解码 batching 的拷贝开销
requests.sort_unstable_by_key(Request::purely_decode);

// println!("tokens:");
// for request in requests.iter() {
Expand Down Expand Up @@ -75,7 +76,7 @@ impl Transformer {
let theta = self.0.rope_theta();
let mut pos = Vec::<u32>::with_capacity(nt as usize);
for request in requests.iter() {
pos.extend(request.pos..request.att_len());
pos.extend(request.pos()..request.att_len());
}
let pos = Tensor::new(DataType::U32, &[nt], reslice(&pos));

Expand All @@ -86,7 +87,7 @@ impl Transformer {
let mut att_buf = Storage::new((nh * max_seq_len * max_att_len) as usize * dt.size());
let mut gate_up = tensor(dt, &[nt, di + di]);

let tokens = requests.iter().flat_map(|r| r.tokens).copied();
let tokens = requests.iter().flat_map(Request::tokens).copied();
gather(&mut x0, &self.0.embed_tokens(), tokens);
// println!("gather:\n{x0}");

Expand Down Expand Up @@ -121,7 +122,7 @@ impl Transformer {

let mut req = 0;
for r in requests.iter_mut() {
let pos = r.pos;
let pos = r.pos();
let seq_len = r.seq_len();
let att_len = r.att_len();

Expand All @@ -137,7 +138,7 @@ impl Transformer {
let mut o = unsafe { o.map_physical(|u| &mut ***u) };

let mut q_att = Tensor::new(dt, &[nh, seq_len, dh], &mut q_buf[..]);
let (k_cache, v_cache) = r.cache[layer].get();
let (k_cache, v_cache) = r.cache(layer);
let k_cat = k_cache.as_mut().slice(cat_slice);
let v_cat = v_cache.as_mut().slice(cat_slice);
let mut k_cat = unsafe { k_cat.map_physical(|u| &mut **u) };
Expand Down Expand Up @@ -193,17 +194,21 @@ impl Transformer {
// println!("layer {layer} down:\n{x0}");
}

let (head, others) = requests.split_first().unwrap();
if !head.decode() {
return vec![];
}

let tokens = {
let (head, others) = requests.split_first().unwrap();
let begin = head.tokens.len();
let begin = head.seq_len() as usize;
let mut i = begin;
let mut j = begin;
let buf = x0.as_mut_slice();
let len = d as usize * dt.size();
for r in others {
i += r.tokens.len();
i += r.seq_len() as usize;
j += 1;
if i > j {
if r.decode() && i > j {
buf.copy_within((i - 1) * len..i * len, (j - 1) * len);
}
}
Expand Down Expand Up @@ -238,7 +243,7 @@ impl Transformer {
requests
.into_iter()
.enumerate()
.map(|(i, r)| (r.id, sample.random(&kernel::slice!(logits; voc; [i]))))
.map(|(i, r)| (r.id(), sample.random(&kernel::slice!(logits; voc; [i]))))
.collect()
}};
}
Expand Down
25 changes: 15 additions & 10 deletions transformer-nvidia/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ impl<'ctx> Transformer<'ctx> {
compute: &Stream<'ctx>,
transfer: &Stream<'ctx>,
) -> Vec<(Id, utok)> {
requests.sort_unstable_by_key(|t| t.tokens.len());
// 归拢所有纯解码的请求到前面,减少解码 batching 的拷贝开销
requests.sort_unstable_by_key(Request::purely_decode);

// println!("tokens:");
// for request in requests.iter() {
Expand Down Expand Up @@ -101,7 +102,7 @@ impl<'ctx> Transformer<'ctx> {
let mut pos = tensor(DataType::U32, &[nt], transfer);
let mut pos_ = Vec::<u32>::with_capacity(nt as usize);
for request in requests.iter() {
pos_.extend(request.pos..request.att_len());
pos_.extend(request.pos()..request.att_len());
}
pos.physical_mut().copy_in_async(&pos_, transfer);

Expand All @@ -116,7 +117,7 @@ impl<'ctx> Transformer<'ctx> {
let mut gate_up = tensor(dt, &[nt, di + di], transfer);
let e_alloc = transfer.record();

let tokens = requests.iter().flat_map(|r| r.tokens).copied();
let tokens = requests.iter().flat_map(Request::tokens).copied();
gather(&mut x0, &self.host.embed_tokens(), tokens, compute);
// compute.synchronize();
// println!("gather:\n{}", map_tensor(&x0));
Expand Down Expand Up @@ -159,7 +160,7 @@ impl<'ctx> Transformer<'ctx> {

let mut req = 0;
for r in requests.iter_mut() {
let pos = r.pos;
let pos = r.pos();
let seq_len = r.seq_len();
let att_len = r.att_len();

Expand All @@ -175,7 +176,7 @@ impl<'ctx> Transformer<'ctx> {
let mut o = unsafe { o.map_physical(|u| &mut ***u) };

let mut q_att = Tensor::new(dt, &[nh, seq_len, dh], &mut *q_buf);
let (k_cache, v_cache) = r.cache[layer].get();
let (k_cache, v_cache) = r.cache(layer);
let k_cat = k_cache.as_mut().slice(cat_slice);
let v_cat = v_cache.as_mut().slice(cat_slice);
let mut k_cat = unsafe { k_cat.map_physical(|u| &mut **u) };
Expand Down Expand Up @@ -240,17 +241,21 @@ impl<'ctx> Transformer<'ctx> {
// println!("layer {layer} down:\n{}", map_tensor(&x0));
}

let (head, others) = requests.split_first().unwrap();
if !head.decode() {
return vec![];
}

let tokens = {
let (head, others) = requests.split_first().unwrap();
let begin = head.tokens.len();
let begin = head.seq_len() as usize;
let mut i = begin;
let mut j = begin;
let buf = unsafe { x0.physical().as_raw() };
let len = d as usize * dt.size();
for r in others {
i += r.tokens.len();
i += r.seq_len() as usize;
j += 1;
if i > j {
if r.decode() && i > j {
cuda::driver!(cuMemcpyDtoDAsync_v2(
buf + ((j - 1) * len) as CUdeviceptr,
buf + ((i - 1) * len) as CUdeviceptr,
Expand Down Expand Up @@ -294,7 +299,7 @@ impl<'ctx> Transformer<'ctx> {
.enumerate()
.map(|(i, r)| {
(
r.id,
r.id(),
sample.random(&mut logits[i * voc as usize..][..voc as usize]),
)
})
Expand Down
61 changes: 55 additions & 6 deletions transformer/src/request.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,60 @@
use crate::LayerCache;
use common::{upos, utok};
use tensor::udim;
use tensor::{udim, Tensor};

pub struct Request<'a, Id, Storage> {
/// Identifier of this task.
pub id: Id,
id: Id,
/// Prompt of this request.
pub tokens: &'a [utok],
tokens: &'a [utok],
/// Context cache of this request.
pub cache: &'a mut [LayerCache<Storage>],
cache: &'a mut [LayerCache<Storage>],
/// Position of `prompt` in context.
pub pos: upos,
pos: upos,
/// Whether to decode the output.
pub decode: bool,
decode: bool,
}

impl<'a, Id, S> Request<'a, Id, S> {
#[inline]
pub fn new(
id: Id,
tokens: &'a [utok],
cache: &'a mut [LayerCache<S>],
pos: upos,
decode: bool,
) -> Self {
Self {
id,
tokens,
cache,
pos,
decode,
}
}
}

impl<T, U> Request<'_, T, U> {
#[inline]
pub fn id(self) -> T {
self.id
}

#[inline]
pub fn tokens(&self) -> &[utok] {
self.tokens
}

#[inline]
pub fn cache(&mut self, layer: usize) -> (&mut Tensor<U>, &mut Tensor<U>) {
self.cache[layer].get()
}

#[inline]
pub fn pos(&self) -> upos {
self.pos
}

#[inline]
pub const fn seq_len(&self) -> udim {
self.tokens.len() as _
Expand All @@ -25,4 +64,14 @@ impl<T, U> Request<'_, T, U> {
pub const fn att_len(&self) -> udim {
self.pos + self.seq_len()
}

#[inline]
pub fn decode(&self) -> bool {
self.decode
}

#[inline]
pub const fn purely_decode(&self) -> bool {
self.decode && self.tokens.len() == 1
}
}

0 comments on commit 33c3097

Please sign in to comment.