Skip to content

Commit

Permalink
feat(service): cpu 上支持伪·无限长解码
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Mar 13, 2024
1 parent 95bb57f commit 0aad89d
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 31 deletions.
52 changes: 24 additions & 28 deletions service/src/cpu.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
use crate::{argmax, Command};
use common::{upos, utok};
use common::utok;
use half::f16;
use std::{collections::HashMap, path::Path, time::Instant};
use tensor::reslice;
use transformer_cpu::{LayerCache, Llama2, Memory, Request, Transformer};
use transformer_cpu::{LayerCache, Memory, Request, Transformer};

pub struct CpuTask {
eos: utok,
transformer: Transformer,
sessions: HashMap<usize, SessionContext>,
}
Expand All @@ -17,22 +16,20 @@ impl CpuTask {
let model = Box::new(Memory::load_safetensors_from_dir(model_dir).unwrap());
info!("load model ... {:?}", time.elapsed());

let eos = model.eos_token_id();
let time = Instant::now();
let transformer = Transformer::new(model);
info!("build transformer ... {:?}", time.elapsed());

let sessions = HashMap::new();
Self {
eos,
transformer,
sessions,
}
}

pub fn invoke(&mut self, cmd: Command) {
match cmd {
Command::Chat {
Command::Infer {
id,
prompt,
responsing,
Expand All @@ -42,19 +39,27 @@ impl CpuTask {
.entry(id)
.or_insert_with_key(|&id| SessionContext::new(&self.transformer, id));

let max_seq_len = self.transformer.max_seq_len();
let eos = self.transformer.eos_token_id();

let time = Instant::now();
let mut logits = self.transformer.decode(vec![ctx.request(&prompt)]).1;
let mut logits = self
.transformer
.decode(vec![ctx.request(&prompt, max_seq_len)])
.1;
info!("prefill transformer ... {:?}", time.elapsed());

let max_seq_len = self.transformer.max_seq_len() as upos;
while ctx.pos < max_seq_len {
loop {
let token = argmax(reslice::<u8, f16>(logits.access().as_slice()));
if token == self.eos {
if token == eos {
break;
}
responsing.send(token).unwrap();

logits = self.transformer.decode(vec![ctx.request(&[token])]).1;
logits = self
.transformer
.decode(vec![ctx.request(&[token], max_seq_len)])
.1;
}
}
Command::Drop { id } => {
Expand All @@ -64,31 +69,22 @@ impl CpuTask {
}
}

struct SessionContext {
id: usize,
pos: upos,
cache: Vec<LayerCache>,
}
struct SessionContext(super::SessionContext<LayerCache>);

impl SessionContext {
#[inline]
fn new(transformer: &Transformer, id: usize) -> Self {
Self {
id,
pos: 0,
cache: transformer.new_cache(),
}
Self(super::SessionContext::new(transformer.new_cache(), id))
}

#[inline]
fn request<'a>(&'a mut self, tokens: &'a [utok]) -> Request<usize> {
let pos = self.pos;
self.pos += tokens.len() as upos;
fn request(&mut self, tokens: &[utok], max_seq_len: usize) -> Request<usize> {
let pos = self.0.request(tokens, max_seq_len);
Request {
id: self.id,
tokens,
cache: &mut self.cache,
pos,
id: self.0.id,
tokens: &self.0.tokens[pos..],
cache: &mut self.0.cache,
pos: pos as _,
}
}
}
42 changes: 41 additions & 1 deletion service/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ impl Service {
}

enum Command {
Chat {
Infer {
id: usize,
prompt: Vec<utok>,
responsing: Sender<utok>,
Expand Down Expand Up @@ -110,6 +110,46 @@ fn tokenizer(model_dir: impl AsRef<Path>) -> Box<dyn Tokenizer + Send + Sync> {
panic!("Tokenizer file not found");
}

struct SessionContext<Cache> {
id: usize,
tokens: Vec<utok>,
cache: Vec<Cache>,
}

impl<Cache> SessionContext<Cache> {
#[inline]
fn new(cache: Vec<Cache>, id: usize) -> Self {
Self {
id,
tokens: Vec::new(),
cache,
}
}

#[inline]
fn request(&mut self, tokens: &[utok], max_seq_len: usize) -> usize {
if self.tokens.len() + tokens.len() > max_seq_len {
let pos = self.tokens.len().min(16);
if tokens.len() > max_seq_len / 2 {
let tokens = &tokens[tokens.len() - max_seq_len / 2..];
self.tokens.truncate(pos);
self.tokens.extend_from_slice(tokens);
} else {
let tail_len = (self.tokens.len() - pos).min(64);
let tail = self.tokens.len() - tail_len;
self.tokens.copy_within(tail.., pos);
self.tokens.truncate(pos + tail_len);
self.tokens.extend_from_slice(tokens);
}
pos
} else {
let pos = self.tokens.len();
self.tokens.extend_from_slice(tokens);
pos
}
}
}

fn argmax<T: PartialOrd>(logits: &[T]) -> utok {
logits
.iter()
Expand Down
3 changes: 2 additions & 1 deletion service/src/nvidia.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ pub fn task(model_dir: impl AsRef<Path>, receiver: Receiver<Command>, ctx: &Cont

while let Ok(cmd) = receiver.recv() {
match cmd {
Command::Chat {
Command::Infer {
id,
prompt,
responsing,
Expand Down Expand Up @@ -79,6 +79,7 @@ struct SessionContext<'ctx> {
}

impl<'ctx> SessionContext<'ctx> {
#[inline]
fn new(transformer: &Transformer, transfer: &'ctx Stream) -> Self {
Self {
pos: 0,
Expand Down
2 changes: 1 addition & 1 deletion service/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ impl Session {
let prompt = self.component.tokenizer.encode(prompt);

let (responsing, receiver) = channel();
let chat = Command::Chat {
let chat = Command::Infer {
id: self.id,
prompt,
responsing,
Expand Down
6 changes: 6 additions & 0 deletions transformer-cpu/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
mod kernel;
mod storage;

use common::utok;
use kernel::{gather, mat_mul, rms_norm, rotary_embedding, softmax, swiglu};
use storage::Storage;
use tensor::{reslice, slice, udim, DataType, Tensor};
Expand All @@ -27,6 +28,11 @@ impl Transformer {
self.0.max_position_embeddings()
}

#[inline]
pub fn eos_token_id(&self) -> utok {
self.0.eos_token_id()
}

pub fn decode<Id>(&mut self, mut requests: Vec<Request<Id>>) -> (Vec<Id>, Tensor<Storage>) {
requests.sort_unstable_by_key(|t| t.tokens.len());

Expand Down

0 comments on commit 0aad89d

Please sign in to comment.