Skip to content

Commit

Permalink
refactor(xtask): 简化构造服务的参数,添加一个简单的测试
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Mar 4, 2024
1 parent 724b64c commit c7a1c54
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 32 deletions.
55 changes: 46 additions & 9 deletions xtask/src/service/channel.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
pub(super) trait Channel {
use std::collections::VecDeque;

pub(super) trait Channel {
fn receive(&mut self) -> Result<Query, ReceiveError>;
fn send(&mut self, response: Response) -> Result<(), SendError>;
}
Expand All @@ -24,9 +26,22 @@ pub(super) enum SendError {
Json(serde_json::Error),
}

#[derive(serde::Serialize, serde::Deserialize, Debug)]
pub(super) struct Query {
pub id: usize,
pub prompt: String,
}

#[derive(serde::Serialize, serde::Deserialize, Debug)]
pub(super) struct Response {
pub id: usize,
pub prompt: String,
}

pub(super) struct StdioChannel;

impl Channel for StdioChannel {
#[inline]
fn receive(&mut self) -> Result<Query, ReceiveError> {
let mut buf = String::new();
std::io::stdin()
Expand All @@ -35,21 +50,43 @@ impl Channel for StdioChannel {
serde_json::from_str(&buf).map_err(ReceiveError::Json)
}

#[inline]
fn send(&mut self, response: Response) -> Result<(), SendError> {
let buf = serde_json::to_string(&response).map_err(SendError::Json)?;
println!("{buf}");
Ok(())
}
}

#[derive(serde::Serialize, serde::Deserialize, Debug)]
pub(super) struct Query {
pub id: usize,
pub prompt: String,
pub(super) struct PrefilledChannel {
queries: VecDeque<Query>,
responses: Vec<Response>,
}

#[derive(serde::Serialize, serde::Deserialize, Debug)]
pub(super) struct Response {
pub id: usize,
pub prompt: String,
impl Channel for PrefilledChannel {
#[inline]
fn receive(&mut self) -> Result<Query, ReceiveError> {
self.queries.pop_front().ok_or(ReceiveError::NoQuery)
}

#[inline]
fn send(&mut self, response: Response) -> Result<(), SendError> {
self.responses.push(response);
Ok(())
}
}

impl PrefilledChannel {
#[inline]
pub fn from_prompt(prompt: String) -> Self {
Self {
queries: VecDeque::from(vec![Query { id: 0, prompt }]),
responses: Vec::new(),
}
}

#[inline]
pub fn take(self) -> Vec<Response> {
self.responses
}
}
59 changes: 38 additions & 21 deletions xtask/src/service/cpu.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,23 @@
use super::{channel::channel, chat::apply_chat, ServiceArgs};
use super::{chat::apply_chat, ServiceParts};
use crate::{
common::{argmax, tokenizer},
common::argmax,
service::channel::{Query, Response},
Template,
};
use common::upos;
use std::{collections::HashMap, path::Path, time::Instant};
use std::{collections::HashMap, io::Write, time::Instant};
use transformer_cpu::{
model_parameters::{Llama2, Memory},
LayerCache, Transformer,
};

pub(super) fn run(args: ServiceArgs) {
let template = if args.model.to_ascii_lowercase().contains("tinyllama") {
Template::ChatTinyLlama
} else {
Template::Chat9G
};
let model_dir = Path::new(&args.model);

let time = Instant::now();
let tokenizer = tokenizer(args.tokenizer, &model_dir);
info!("build tokenizer ... {:?}", time.elapsed());

pub(super) fn run(
ServiceParts {
model_dir,
template,
tokenizer,
mut channel,
}: ServiceParts,
) {
let time = Instant::now();
let model = Box::new(Memory::load_safetensors_from_dir(&model_dir).unwrap());
info!("load model ... {:?}", time.elapsed());
Expand All @@ -34,10 +29,6 @@ pub(super) fn run(args: ServiceArgs) {
let mut transformer = Transformer::new(model);
info!("build transformer ... {:?}", time.elapsed());

let time = Instant::now();
let mut channel = channel(args.channel);
info!("build channel ... {:?}", time.elapsed());

struct SessionContext {
pos: upos,
kv_cache: Vec<LayerCache>,
Expand Down Expand Up @@ -72,10 +63,36 @@ pub(super) fn run(args: ServiceArgs) {
break;
}

out.push_str(&tokenizer.decode(token).replace('▁', " "));
let text = tokenizer.decode(token).replace('▁', " ");
out.push_str(&text);
session.pos += 1;

debug!("decode for {id}: {token:>5} {text:?}");
}

std::io::stdout().flush().unwrap();
channel.send(Response { id, prompt: out }).unwrap();
}
}

#[test]
fn cpu_service() {
use super::channel::PrefilledChannel;
use crate::{common::tokenizer, Template};
use std::path::PathBuf;

let model_dir = PathBuf::from("../../TinyLlama-1.1B-Chat-v1.0_F16");
if !model_dir.is_dir() {
return;
}

crate::common::logger_init(&Some("trace".into()));
let args = ServiceParts {
tokenizer: tokenizer(None, &model_dir),
model_dir,
template: Template::ChatTinyLlama,
channel: Box::new(PrefilledChannel::from_prompt("Once upon a time, ".into())),
};

run(args);
}
43 changes: 41 additions & 2 deletions xtask/src/service/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ mod cpu;
#[cfg(detected_cuda)]
mod nvidia;

use crate::{common::tokenizer, Template};
use channel::channel;
use std::{path::PathBuf, time::Instant};
use tokenizer::Tokenizer;

#[derive(Args, Default)]
pub(crate) struct ServiceArgs {
/// Model directory.
Expand All @@ -26,7 +31,6 @@ pub(crate) struct ServiceArgs {

impl ServiceArgs {
pub fn launch(self) {
crate::common::logger_init(&self.log);
if self.nvidia {
#[cfg(detected_cuda)]
{
Expand All @@ -37,7 +41,42 @@ impl ServiceArgs {
panic!("Nvidia GPU is not available");
}
} else {
cpu::run(self);
cpu::run(self.into());
}
}
}

struct ServiceParts {
model_dir: PathBuf,
template: Template,
tokenizer: Box<dyn Tokenizer>,
channel: Box<dyn channel::Channel>,
}

impl From<ServiceArgs> for ServiceParts {
fn from(args: ServiceArgs) -> Self {
crate::common::logger_init(&args.log);

let template = if args.model.to_ascii_lowercase().contains("tinyllama") {
Template::ChatTinyLlama
} else {
Template::Chat9G
};
let model_dir = PathBuf::from(&args.model);

let time = Instant::now();
let tokenizer = tokenizer(args.tokenizer, &model_dir);
info!("build tokenizer ... {:?}", time.elapsed());

let time = Instant::now();
let channel = channel(args.channel);
info!("build channel ... {:?}", time.elapsed());

Self {
model_dir,
template,
tokenizer,
channel,
}
}
}

0 comments on commit c7a1c54

Please sign in to comment.