Skip to content

Commit

Permalink
refactor(xtask): 移除选择硬件的宏,改为使用 trait 的默认实现来统一这部分代码
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Apr 26, 2024
1 parent 3c8632e commit 075d261
Show file tree
Hide file tree
Showing 6 changed files with 132 additions and 149 deletions.
5 changes: 3 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion xtask/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ causal-lm = { path = "../causal-lm" }
transformer = { path = "../transformer" }
transformer-cpu = { path = "../transformer-cpu" }
transformer-nv = { path = "../nvidia/transformer", optional = true }
distributed = { path = "../nvidia/distributed", optional = true }
service = { path = "../service" }
web-api = { path = "../web-api" }
log.workspace = true
Expand All @@ -26,4 +27,4 @@ search-cuda-tools.workspace = true

[features]
default = ["nvidia"]
nvidia = ["transformer-nv"]
nvidia = ["transformer-nv", "distributed"]
63 changes: 27 additions & 36 deletions xtask/src/chat.rs
Original file line number Diff line number Diff line change
@@ -1,35 +1,36 @@
use crate::print_now;
use crate::{print_now, InferenceArgs, Task};
use causal_lm::CausalLM;
use colored::Colorize;
use service::{Service, Session};
use std::collections::HashMap;
use std::{collections::HashMap, fmt::Debug};

