From 00778a203e63e0bc902a9cf1c9556df266e7c933 Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Thu, 25 Apr 2024 16:31:03 +0800 Subject: [PATCH] =?UTF-8?q?fix(xtask):=20nvidia=20=E5=90=AF=E7=94=A8?= =?UTF-8?q?=E6=97=B6=E9=80=80=E5=87=BA=E9=9C=80=E8=A6=81=E5=90=8C=E6=AD=A5?= =?UTF-8?q?=E4=BB=A5=E5=85=8D=E9=87=8A=E6=94=BE=E5=BC=82=E5=B8=B8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- nvidia/common/src/rms_norm.rs | 2 +- nvidia/transformer/src/lib.rs | 2 +- xtask/src/chat.rs | 1 - xtask/src/main.rs | 8 ++++++++ 4 files changed, 10 insertions(+), 3 deletions(-) diff --git a/nvidia/common/src/rms_norm.rs b/nvidia/common/src/rms_norm.rs index 31be39bc..5a405b44 100644 --- a/nvidia/common/src/rms_norm.rs +++ b/nvidia/common/src/rms_norm.rs @@ -61,7 +61,7 @@ extern "C" __global__ void {folding}( let (ptx, log) = Ptx::compile(code); if !log.is_empty() { - println!("{log}"); + warn!("{log}"); } Self { ptx: ptx.unwrap(), diff --git a/nvidia/transformer/src/lib.rs b/nvidia/transformer/src/lib.rs index d13eadf2..8b734107 100644 --- a/nvidia/transformer/src/lib.rs +++ b/nvidia/transformer/src/lib.rs @@ -20,7 +20,7 @@ use std::{ }; use transformer::{Kernels, Llama2, Memory}; -pub use common_nv::cuda; +pub use common_nv::{cuda, synchronize}; pub struct Transformer { host: Memory, diff --git a/xtask/src/chat.rs b/xtask/src/chat.rs index 93491919..c85c5ddf 100644 --- a/xtask/src/chat.rs +++ b/xtask/src/chat.rs @@ -38,7 +38,6 @@ impl InferenceArgs { #[cfg(feature = "nvidia")] &[n] => { use transformer_nv::{cuda, Transformer as M}; - cuda::init(); chat!(M; cuda::Device::new(n)); } #[cfg(feature = "nvidia")] diff --git a/xtask/src/main.rs b/xtask/src/main.rs index c72bc83a..54632918 100644 --- a/xtask/src/main.rs +++ b/xtask/src/main.rs @@ -19,9 +19,17 @@ fn main() { #[inline] fn block_on(f: impl Future) { + #[cfg(feature = "nvidia")] + { + transformer_nv::cuda::init(); + } let runtime = tokio::runtime::Runtime::new().unwrap(); runtime.block_on(f); runtime.shutdown_background(); + #[cfg(feature = "nvidia")] + { + transformer_nv::synchronize(); + } } #[derive(Parser)]