diff --git a/nvidia/common/src/lib.rs b/nvidia/common/src/lib.rs index 7bb4171e..14460b16 100644 --- a/nvidia/common/src/lib.rs +++ b/nvidia/common/src/lib.rs @@ -10,11 +10,9 @@ mod mat_mul; mod reform; mod rms_norm; mod rotary_embedding; -mod storage; mod swiglu; pub use common::utok; -pub use storage::{tensor, Cache}; pub use tensor::{slice, udim, DataType, LocalSplitable, Tensor}; use cublas::{Cublas, CublasSpore}; diff --git a/nvidia/common/src/storage.rs b/nvidia/common/src/storage.rs deleted file mode 100644 index fb79b587..00000000 --- a/nvidia/common/src/storage.rs +++ /dev/null @@ -1,24 +0,0 @@ -use cuda::{Context, ContextSpore, DevMem, DevMemSpore, Stream}; -use std::sync::Arc; -use tensor::{udim, DataType, LocalSplitable, Tensor}; - -pub struct Cache { - pub context: Arc, - pub mem: DevMemSpore, -} - -impl Drop for Cache { - #[inline] - fn drop(&mut self) { - self.context.apply(|ctx| unsafe { self.mem.kill(ctx) }); - } -} - -#[inline] -pub fn tensor<'ctx>( - dt: DataType, - shape: &[udim], - stream: &Stream<'ctx>, -) -> Tensor>> { - Tensor::alloc(dt, shape, |l| stream.malloc::(l).into()) -} diff --git a/nvidia/distributed/src/gather.rs b/nvidia/distributed/src/gather.rs index cd9815a5..85ab8bc6 100644 --- a/nvidia/distributed/src/gather.rs +++ b/nvidia/distributed/src/gather.rs @@ -1,18 +1,17 @@ use common_nv::{ - cuda::{ContextSpore, DevByte, StreamSpore}, + cuda::{ContextSpore, DevMemSpore, StreamSpore}, utok, Tensor, }; -use std::ops::{Deref, DerefMut}; +use std::ops::Deref; -pub fn gather( - x: &mut [Tensor], - table: &Tensor, +pub fn gather( + x: &mut [Tensor], + table: &Tensor, tokens: I, comms: &nccl::CommunicatorGroup, streams: &[StreamSpore], ) where - T: DerefMut, - U: Deref, + T: Deref, I: IntoIterator, { assert!(!x.is_empty()); @@ -41,12 +40,12 @@ pub fn gather( for (i, comm) in comms.call().iter().enumerate() { comm.device().retain_primary().apply(|ctx| { let stream = unsafe { streams[i].sprout(ctx) }; - let dst = &mut **x[i].physical_mut(); + let mut dst = unsafe { x[i].physical_mut().sprout(ctx) }; for _ in 0..distributed { let Some((i, t)) = iter.next() else { break }; stream.memcpy_h2d(&mut dst[d * i..][..d], &table[d * t..][..d]); } - comm.all_gather(dst, None, &stream); + comm.all_gather(&mut dst, None, &stream); }); } } diff --git a/nvidia/distributed/src/lib.rs b/nvidia/distributed/src/lib.rs index fcaa8aaa..8c5b92d7 100644 --- a/nvidia/distributed/src/lib.rs +++ b/nvidia/distributed/src/lib.rs @@ -1,5 +1,4 @@ #![cfg(detected_nccl)] -#![allow(unused)] mod gather; mod parameters; @@ -10,8 +9,8 @@ extern crate log; pub use common_nv::cuda; use common_nv::{ - cuda::{AsRaw, ContextResource, ContextSpore, CudaDataType, DevMem, Device, StreamSpore}, - tensor, udim, utok, DataType, LocalSplitable, Tensor, + cuda::{AsRaw, ContextResource, ContextSpore, DevMemSpore, Device, StreamSpore}, + udim, utok, Tensor, }; use nccl::CommunicatorGroup; use parameters::ParameterMatrix; @@ -49,8 +48,24 @@ impl transformer::Transformer for Transformer { .map(|ctx| ctx.apply(|c| c.stream().sporulate())) .collect::>(); - for (context, mut stream) in zip(contexts, streams) { - context.apply(|ctx| unsafe { stream.kill(ctx) }); + let x0 = self.token_embed(&requests, &streams); + let x1 = zip(zip(&contexts, &streams), &x0) + .map(|((context, stream), x)| { + context.apply(|ctx| { + let stream = unsafe { stream.sprout(ctx) }; + Tensor::alloc(x.data_type(), x.shape(), |len| { + stream.malloc::(len).sporulate() + }) + }) + }) + .collect::>(); + + for ((context, mut stream), (mut x0, mut x1)) in zip(zip(contexts, streams), zip(x0, x1)) { + context.apply(|ctx| unsafe { + stream.kill(ctx); + x0.physical_mut().kill(ctx); + x1.physical_mut().kill(ctx); + }); } todo!() } @@ -85,11 +100,11 @@ impl Transformer { } } - fn token_embed<'ctx, Id>( + fn token_embed( &self, requests: &[Request], streams: &[StreamSpore], - ) -> Vec>>> { + ) -> Vec> { let dt = self.host.data_type(); let nt = requests.iter().map(Request::seq_len).sum::(); let d = self.host.hidden_size() as udim; @@ -98,7 +113,13 @@ impl Transformer { .comms .contexts() .zip(streams) - .map(|(context, stream)| context.apply(|ctx| tensor(dt, &[nt, d], todo!()))) + .map(|(context, stream)| { + context.apply(|ctx| { + Tensor::alloc(dt, &[nt, d], |len| { + unsafe { stream.sprout(ctx) }.malloc::(len).sporulate() + }) + }) + }) .collect::>(); let tokens = requests.iter().flat_map(Request::tokens).copied(); @@ -122,12 +143,28 @@ impl Drop for Transformer { } } -fn convert(dt: DataType) -> CudaDataType { - match dt { - DataType::F16 => CudaDataType::f16, - DataType::BF16 => CudaDataType::bf16, - DataType::F32 => CudaDataType::f32, - DataType::F64 => CudaDataType::f64, - _ => unreachable!(), +#[test] +fn test() { + use common_nv::cuda::{self, Device}; + use log::LevelFilter::Trace; + use simple_logger::SimpleLogger; + use transformer::Transformer as _; + + const N: usize = 1; + + cuda::init(); + if Device::count() < N { + return; } + + SimpleLogger::new().with_level(Trace).init().unwrap(); + + let time = Instant::now(); + let transformer = Transformer::new( + "../../../TinyLlama-1.1B-Chat-v1.0_F16", + &[Device::fetch().unwrap()], + ); + info!("load {:?}", time.elapsed()); + + transformer.decode(vec![Request::new(0, &[1, 2, 3], &mut [], 0, true)]); } diff --git a/nvidia/distributed/src/parameters.rs b/nvidia/distributed/src/parameters.rs index 448ec0df..19b92f41 100644 --- a/nvidia/distributed/src/parameters.rs +++ b/nvidia/distributed/src/parameters.rs @@ -1,4 +1,6 @@ -use common_nv::{ +#![allow(unused)] + +use common_nv::{ cuda::{Context, ContextGuard, ContextResource, ContextSpore, DevByte, DevMem, DevMemSpore}, udim, Tensor, }; @@ -154,7 +156,7 @@ fn test_load() { SimpleLogger::new().with_level(Trace).init().unwrap(); let time = Instant::now(); - let safetensors = Memory::load_safetensors_from_dir("../../../TinyLlama-1.1B-Chat-v1.0"); + let safetensors = Memory::load_safetensors_from_dir("../../../TinyLlama-1.1B-Chat-v1.0_F16"); info!("mmap {:?}", time.elapsed()); let model = match safetensors { diff --git a/nvidia/transformer/src/lib.rs b/nvidia/transformer/src/lib.rs index d50642cc..9730e7b4 100644 --- a/nvidia/transformer/src/lib.rs +++ b/nvidia/transformer/src/lib.rs @@ -9,8 +9,8 @@ pub use common_nv::cuda; use ::half::f16; use common_nv::{ - cuda::{bindings::CUdeviceptr, memcpy_d2h, DevMem}, - slice, tensor, udim, utok, Cache, DataType, LocalSplitable, NvidiaKernels, Tensor, + cuda::{bindings::CUdeviceptr, memcpy_d2h, DevMem, DevMemSpore}, + slice, udim, utok, DataType, LocalSplitable, NvidiaKernels, Tensor, }; use cuda::{AsRaw, Context, ContextResource, ContextSpore, Device, Stream, StreamSpore}; use parameters::{LayerParameter, LayersParameters, ModelParameters}; @@ -61,15 +61,15 @@ impl transformer::Transformer for Transformer { let compute = ctx.stream(); // 生成词嵌入并预分配空间 let mut x0 = self.token_embed(&requests, &compute); - let mut x1 = tensor(x0.data_type(), x0.shape(), &transfer); + let mut x1 = + Tensor::alloc(x0.data_type(), x0.shape(), |len| transfer.malloc::(len)); let mut buf = LayerBuffer::alloc(&self.host, &requests, |len| { transfer.malloc::(len).into() }); // 生成位置张量 let nt = x0.shape()[0]; // `nt` for number of tokens let pos_ = pos(&requests, nt); - let mut pos = tensor(DataType::U32, &[nt], &transfer); - transfer.memcpy_h2d(pos.physical_mut(), &pos_); + let pos = Tensor::new(DataType::U32, &[nt], transfer.from_host(&pos_)); // 推理 compute.wait_for(&transfer.record()); { @@ -170,13 +170,13 @@ impl Transformer { &self, requests: &[Request], compute: &Stream<'ctx>, - ) -> Tensor> { + ) -> Tensor> { let dt = self.host.data_type(); let nt = requests.iter().map(Request::seq_len).sum::(); let d = self.host.hidden_size() as udim; let kernels = self.kernels.on(compute); - let mut x0 = tensor(dt, &[nt, d], compute); + let mut x0 = Tensor::alloc(dt, &[nt, d], |len| compute.malloc::(len)); let tokens = requests.iter().flat_map(Request::tokens).copied(); kernels.gather(&mut x0, &self.host.embed_tokens(), tokens); // compute.synchronize(); @@ -188,10 +188,10 @@ impl Transformer { fn before_att<'ctx>( &self, params: &LayerParameter, - x0: &Tensor, - x1: &mut Tensor, + x0: &Tensor, + x1: &mut Tensor, qkv: &mut Tensor>, - pos: &Tensor, + pos: &Tensor, compute: &Stream, ) -> ( Tensor>, @@ -240,7 +240,7 @@ impl Transformer { q: Tensor, k: Tensor, v: Tensor, - o: &mut Tensor, + o: &mut Tensor, q_buf: &mut Local, att_buf: &mut Local, compute: &Stream, @@ -324,8 +324,8 @@ impl Transformer { fn after_att( &self, params: &LayerParameter, - x0: &mut Tensor, - x1: &mut Tensor, + x0: &mut Tensor, + x1: &mut Tensor, gate_up: &mut Tensor, compute: &Stream, ) { @@ -366,9 +366,9 @@ impl Transformer { fn move_decode<'ctx, Id>( &self, requests: &[Request], - x0: Tensor>, + x0: Tensor>, compute: &Stream, - ) -> Tensor> { + ) -> Tensor> { let buf = x0.physical().as_ptr() as CUdeviceptr; let len = self.host.hidden_size() * self.host.data_type().size(); @@ -395,7 +395,7 @@ impl Transformer { x0.slice(&[slice![from begin, until dst + 1], slice![all]]) } - fn logits(&self, mut x: Tensor, compute: &Stream) -> Tensor { + fn logits(&self, mut x: Tensor, compute: &Stream) -> Tensor { let dt = self.host.data_type(); let voc = self.host.vocab_size() as udim; @@ -430,6 +430,18 @@ impl Transformer { } } +pub struct Cache { + pub context: Arc, + pub mem: DevMemSpore, +} + +impl Drop for Cache { + #[inline] + fn drop(&mut self) { + self.context.apply(|ctx| unsafe { self.mem.kill(ctx) }); + } +} + #[inline] fn tensor_cache( dt: DataType, diff --git a/tensor/src/split.rs b/tensor/src/split.rs index d845556e..5dcfba1e 100644 --- a/tensor/src/split.rs +++ b/tensor/src/split.rs @@ -21,6 +21,13 @@ impl Splitable for T { #[repr(transparent)] pub struct LocalSplitable(Rc); +impl AsRef for LocalSplitable { + #[inline] + fn as_ref(&self) -> &T { + &self.0 + } +} + impl From for LocalSplitable { #[inline] fn from(t: T) -> Self {