impl crate::InferenceArgs {
pub async fn chat(self) {
macro_rules! chat {
($ty:ty; $meta:expr) => {
let (mut service, _handle) = Service::<$ty>::load(&self.model, $meta);
service.default_sample = self.sample_args();
Chatting::new(service).chat().await
};
}
#[derive(Args, Default)]
pub(crate) struct ChatArgs {
#[clap(flatten)]
pub inference: InferenceArgs,
}

self.init_log();
match self.nvidia().as_slice() {
[] => {
use transformer_cpu::Transformer as M;
chat!(M; ());
}
#[cfg(detected_cuda)]
&[n] => {
use transformer_nv::{cuda, Transformer as M};
chat!(M; cuda::Device::new(n));
}
#[cfg(detected_nccl)]
_distribute => todo!(),
#[cfg(not(all(detected_cuda, detected_nccl)))]
_ => panic!("Set \"nvidia\" feature to enablel nvidia support."),
impl Task for ChatArgs {
fn inference(&self) -> &InferenceArgs {
&self.inference
}

async fn typed<M>(self, meta: M::Meta)
where
M: CausalLM + Send + Sync + 'static,
M::Storage: Send,
M::Error: Debug,
{
let (mut service, _handle) = Service::<M>::load(&self.inference.model, meta);
service.default_sample = self.inference.sample_args();
Chatting {
service,
current: 0,
next_id: 0,
sessions: Default::default(),
}
.chat()
.await
}
}

Expand Down Expand Up @@ -60,16 +61,6 @@ fn print_help() {
}

impl<M: CausalLM> Chatting<M> {
#[inline]
fn new(service: Service<M>) -> Self {
Chatting {
service,
current: 0,
next_id: 0,
sessions: Default::default(),
}
}

async fn chat(mut self) {
println!(
"\
Expand Down
87 changes: 28 additions & 59 deletions xtask/src/generate.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::{print_now, InferenceArgs};
use causal_lm::{CausalLM, SampleArgs};
use crate::{print_now, InferenceArgs, Task};
use causal_lm::CausalLM;
use service::Service;
use std::{fmt::Debug, path::Path};
use std::fmt::Debug;

#[derive(Args, Default)]
pub(crate) struct GenerateArgs {
Expand All @@ -11,68 +11,37 @@ pub(crate) struct GenerateArgs {
#[clap(long, short)]
pub prompt: String,
/// Max number of steps to generate.
#[clap(long, short)]
#[clap(long)]
pub max_steps: Option<usize>,
}

impl GenerateArgs {
pub async fn generate(self) {
macro_rules! generate {
($ty:ty; $meta:expr) => {
generate::<$ty>(
&self.inference.model,
$meta,
&self.prompt,
self.max_steps.unwrap_or(usize::MAX),
self.inference.sample_args(),
)
.await;
};
}
impl Task for GenerateArgs {
fn inference(&self) -> &InferenceArgs {
&self.inference
}

async fn typed<M>(self, meta: M::Meta)
where
M: CausalLM + Send + Sync + 'static,
M::Storage: Send,
M::Error: Debug,
{
let (service, _handle) = Service::<M>::load(&self.inference.model, meta);

print_now!("{}", self.prompt);

self.inference.init_log();
match self.inference.nvidia().as_slice() {
[] => {
use transformer_cpu::Transformer as M;
generate!(M; ());
let mut steps = self.max_steps.unwrap_or(usize::MAX);
let mut generator = service.generate(self.prompt, Some(self.inference.sample_args()));
while let Some(s) = generator.decode().await {
match &*s {
"\\n" => println!(),
_ => print_now!("{s}"),
}
#[cfg(detected_cuda)]
&[n] => {
use transformer_nv::{cuda, Transformer as M};
generate!(M; cuda::Device::new(n));
steps -= 1;
if steps == 0 {
break;
}
#[cfg(detected_nccl)]
_distribute => todo!(),
#[cfg(not(all(detected_cuda, detected_nccl)))]
_ => panic!("Set \"nvidia\" feature to enablel nvidia support."),
}
}
}

async fn generate<M>(
model_dir: impl AsRef<Path>,
meta: M::Meta,
prompt: impl AsRef<str>,
max_steps: usize,
sample: SampleArgs,
) where
M: CausalLM + Send + Sync + 'static,
M::Storage: Send,
M::Error: Debug,
{
let (mut service, _handle) = Service::<M>::load(model_dir, meta);
service.default_sample = sample;
let mut generator = service.generate(prompt, None);
let mut steps = 0;
while let Some(s) = generator.decode().await {
match &*s {
"\\n" => println!(),
_ => print_now!("{s}"),
}
steps += 1;
if steps >= max_steps {
break;
}
println!();
}
println!();
}
72 changes: 51 additions & 21 deletions xtask/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ mod deploy;
mod generate;
mod service;

use causal_lm::SampleArgs;
use causal_lm::{CausalLM, SampleArgs};
use clap::Parser;
use deploy::DeployArgs;
use service::ServiceArgs;
use std::{ffi::c_int, future::Future};
use std::{ffi::c_int, fmt};

#[macro_use]
extern crate clap;
Expand All @@ -18,24 +18,9 @@ fn main() {
match Cli::parse().command {
Deploy(deploy) => deploy.deploy(),
Cast(cast) => cast.invode(),
Generate(args) => block_on(args.generate()),
Chat(chat) => block_on(chat.chat()),
Service(service) => block_on(service.serve()),
}
}

#[inline]
fn block_on(f: impl Future) {
#[cfg(detected_cuda)]
{
transformer_nv::cuda::init();
}
let runtime = tokio::runtime::Runtime::new().unwrap();
runtime.block_on(f);
runtime.shutdown_background();
#[cfg(detected_cuda)]
{
transformer_nv::synchronize();
Generate(args) => args.run(),
Chat(chat) => chat.run(),
Service(service) => service.run(),
}
}

Expand All @@ -56,7 +41,7 @@ enum Commands {
/// Generate following text
Generate(generate::GenerateArgs),
/// Chat locally
Chat(InferenceArgs),
Chat(chat::ChatArgs),
/// Start the service
Service(ServiceArgs),
}
Expand Down Expand Up @@ -127,6 +112,51 @@ impl InferenceArgs {
}
}

trait Task: Sized {
fn inference(&self) -> &InferenceArgs;

async fn typed<M>(self, meta: M::Meta)
where
M: CausalLM + Send + Sync + 'static,
M::Storage: Send,
M::Error: fmt::Debug;

fn run(self) {
#[cfg(detected_cuda)]
{
transformer_nv::cuda::init();
}
let runtime = tokio::runtime::Runtime::new().unwrap();

self.inference().init_log();
match self.inference().nvidia().as_slice() {
[] => {
use transformer_cpu::Transformer as M;
runtime.block_on(self.typed::<M>(()));
}
#[cfg(detected_cuda)]
&[n] => {
use transformer_nv::{cuda, Transformer as M};
runtime.block_on(self.typed::<M>(cuda::Device::new(n)));
}
#[cfg(detected_nccl)]
distribute => {
use distributed::{cuda::Device, Transformer as M};
let meta = distribute.iter().copied().map(Device::new).collect();
runtime.block_on(self.typed::<M>(meta));
}
#[cfg(not(all(detected_cuda, detected_nccl)))]
_ => panic!("Set \"nvidia\" feature to enablel nvidia support."),
}

runtime.shutdown_background();
#[cfg(detected_cuda)]
{
transformer_nv::synchronize();
}
}
}

#[macro_export]
macro_rules! print_now {
($($arg:tt)*) => {{
Expand Down
51 changes: 21 additions & 30 deletions xtask/src/service.rs
Original file line number Diff line number Diff line change
@@ -1,40 +1,31 @@
#[derive(Args, Default)]
use crate::{InferenceArgs, Task};
use causal_lm::CausalLM;
use service::Service;
use std::fmt::Debug;
use web_api::start_infer_service;

#[derive(Args, Default)]
pub struct ServiceArgs {
#[clap(flatten)]
pub inference: crate::InferenceArgs,
pub inference: InferenceArgs,
/// Port to bind the service to
#[clap(short, long)]
pub port: u16,
}

impl ServiceArgs {
pub async fn serve(self) {
use service::Service;
use web_api::start_infer_service;

macro_rules! serve {
($ty:ty; $meta:expr) => {
let (mut service, _handle) = Service::<$ty>::load(&self.inference.model, $meta);
service.default_sample = self.inference.sample_args();
start_infer_service(service, self.port).await.unwrap();
};
}
impl Task for ServiceArgs {
fn inference(&self) -> &InferenceArgs {
&self.inference
}

self.inference.init_log();
match self.inference.nvidia().as_slice() {
[] => {
use transformer_cpu::Transformer as M;
serve!(M; ());
}
#[cfg(detected_cuda)]
&[n] => {
use transformer_nv::{cuda, Transformer as M};
serve!(M; cuda::Device::new(n));
}
#[cfg(detected_nccl)]
_distribute => todo!(),
#[cfg(not(all(detected_cuda, detected_nccl)))]
_ => panic!("Set \"nvidia\" feature to enablel nvidia support."),
}
async fn typed<M>(self, meta: M::Meta)
where
M: CausalLM + Send + Sync + 'static,
M::Storage: Send,
M::Error: Debug,
{
let (mut service, _handle) = Service::<M>::load(&self.inference.model, meta);
service.default_sample = self.inference.sample_args();
start_infer_service(service, self.port).await.unwrap();
}
}

0 comments on commit 075d261

Please sign in to comment.