From 95bb57f1a5343876be23daf8e017da5aeffed142 Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Wed, 13 Mar 2024 13:32:33 +0800 Subject: [PATCH] =?UTF-8?q?style(xtask):=20=E5=88=A9=E7=94=A8=E5=8F=82?= =?UTF-8?q?=E6=95=B0=20flatten=20=E7=AE=80=E5=8C=96=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- xtask/src/chat.rs | 17 +++--------- xtask/src/generate.rs | 18 ++++--------- xtask/src/main.rs | 62 ++++++++++++++++++++++++++----------------- 3 files changed, 47 insertions(+), 50 deletions(-) diff --git a/xtask/src/chat.rs b/xtask/src/chat.rs index c17b38f9..82c7ed66 100644 --- a/xtask/src/chat.rs +++ b/xtask/src/chat.rs @@ -1,26 +1,17 @@ -use crate::{init_logger, service}; +use crate::InferenceArgs; use ::service::{Service, Session}; use colored::Colorize; use std::{collections::HashMap, io::Write}; #[derive(Args, Default)] pub(crate) struct ChatArgs { - /// Model directory. - #[clap(short, long)] - model: String, - /// Log level, may be "off", "trace", "debug", "info" or "error". - #[clap(long)] - log: Option, - - /// Use Nvidia GPU. - #[clap(long)] - nvidia: bool, + #[clap(flatten)] + inference: InferenceArgs, } impl ChatArgs { pub fn invoke(self) { - init_logger(self.log); - let service = service(&self.model, self.nvidia); + let service: Service = self.inference.into(); let mut session = service.launch(); let mut sessions = HashMap::new(); diff --git a/xtask/src/generate.rs b/xtask/src/generate.rs index 3a6a714e..95a91ab0 100644 --- a/xtask/src/generate.rs +++ b/xtask/src/generate.rs @@ -1,27 +1,19 @@ -use crate::{init_logger, service}; +use crate::InferenceArgs; +use service::Service; use std::io::Write; #[derive(Args, Default)] pub(crate) struct GenerateArgs { - /// Model directory. - #[clap(short, long)] - model: String, + #[clap(flatten)] + inference: InferenceArgs, /// Prompt. #[clap(short, long)] prompt: String, - /// Log level, may be "off", "trace", "debug", "info" or "error". - #[clap(long)] - log: Option, - - /// Use Nvidia GPU. - #[clap(long)] - nvidia: bool, } impl GenerateArgs { pub fn invoke(self) { - init_logger(self.log); - let service = service(&self.model, self.nvidia); + let service: Service = self.inference.into(); print!("{}", self.prompt); service.launch().generate(&self.prompt, |piece| { diff --git a/xtask/src/main.rs b/xtask/src/main.rs index d5e93300..2067407c 100644 --- a/xtask/src/main.rs +++ b/xtask/src/main.rs @@ -35,30 +35,44 @@ enum Commands { Chat(chat::ChatArgs), } -fn init_logger(log: Option) { - use log::LevelFilter; - use simple_logger::SimpleLogger; - let log = log - .as_ref() - .and_then(|log| match log.to_lowercase().as_str() { - "off" | "none" => Some(LevelFilter::Off), - "trace" => Some(LevelFilter::Trace), - "debug" => Some(LevelFilter::Debug), - "info" => Some(LevelFilter::Info), - "error" => Some(LevelFilter::Error), - _ => None, - }) - .unwrap_or(LevelFilter::Warn); - SimpleLogger::new().with_level(log).init().unwrap(); +#[derive(Args, Default)] +struct InferenceArgs { + /// Model directory. + #[clap(short, long)] + model: String, + /// Log level, may be "off", "trace", "debug", "info" or "error". + #[clap(long)] + log: Option, + /// Use Nvidia GPU. + #[clap(long)] + nvidia: bool, } -fn service(model_dir: &str, nvidia: bool) -> Service { - Service::load_model( - model_dir, - if nvidia { - Device::NvidiaGpu(0) - } else { - Device::Cpu - }, - ) +impl From for Service { + fn from(args: InferenceArgs) -> Self { + use log::LevelFilter; + use simple_logger::SimpleLogger; + let log = args + .log + .as_ref() + .and_then(|log| match log.to_lowercase().as_str() { + "off" | "none" => Some(LevelFilter::Off), + "trace" => Some(LevelFilter::Trace), + "debug" => Some(LevelFilter::Debug), + "info" => Some(LevelFilter::Info), + "error" => Some(LevelFilter::Error), + _ => None, + }) + .unwrap_or(LevelFilter::Warn); + SimpleLogger::new().with_level(log).init().unwrap(); + + Service::load_model( + args.model, + if args.nvidia { + Device::NvidiaGpu(0) + } else { + Device::Cpu + }, + ) + } }