Skip to content

Commit

Permalink
perf(nvidia): 设置内存池阈值,异步释放临时存储
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed May 9, 2024
1 parent 52ccf7e commit 457776f
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 14 deletions.
8 changes: 4 additions & 4 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ memmap2 = "0.9"
rayon = "1.10"
tokio = { version = "1.37", features = ["rt-multi-thread", "sync"] }

cuda = { git = "https://github.com/YdrMaster/cuda-driver", rev = "c3e907f" }
cublas = { git = "https://github.com/YdrMaster/cuda-driver", rev = "c3e907f" }
nccl = { git = "https://github.com/YdrMaster/cuda-driver", rev = "c3e907f" }
search-cuda-tools = { git = "https://github.com/YdrMaster/cuda-driver", rev = "c3e907f" }
cuda = { git = "https://github.com/YdrMaster/cuda-driver", rev = "13c037e" }
cublas = { git = "https://github.com/YdrMaster/cuda-driver", rev = "13c037e" }
nccl = { git = "https://github.com/YdrMaster/cuda-driver", rev = "13c037e" }
search-cuda-tools = { git = "https://github.com/YdrMaster/cuda-driver", rev = "13c037e" }
17 changes: 12 additions & 5 deletions models/llama/nvidia-distributed/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,13 @@ impl Model for Transformer {
info!("load host: {:?}", time.elapsed());

let block_size = meta.iter().map(|dev| dev.max_block_dims().0).min().unwrap();
let contexts = meta.iter().map(Device::retain_primary).collect::<Vec<_>>();
let contexts = meta
.iter()
.map(|dev| {
dev.set_mempool_threshold(u64::MAX);
dev.retain_primary()
})
.collect::<Vec<_>>();
let kernels =
NvidiaKernelsPtx::new(host.config.d as _, host.config.max_seq_len as _, block_size);

Expand Down Expand Up @@ -457,10 +463,11 @@ impl CausalLM for Transformer {
// kill
for (i, context) in contexts.iter().enumerate() {
context.apply(|ctx| unsafe {
ctx.kill(&mut state_buf.physical_mut()[i]);
ctx.kill(&mut q_buf[i]);
ctx.kill(&mut att_buf[i]);
ctx.kill(&mut pos.physical_mut()[i]);
let stream = self.streams[i].sprout(ctx);
state_buf.physical_mut()[i].kill_on(&stream);
pos.physical_mut()[i].kill_on(&stream);
q_buf[i].kill_on(&stream);
att_buf[i].kill_on(&stream);
});
}

Expand Down
5 changes: 4 additions & 1 deletion models/llama/nvidia/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ impl Model for Transformer {
info!("load host: {:?}", time.elapsed());
let load_layers = (load_layers as udim).min(host.config.nlayers);

device.set_mempool_threshold(u64::MAX);
let context = Arc::new(device.retain_primary());
context.apply(|ctx| {
let transfer = ctx.stream();
Expand Down Expand Up @@ -339,9 +340,11 @@ impl Drop for Cache {
fn drop(&mut self) {
self.context.apply(|ctx| unsafe {
if let Some(mut stream) = self.stream.take() {
self.mem.kill_on(&ctx.sprout(&stream));
ctx.kill(&mut stream);
} else {
ctx.kill(&mut self.mem);
}
ctx.kill(&mut self.mem);
});
}
}
Expand Down

0 comments on commit 457776f

Please sign in to comment.