diff --git a/Cargo.lock b/Cargo.lock index 4caab576..3f00cd2e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1514,9 +1514,9 @@ dependencies = [ [[package]] name = "safetensors" -version = "0.4.2" +version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d980e6bfb34436fb0a81e42bc41af43f11805bbbca443e7f68e9faaabe669ed" +checksum = "8ced76b22c7fba1162f11a5a75d9d8405264b467a07ae0c9c29be119b9297db9" dependencies = [ "serde", "serde_json", @@ -2076,9 +2076,9 @@ dependencies = [ [[package]] name = "wide" -version = "0.7.15" +version = "0.7.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89beec544f246e679fc25490e3f8e08003bc4bf612068f325120dad4cea02c1c" +checksum = "81a1851a719f11d1d2fea40e15c72f6c00de8c142d7ac47c1441cc7e4d0d5bc6" dependencies = [ "bytemuck", "safe_arch", diff --git a/Cargo.toml b/Cargo.toml index a2418a38..f089ac90 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,7 +21,7 @@ serde = "1.0" log = "0.4" tokio = { version = "1.37", features = ["rt-multi-thread", "sync"] } -cuda = { git = "https://github.com/YdrMaster/cuda-driver", rev = "6bec5f3" } -cublas = { git = "https://github.com/YdrMaster/cuda-driver", rev = "6bec5f3" } -nccl = { git = "https://github.com/YdrMaster/cuda-driver", rev = "6bec5f3" } -search-cuda-tools = { git = "https://github.com/YdrMaster/cuda-driver", rev = "6bec5f3" } +cuda = { git = "https://github.com/YdrMaster/cuda-driver", rev = "d45ab7f" } +cublas = { git = "https://github.com/YdrMaster/cuda-driver", rev = "d45ab7f" } +nccl = { git = "https://github.com/YdrMaster/cuda-driver", rev = "d45ab7f" } +search-cuda-tools = { git = "https://github.com/YdrMaster/cuda-driver", rev = "d45ab7f" } diff --git a/nvidia/distributed/src/lib.rs b/nvidia/distributed/src/lib.rs index 511783fa..5eac725b 100644 --- a/nvidia/distributed/src/lib.rs +++ b/nvidia/distributed/src/lib.rs @@ -84,27 +84,22 @@ impl transformer::Transformer for Transformer { // token embedding let mut x0 = Tensor::alloc(dt, &[nt, d], |len| malloc_all(&contexts, &streams, len)); - { - let d = d as usize * dt.size(); - let table = self.host.embed_tokens(); - let table = table.as_slice(); - - let iter = requests - .iter() - .flat_map(Request::tokens) - .copied() - .map(|t| t as usize) - .enumerate(); - - for (i, context) in contexts.iter().enumerate() { - context.apply(|ctx| { - let stream = unsafe { ctx.sprout(&streams[i]) }; - let mut dst = unsafe { ctx.sprout(&x0.physical_mut()[i]) }; - for (i, t) in iter.clone() { - stream.memcpy_h2d(&mut dst[i * d..][..d], &table[t * d..][..d]); - } - }); - } + contexts[0].apply(|ctx| { + let stream = unsafe { ctx.sprout(&streams[0]) }; + let kernels = self.kernels[0].on(&stream); + let mut x = unsafe { x0.as_mut().map_physical(|u| ctx.sprout(&u[0])) }; + kernels.gather( + &mut x, + &self.host.embed_tokens(), + requests.iter().flat_map(Request::tokens).copied(), + ); + }); + for (i, comm) in self.comms.call().into_iter().enumerate() { + contexts[i].apply(|ctx| { + let stream = unsafe { ctx.sprout(&streams[i]) }; + let mut dst = unsafe { ctx.sprout(&x0.physical_mut()[i]) }; + comm.broadcast(&mut dst, None, 0, &stream); + }); } let mut x1 = Tensor::alloc(dt, &[nt, d], |len| malloc_all(&contexts, &streams, len)); let LayerBuffer {