Skip to content

Commit

Permalink
refactor(transformer-cpu): decode 需要返回 batching logits 的顺序
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Mar 13, 2024
1 parent b62d2f6 commit 00cf6ad
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
6 changes: 4 additions & 2 deletions service/src/cpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ impl CpuTask {
.or_insert_with_key(|&id| SessionContext::new(&self.transformer, id));

let time = Instant::now();
let mut logits = self.transformer.decode(vec![ctx.request(&prompt)]);
let mut logits = self.transformer.decode(vec![ctx.request(&prompt)]).1;
info!("prefill transformer ... {:?}", time.elapsed());

let max_seq_len = self.transformer.max_seq_len() as upos;
Expand All @@ -54,7 +54,7 @@ impl CpuTask {
}
responsing.send(token).unwrap();

logits = self.transformer.decode(vec![ctx.request(&[token])]);
logits = self.transformer.decode(vec![ctx.request(&[token])]).1;
}
}
Command::Drop { id } => {
Expand All @@ -71,6 +71,7 @@ struct SessionContext {
}

impl SessionContext {
#[inline]
fn new(transformer: &Transformer, id: usize) -> Self {
Self {
id,
Expand All @@ -79,6 +80,7 @@ impl SessionContext {
}
}

#[inline]
fn request<'a>(&'a mut self, tokens: &'a [utok]) -> Request<usize> {
let pos = self.pos;
self.pos += tokens.len() as upos;
Expand Down
4 changes: 2 additions & 2 deletions transformer-cpu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ impl Transformer {
self.0.max_position_embeddings()
}

pub fn decode<Id>(&mut self, mut requests: Vec<Request<Id>>) -> Tensor<Storage> {
pub fn decode<Id>(&mut self, mut requests: Vec<Request<Id>>) -> (Vec<Id>, Tensor<Storage>) {
requests.sort_unstable_by_key(|t| t.tokens.len());

// println!("tokens:");
Expand Down Expand Up @@ -221,7 +221,7 @@ impl Transformer {
mat_mul(&mut logits.access_mut(), 0., &x.access(), &lm_head, 1.);
// println!("pos {pos} logits:\n{}", logits.access());

logits
(requests.into_iter().map(|r| r.id).collect(), logits)
}
}

Expand Down

0 comments on commit 00cf6ad

Please sign in to comment.