Skip to content

Commit

Permalink
fix(xtask): 检测系统 cuda、nccl 支持
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Apr 26, 2024
1 parent 079d89d commit de94024
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 22 deletions.
38 changes: 24 additions & 14 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions xtask/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ simple_logger = "4.3"
colored = "2.1"
clap = { version = "4.5", features = ["derive"] }

[build-dependencies]
search-cuda-tools.workspace = true

[features]
default = ["nvidia"]
nvidia = ["transformer-nv"]
12 changes: 12 additions & 0 deletions xtask/build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
fn main() {
use search_cuda_tools::*;
if !cfg!(feature = "nvidia") {
return;
}
if find_cuda_root().is_some() {
detect_cuda();
}
if find_nccl_root().is_some() {
detect_nccl();
}
}
6 changes: 3 additions & 3 deletions xtask/src/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ impl crate::InferenceArgs {
use transformer_cpu::Transformer as M;
chat!(M; ());
}
#[cfg(feature = "nvidia")]
#[cfg(detected_cuda)]
&[n] => {
use transformer_nv::{cuda, Transformer as M};
chat!(M; cuda::Device::new(n));
}
#[cfg(feature = "nvidia")]
#[cfg(detected_nccl)]
_distribute => todo!(),
#[cfg(not(feature = "nvidia"))]
#[cfg(not(all(detected_cuda, detected_nccl)))]
_ => panic!("Set \"nvidia\" feature to enablel nvidia support."),
}
}
Expand Down
4 changes: 2 additions & 2 deletions xtask/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@ fn main() {

#[inline]
fn block_on(f: impl Future) {
#[cfg(feature = "nvidia")]
#[cfg(detected_cuda)]
{
transformer_nv::cuda::init();
}
let runtime = tokio::runtime::Runtime::new().unwrap();
runtime.block_on(f);
runtime.shutdown_background();
#[cfg(feature = "nvidia")]
#[cfg(detected_cuda)]
{
transformer_nv::synchronize();
}
Expand Down
6 changes: 3 additions & 3 deletions xtask/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ impl ServiceArgs {
use transformer_cpu::Transformer as M;
serve!(M; ());
}
#[cfg(feature = "nvidia")]
#[cfg(detected_cuda)]
&[n] => {
use transformer_nv::{cuda, Transformer as M};
serve!(M; cuda::Device::new(n));
}
#[cfg(feature = "nvidia")]
#[cfg(detected_nccl)]
_distribute => todo!(),
#[cfg(not(feature = "nvidia"))]
#[cfg(not(all(detected_cuda, detected_nccl)))]
_ => panic!("Set \"nvidia\" feature to enablel nvidia support."),
}
}
Expand Down

0 comments on commit de94024

Please sign in to comment.