Skip to content

Commit

Permalink
fix(service): 改正多轮推理设置 pos 的错误
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Apr 24, 2024
1 parent ec4359b commit 6a20fce
Showing 1 changed file with 47 additions and 8 deletions.
55 changes: 47 additions & 8 deletions service/src/new.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::{
borrow::Cow,
error,
iter::zip,
mem::take,
mem::{replace, take},
ops::Range,
path::Path,
sync::{Arc, Mutex},
Expand Down Expand Up @@ -96,7 +96,7 @@ where
});
{
let handle = handle.clone();
std::thread::spawn(move || {
tokio::task::spawn_blocking(move || {
// 这个线程的生命周期不小于服务的生命周期,不占用线程池
while let Some(tasks) = Some(handle.batcher.deq()).filter(|t| !t.is_empty()) {
let token_embedded = {
Expand Down Expand Up @@ -145,8 +145,7 @@ where
*token != eos && task.sender.send(*token).is_ok()
})
.for_each(|(mut task, token)| {
task.tokens = vec![token];
task.pos += 1;
task.pos += replace(&mut task.tokens, vec![token]).len() as upos;
handle.batcher.enq(task);
});
});
Expand Down Expand Up @@ -220,7 +219,7 @@ impl<M: CausalLM> Session<M> {
}
prompt = !prompt;
}
self.infer(prefill)
self.infer(prefill, 0)
}

/// 向对话的 `dialog_pos` 处填充 `prompt`,启动推理并返回忙会话。
Expand All @@ -241,12 +240,13 @@ impl<M: CausalLM> Session<M> {
self.tail.extend(tail);
self.dialog.truncate(dialog_pos);
}
let pos = self.pos();
let prompt = self.push_sentence(prompt).to_vec();
Ok(self.infer(prompt))
Ok(self.infer(prompt, pos))
}
}

fn infer(&mut self, tokens: Vec<utok>) -> BusySession<M> {
fn infer(&mut self, tokens: Vec<utok>, pos: upos) -> BusySession<M> {
// 生成推理任务与会话的交互管道
let (sender, receiver) = unbounded_channel();
let cache = Arc::new(Mutex::new(Some(
Expand All @@ -256,7 +256,7 @@ impl<M: CausalLM> Session<M> {
)));
self.component.handle.batcher.enq(Task {
tokens,
pos: self.pos(),
pos,
sample: self.sample.clone(),
cache: cache.clone(),
sender,
Expand Down Expand Up @@ -351,3 +351,42 @@ impl Sentence {
&self.tokens[..self.head_len]
}
}

#[test]
fn test() {
use colored::{Color, Colorize};
use std::{io::Write, iter::zip};
use tokio::{runtime::Builder, task::JoinSet};

let Some(model_dir) = common::test_model::find() else {
return;
};
println!("model_dir: {}", model_dir.display());

let runtime = Builder::new_current_thread().enable_time().build().unwrap();
let _rt = runtime.enter();

let service = Service::<transformer_cpu::Transformer>::new(model_dir);

let mut set = JoinSet::new();
let tasks = vec![
("Say \"Hi\" to me.", Color::Yellow),
("Hi", Color::Red),
("Where is the capital of France?", Color::Green),
];

let sessions = tasks.iter().map(|_| service.launch()).collect::<Vec<_>>();

for ((prompt, color), mut session) in zip(tasks, sessions) {
set.spawn(async move {
let mut busy = session.chat(0, prompt).unwrap();
while let Some(s) = busy.decode().await {
print!("{}", s.color(color));
std::io::stdout().flush().unwrap();
}
});
}

runtime.block_on(async { while set.join_next().await.is_some() {} });
runtime.shutdown_background();
}

0 comments on commit 6a20fce

Please sign in to comment.