Skip to content

Commit

Permalink
fix(xtask): nvidia 启用时退出需要同步以免释放异常
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Apr 25, 2024
1 parent 0abf47a commit 00778a2
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 3 deletions.
2 changes: 1 addition & 1 deletion nvidia/common/src/rms_norm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ extern "C" __global__ void {folding}(

let (ptx, log) = Ptx::compile(code);
if !log.is_empty() {
println!("{log}");
warn!("{log}");
}
Self {
ptx: ptx.unwrap(),
Expand Down
2 changes: 1 addition & 1 deletion nvidia/transformer/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use std::{
};
use transformer::{Kernels, Llama2, Memory};

pub use common_nv::cuda;
pub use common_nv::{cuda, synchronize};

pub struct Transformer {
host: Memory,
Expand Down
1 change: 0 additions & 1 deletion xtask/src/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ impl InferenceArgs {
#[cfg(feature = "nvidia")]
&[n] => {
use transformer_nv::{cuda, Transformer as M};
cuda::init();
chat!(M; cuda::Device::new(n));
}
#[cfg(feature = "nvidia")]
Expand Down
8 changes: 8 additions & 0 deletions xtask/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,17 @@ fn main() {

#[inline]
fn block_on(f: impl Future) {
#[cfg(feature = "nvidia")]
{
transformer_nv::cuda::init();
}
let runtime = tokio::runtime::Runtime::new().unwrap();
runtime.block_on(f);
runtime.shutdown_background();
#[cfg(feature = "nvidia")]
{
transformer_nv::synchronize();
}
}

#[derive(Parser)]
Expand Down

0 comments on commit 00778a2

Please sign in to comment.