Skip to content

Commit

Permalink
style(xtask): 利用参数 flatten 简化代码
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Mar 13, 2024
1 parent 00cf6ad commit 95bb57f
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 50 deletions.
17 changes: 4 additions & 13 deletions xtask/src/chat.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,17 @@
use crate::{init_logger, service};
use crate::InferenceArgs;
use ::service::{Service, Session};
use colored::Colorize;
use std::{collections::HashMap, io::Write};

#[derive(Args, Default)]
pub(crate) struct ChatArgs {
/// Model directory.
#[clap(short, long)]
model: String,
/// Log level, may be "off", "trace", "debug", "info" or "error".
#[clap(long)]
log: Option<String>,

/// Use Nvidia GPU.
#[clap(long)]
nvidia: bool,
#[clap(flatten)]
inference: InferenceArgs,
}

impl ChatArgs {
pub fn invoke(self) {
init_logger(self.log);
let service = service(&self.model, self.nvidia);
let service: Service = self.inference.into();
let mut session = service.launch();
let mut sessions = HashMap::new();

Expand Down
18 changes: 5 additions & 13 deletions xtask/src/generate.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,19 @@
use crate::{init_logger, service};
use crate::InferenceArgs;
use service::Service;
use std::io::Write;

#[derive(Args, Default)]
pub(crate) struct GenerateArgs {
/// Model directory.
#[clap(short, long)]
model: String,
#[clap(flatten)]
inference: InferenceArgs,
/// Prompt.
#[clap(short, long)]
prompt: String,
/// Log level, may be "off", "trace", "debug", "info" or "error".
#[clap(long)]
log: Option<String>,

/// Use Nvidia GPU.
#[clap(long)]
nvidia: bool,
}

impl GenerateArgs {
pub fn invoke(self) {
init_logger(self.log);
let service = service(&self.model, self.nvidia);
let service: Service = self.inference.into();

print!("{}", self.prompt);
service.launch().generate(&self.prompt, |piece| {
Expand Down
62 changes: 38 additions & 24 deletions xtask/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,30 +35,44 @@ enum Commands {
Chat(chat::ChatArgs),
}

fn init_logger(log: Option<String>) {
use log::LevelFilter;
use simple_logger::SimpleLogger;
let log = log
.as_ref()
.and_then(|log| match log.to_lowercase().as_str() {
"off" | "none" => Some(LevelFilter::Off),
"trace" => Some(LevelFilter::Trace),
"debug" => Some(LevelFilter::Debug),
"info" => Some(LevelFilter::Info),
"error" => Some(LevelFilter::Error),
_ => None,
})
.unwrap_or(LevelFilter::Warn);
SimpleLogger::new().with_level(log).init().unwrap();
#[derive(Args, Default)]
struct InferenceArgs {
/// Model directory.
#[clap(short, long)]
model: String,
/// Log level, may be "off", "trace", "debug", "info" or "error".
#[clap(long)]
log: Option<String>,
/// Use Nvidia GPU.
#[clap(long)]
nvidia: bool,
}

fn service(model_dir: &str, nvidia: bool) -> Service {
Service::load_model(
model_dir,
if nvidia {
Device::NvidiaGpu(0)
} else {
Device::Cpu
},
)
impl From<InferenceArgs> for Service {
fn from(args: InferenceArgs) -> Self {
use log::LevelFilter;
use simple_logger::SimpleLogger;
let log = args
.log
.as_ref()
.and_then(|log| match log.to_lowercase().as_str() {
"off" | "none" => Some(LevelFilter::Off),
"trace" => Some(LevelFilter::Trace),
"debug" => Some(LevelFilter::Debug),
"info" => Some(LevelFilter::Info),
"error" => Some(LevelFilter::Error),
_ => None,
})
.unwrap_or(LevelFilter::Warn);
SimpleLogger::new().with_level(log).init().unwrap();

Service::load_model(
args.model,
if args.nvidia {
Device::NvidiaGpu(0)
} else {
Device::Cpu
},
)
}
}

0 comments on commit 95bb57f

Please sign in to comment.