diff --git a/Cargo.lock b/Cargo.lock index 58046078..92b2dc2c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -915,9 +915,9 @@ dependencies = [ [[package]] name = "flate2" -version = "1.0.28" +version = "1.0.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46303f565772937ffe1d394a4fac6f411c6013172fadde9dcdb1e147a086940e" +checksum = "4556222738635b7a3417ae6130d8f52201e45a0c4d1a907f0826383adb5f85e7" dependencies = [ "crc32fast", "miniz_oxide", @@ -2933,6 +2933,7 @@ dependencies = [ "clap", "colored", "common", + "distributed", "log", "search-cuda-tools", "service", diff --git a/xtask/Cargo.toml b/xtask/Cargo.toml index db9b8ab6..b5e37d8a 100644 --- a/xtask/Cargo.toml +++ b/xtask/Cargo.toml @@ -13,6 +13,7 @@ causal-lm = { path = "../causal-lm" } transformer = { path = "../transformer" } transformer-cpu = { path = "../transformer-cpu" } transformer-nv = { path = "../nvidia/transformer", optional = true } +distributed = { path = "../nvidia/distributed", optional = true } service = { path = "../service" } web-api = { path = "../web-api" } log.workspace = true @@ -26,4 +27,4 @@ search-cuda-tools.workspace = true [features] default = ["nvidia"] -nvidia = ["transformer-nv"] +nvidia = ["transformer-nv", "distributed"] diff --git a/xtask/src/chat.rs b/xtask/src/chat.rs index 3682fa2f..72b47b36 100644 --- a/xtask/src/chat.rs +++ b/xtask/src/chat.rs @@ -1,35 +1,36 @@ -use crate::print_now; +use crate::{print_now, InferenceArgs, Task}; use causal_lm::CausalLM; use colored::Colorize; use service::{Service, Session}; -use std::collections::HashMap; +use std::{collections::HashMap, fmt::Debug}; -impl crate::InferenceArgs { - pub async fn chat(self) { - macro_rules! chat { - ($ty:ty; $meta:expr) => { - let (mut service, _handle) = Service::<$ty>::load(&self.model, $meta); - service.default_sample = self.sample_args(); - Chatting::new(service).chat().await - }; - } +#[derive(Args, Default)] +pub(crate) struct ChatArgs { + #[clap(flatten)] + pub inference: InferenceArgs, +} - self.init_log(); - match self.nvidia().as_slice() { - [] => { - use transformer_cpu::Transformer as M; - chat!(M; ()); - } - #[cfg(detected_cuda)] - &[n] => { - use transformer_nv::{cuda, Transformer as M}; - chat!(M; cuda::Device::new(n)); - } - #[cfg(detected_nccl)] - _distribute => todo!(), - #[cfg(not(all(detected_cuda, detected_nccl)))] - _ => panic!("Set \"nvidia\" feature to enablel nvidia support."), +impl Task for ChatArgs { + fn inference(&self) -> &InferenceArgs { + &self.inference + } + + async fn typed(self, meta: M::Meta) + where + M: CausalLM + Send + Sync + 'static, + M::Storage: Send, + M::Error: Debug, + { + let (mut service, _handle) = Service::::load(&self.inference.model, meta); + service.default_sample = self.inference.sample_args(); + Chatting { + service, + current: 0, + next_id: 0, + sessions: Default::default(), } + .chat() + .await } } @@ -60,16 +61,6 @@ fn print_help() { } impl Chatting { - #[inline] - fn new(service: Service) -> Self { - Chatting { - service, - current: 0, - next_id: 0, - sessions: Default::default(), - } - } - async fn chat(mut self) { println!( "\ diff --git a/xtask/src/generate.rs b/xtask/src/generate.rs index 9815b348..c2bfd5cc 100644 --- a/xtask/src/generate.rs +++ b/xtask/src/generate.rs @@ -1,7 +1,7 @@ -use crate::{print_now, InferenceArgs}; -use causal_lm::{CausalLM, SampleArgs}; +use crate::{print_now, InferenceArgs, Task}; +use causal_lm::CausalLM; use service::Service; -use std::{fmt::Debug, path::Path}; +use std::fmt::Debug; #[derive(Args, Default)] pub(crate) struct GenerateArgs { @@ -11,68 +11,37 @@ pub(crate) struct GenerateArgs { #[clap(long, short)] pub prompt: String, /// Max number of steps to generate. - #[clap(long, short)] + #[clap(long)] pub max_steps: Option, } -impl GenerateArgs { - pub async fn generate(self) { - macro_rules! generate { - ($ty:ty; $meta:expr) => { - generate::<$ty>( - &self.inference.model, - $meta, - &self.prompt, - self.max_steps.unwrap_or(usize::MAX), - self.inference.sample_args(), - ) - .await; - }; - } +impl Task for GenerateArgs { + fn inference(&self) -> &InferenceArgs { + &self.inference + } + + async fn typed(self, meta: M::Meta) + where + M: CausalLM + Send + Sync + 'static, + M::Storage: Send, + M::Error: Debug, + { + let (service, _handle) = Service::::load(&self.inference.model, meta); + + print_now!("{}", self.prompt); - self.inference.init_log(); - match self.inference.nvidia().as_slice() { - [] => { - use transformer_cpu::Transformer as M; - generate!(M; ()); + let mut steps = self.max_steps.unwrap_or(usize::MAX); + let mut generator = service.generate(self.prompt, Some(self.inference.sample_args())); + while let Some(s) = generator.decode().await { + match &*s { + "\\n" => println!(), + _ => print_now!("{s}"), } - #[cfg(detected_cuda)] - &[n] => { - use transformer_nv::{cuda, Transformer as M}; - generate!(M; cuda::Device::new(n)); + steps -= 1; + if steps == 0 { + break; } - #[cfg(detected_nccl)] - _distribute => todo!(), - #[cfg(not(all(detected_cuda, detected_nccl)))] - _ => panic!("Set \"nvidia\" feature to enablel nvidia support."), - } - } -} - -async fn generate( - model_dir: impl AsRef, - meta: M::Meta, - prompt: impl AsRef, - max_steps: usize, - sample: SampleArgs, -) where - M: CausalLM + Send + Sync + 'static, - M::Storage: Send, - M::Error: Debug, -{ - let (mut service, _handle) = Service::::load(model_dir, meta); - service.default_sample = sample; - let mut generator = service.generate(prompt, None); - let mut steps = 0; - while let Some(s) = generator.decode().await { - match &*s { - "\\n" => println!(), - _ => print_now!("{s}"), - } - steps += 1; - if steps >= max_steps { - break; } + println!(); } - println!(); } diff --git a/xtask/src/main.rs b/xtask/src/main.rs index 59e08125..d5f37c90 100644 --- a/xtask/src/main.rs +++ b/xtask/src/main.rs @@ -4,11 +4,11 @@ mod deploy; mod generate; mod service; -use causal_lm::SampleArgs; +use causal_lm::{CausalLM, SampleArgs}; use clap::Parser; use deploy::DeployArgs; use service::ServiceArgs; -use std::{ffi::c_int, future::Future}; +use std::{ffi::c_int, fmt}; #[macro_use] extern crate clap; @@ -18,24 +18,9 @@ fn main() { match Cli::parse().command { Deploy(deploy) => deploy.deploy(), Cast(cast) => cast.invode(), - Generate(args) => block_on(args.generate()), - Chat(chat) => block_on(chat.chat()), - Service(service) => block_on(service.serve()), - } -} - -#[inline] -fn block_on(f: impl Future) { - #[cfg(detected_cuda)] - { - transformer_nv::cuda::init(); - } - let runtime = tokio::runtime::Runtime::new().unwrap(); - runtime.block_on(f); - runtime.shutdown_background(); - #[cfg(detected_cuda)] - { - transformer_nv::synchronize(); + Generate(args) => args.run(), + Chat(chat) => chat.run(), + Service(service) => service.run(), } } @@ -56,7 +41,7 @@ enum Commands { /// Generate following text Generate(generate::GenerateArgs), /// Chat locally - Chat(InferenceArgs), + Chat(chat::ChatArgs), /// Start the service Service(ServiceArgs), } @@ -127,6 +112,51 @@ impl InferenceArgs { } } +trait Task: Sized { + fn inference(&self) -> &InferenceArgs; + + async fn typed(self, meta: M::Meta) + where + M: CausalLM + Send + Sync + 'static, + M::Storage: Send, + M::Error: fmt::Debug; + + fn run(self) { + #[cfg(detected_cuda)] + { + transformer_nv::cuda::init(); + } + let runtime = tokio::runtime::Runtime::new().unwrap(); + + self.inference().init_log(); + match self.inference().nvidia().as_slice() { + [] => { + use transformer_cpu::Transformer as M; + runtime.block_on(self.typed::(())); + } + #[cfg(detected_cuda)] + &[n] => { + use transformer_nv::{cuda, Transformer as M}; + runtime.block_on(self.typed::(cuda::Device::new(n))); + } + #[cfg(detected_nccl)] + distribute => { + use distributed::{cuda::Device, Transformer as M}; + let meta = distribute.iter().copied().map(Device::new).collect(); + runtime.block_on(self.typed::(meta)); + } + #[cfg(not(all(detected_cuda, detected_nccl)))] + _ => panic!("Set \"nvidia\" feature to enablel nvidia support."), + } + + runtime.shutdown_background(); + #[cfg(detected_cuda)] + { + transformer_nv::synchronize(); + } + } +} + #[macro_export] macro_rules! print_now { ($($arg:tt)*) => {{ diff --git a/xtask/src/service.rs b/xtask/src/service.rs index a355bbdb..8c6c1db7 100644 --- a/xtask/src/service.rs +++ b/xtask/src/service.rs @@ -1,40 +1,31 @@ -#[derive(Args, Default)] +use crate::{InferenceArgs, Task}; +use causal_lm::CausalLM; +use service::Service; +use std::fmt::Debug; +use web_api::start_infer_service; + +#[derive(Args, Default)] pub struct ServiceArgs { #[clap(flatten)] - pub inference: crate::InferenceArgs, + pub inference: InferenceArgs, /// Port to bind the service to #[clap(short, long)] pub port: u16, } -impl ServiceArgs { - pub async fn serve(self) { - use service::Service; - use web_api::start_infer_service; - - macro_rules! serve { - ($ty:ty; $meta:expr) => { - let (mut service, _handle) = Service::<$ty>::load(&self.inference.model, $meta); - service.default_sample = self.inference.sample_args(); - start_infer_service(service, self.port).await.unwrap(); - }; - } +impl Task for ServiceArgs { + fn inference(&self) -> &InferenceArgs { + &self.inference + } - self.inference.init_log(); - match self.inference.nvidia().as_slice() { - [] => { - use transformer_cpu::Transformer as M; - serve!(M; ()); - } - #[cfg(detected_cuda)] - &[n] => { - use transformer_nv::{cuda, Transformer as M}; - serve!(M; cuda::Device::new(n)); - } - #[cfg(detected_nccl)] - _distribute => todo!(), - #[cfg(not(all(detected_cuda, detected_nccl)))] - _ => panic!("Set \"nvidia\" feature to enablel nvidia support."), - } + async fn typed(self, meta: M::Meta) + where + M: CausalLM + Send + Sync + 'static, + M::Storage: Send, + M::Error: Debug, + { + let (mut service, _handle) = Service::::load(&self.inference.model, meta); + service.default_sample = self.inference.sample_args(); + start_infer_service(service, self.port).await.unwrap(); } }