diff --git a/nvidia/common/src/rms_norm.rs b/nvidia/common/src/rms_norm.rs index 31be39bc..5a405b44 100644 --- a/nvidia/common/src/rms_norm.rs +++ b/nvidia/common/src/rms_norm.rs @@ -61,7 +61,7 @@ extern "C" __global__ void {folding}( let (ptx, log) = Ptx::compile(code); if !log.is_empty() { - println!("{log}"); + warn!("{log}"); } Self { ptx: ptx.unwrap(), diff --git a/nvidia/transformer/src/lib.rs b/nvidia/transformer/src/lib.rs index d13eadf2..8b734107 100644 --- a/nvidia/transformer/src/lib.rs +++ b/nvidia/transformer/src/lib.rs @@ -20,7 +20,7 @@ use std::{ }; use transformer::{Kernels, Llama2, Memory}; -pub use common_nv::cuda; +pub use common_nv::{cuda, synchronize}; pub struct Transformer { host: Memory, diff --git a/xtask/src/chat.rs b/xtask/src/chat.rs index 93491919..c85c5ddf 100644 --- a/xtask/src/chat.rs +++ b/xtask/src/chat.rs @@ -38,7 +38,6 @@ impl InferenceArgs { #[cfg(feature = "nvidia")] &[n] => { use transformer_nv::{cuda, Transformer as M}; - cuda::init(); chat!(M; cuda::Device::new(n)); } #[cfg(feature = "nvidia")] diff --git a/xtask/src/main.rs b/xtask/src/main.rs index c72bc83a..54632918 100644 --- a/xtask/src/main.rs +++ b/xtask/src/main.rs @@ -19,9 +19,17 @@ fn main() { #[inline] fn block_on(f: impl Future) { + #[cfg(feature = "nvidia")] + { + transformer_nv::cuda::init(); + } let runtime = tokio::runtime::Runtime::new().unwrap(); runtime.block_on(f); runtime.shutdown_background(); + #[cfg(feature = "nvidia")] + { + transformer_nv::synchronize(); + } } #[derive(Parser)]