Skip to content

Commit

Permalink
feat(distributed): 使用 ncclBroadcast 广播 token embedding 以复用 gather
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Apr 16, 2024
1 parent 3a7aaaf commit 86ed0fa
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 29 deletions.
12 changes: 8 additions & 4 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ serde = "1.0"
log = "0.4"
tokio = { version = "1.37", features = ["rt-multi-thread", "sync"] }

cuda = { git = "https://github.com/YdrMaster/cuda-driver", rev = "6bec5f3" }
cublas = { git = "https://github.com/YdrMaster/cuda-driver", rev = "6bec5f3" }
nccl = { git = "https://github.com/YdrMaster/cuda-driver", rev = "6bec5f3" }
search-cuda-tools = { git = "https://github.com/YdrMaster/cuda-driver", rev = "6bec5f3" }
cuda = { git = "https://github.com/YdrMaster/cuda-driver", rev = "d45ab7f" }
cublas = { git = "https://github.com/YdrMaster/cuda-driver", rev = "d45ab7f" }
nccl = { git = "https://github.com/YdrMaster/cuda-driver", rev = "d45ab7f" }
search-cuda-tools = { git = "https://github.com/YdrMaster/cuda-driver", rev = "d45ab7f" }
37 changes: 16 additions & 21 deletions nvidia/distributed/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,27 +84,22 @@ impl transformer::Transformer for Transformer {

// token embedding
let mut x0 = Tensor::alloc(dt, &[nt, d], |len| malloc_all(&contexts, &streams, len));
{
let d = d as usize * dt.size();
let table = self.host.embed_tokens();
let table = table.as_slice();

let iter = requests
.iter()
.flat_map(Request::tokens)
.copied()
.map(|t| t as usize)
.enumerate();

for (i, context) in contexts.iter().enumerate() {
context.apply(|ctx| {
let stream = unsafe { ctx.sprout(&streams[i]) };
let mut dst = unsafe { ctx.sprout(&x0.physical_mut()[i]) };
for (i, t) in iter.clone() {
stream.memcpy_h2d(&mut dst[i * d..][..d], &table[t * d..][..d]);
}
});
}
contexts[0].apply(|ctx| {
let stream = unsafe { ctx.sprout(&streams[0]) };
let kernels = self.kernels[0].on(&stream);
let mut x = unsafe { x0.as_mut().map_physical(|u| ctx.sprout(&u[0])) };
kernels.gather(
&mut x,
&self.host.embed_tokens(),
requests.iter().flat_map(Request::tokens).copied(),
);
});
for (i, comm) in self.comms.call().into_iter().enumerate() {
contexts[i].apply(|ctx| {
let stream = unsafe { ctx.sprout(&streams[i]) };
let mut dst = unsafe { ctx.sprout(&x0.physical_mut()[i]) };
comm.broadcast(&mut dst, None, 0, &stream);
});
}
let mut x1 = Tensor::alloc(dt, &[nt, d], |len| malloc_all(&contexts, &streams, len));
let LayerBuffer {
Expand Down

0 comments on commit 86ed0fa

Please sign in to comment.