Skip to content

Commit

Permalink
test(distributed): 恢复分布式切分的测试
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Apr 30, 2024
1 parent 165b3cd commit 0171f99
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 54 deletions.
45 changes: 22 additions & 23 deletions nvidia/distributed/src/distribute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,26 +204,25 @@ impl DistributeScheme {
}
}

// #[test]
// fn test() {
// use super::Memory;
// use std::time::Instant;

// let Some(model_dir) = common::test_model::find() else {
// return;
// };
// println!("model_dir: {}", model_dir.display());

// let time = Instant::now();
// let model = Memory::load_safetensors(model_dir).unwrap();
// println!("mmap {:?}", time.elapsed());

// let distributer = Distributer::new(&model, 4, 512);
// let time = Instant::now();
// for layer in 0..model.num_hidden_layers() {
// for i in 0..4 {
// let _ = distributer.distribute(layer, i);
// }
// }
// println!("distribute {:?}", time.elapsed());
// }
#[test]
fn test() {
use std::time::Instant;

let Some(model_dir) = common::test_model::find() else {
return;
};
println!("model_dir: {}", model_dir.display());

let time = Instant::now();
let model = llama::Storage::load_safetensors(model_dir).unwrap();
println!("mmap {:?}", time.elapsed());

let distributer = Distributer::new(&model, 4, 512);
let time = Instant::now();
for layer in 0..model.config.nlayers as usize {
for i in 0..4 {
let _ = distributer.distribute(layer, i);
}
}
println!("distribute {:?}", time.elapsed());
}
58 changes: 29 additions & 29 deletions nvidia/distributed/src/parameters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,32 +133,32 @@ impl Layer<'_> {
}
}

// #[test]
// fn test_load() {
// use common_nv::cuda::{self, Device};
// use log::LevelFilter::Trace;
// use simple_logger::SimpleLogger;

// let Some(model_dir) = common_nv::test_model::find() else {
// return;
// };
// println!("model_dir: {}", model_dir.display());

// const N: usize = 1;

// cuda::init();
// if Device::count() < N {
// return;
// }

// SimpleLogger::new().with_level(Trace).init().unwrap();

// let time = Instant::now();
// let model = Memory::load_safetensors(model_dir).unwrap();
// info!("mmap {:?}", time.elapsed());

// let contexts = (0..N as _)
// .map(|i| Device::new(i).retain_primary())
// .collect::<Vec<_>>();
// unsafe { ParameterMatrix::load(&model, &contexts).kill(&contexts) };
// }
#[test]
fn test_load() {
use common_nv::cuda::{self, Device};
use log::LevelFilter::Trace;
use simple_logger::SimpleLogger;

let Some(model_dir) = common::test_model::find() else {
return;
};
println!("model_dir: {}", model_dir.display());

const N: usize = 1;

cuda::init();
if Device::count() < N {
return;
}

SimpleLogger::new().with_level(Trace).init().unwrap();

let time = Instant::now();
let model = llama::Storage::load_safetensors(model_dir).unwrap();
info!("mmap {:?}", time.elapsed());

let contexts = (0..N as _)
.map(|i| Device::new(i).retain_primary())
.collect::<Vec<_>>();
unsafe { ParameterMatrix::load(&model, &contexts).kill(&contexts) };
}
27 changes: 25 additions & 2 deletions nvidia/transformer/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,12 @@ pub struct Transformer {
transfer: StreamSpore,
compute: StreamSpore,
kernels: NvidiaKernels,

embed_tokens: Tensor<HostMemSpore>,
layers: Vec<LayerStorage<HostMemSpore>>,
lm_layernorm: Tensor<DevMemSpore>,
lm_head: Tensor<DevMemSpore>,

pool: Mutex<VecDeque<(LayerStorage<DevMemSpore>, EventSpore)>>,
}

Expand Down Expand Up @@ -289,8 +291,29 @@ impl Drop for Transformer {
#[inline]
fn drop(&mut self) {
self.context.apply(|ctx| unsafe {
self.transfer.kill(ctx);
self.compute.kill(ctx);
ctx.kill(&mut self.transfer);
ctx.kill(&mut self.compute);
ctx.kill(self.embed_tokens.physical_mut());
ctx.kill(self.lm_layernorm.physical_mut());
ctx.kill(self.lm_head.physical_mut());
for layer in self.layers.iter_mut() {
ctx.kill(layer.att_layernorm.physical_mut());
ctx.kill(layer.att_qkv.physical_mut());
ctx.kill(layer.att_o.physical_mut());
ctx.kill(layer.mlp_layernorm.physical_mut());
ctx.kill(layer.mlp_gate_up.physical_mut());
ctx.kill(layer.mlp_down.physical_mut());
}
let mut pool = self.pool.lock().unwrap();
while let Some((mut layer, mut event)) = pool.pop_front() {
ctx.kill(layer.att_layernorm.physical_mut());
ctx.kill(layer.att_qkv.physical_mut());
ctx.kill(layer.att_o.physical_mut());
ctx.kill(layer.mlp_layernorm.physical_mut());
ctx.kill(layer.mlp_gate_up.physical_mut());
ctx.kill(layer.mlp_down.physical_mut());
ctx.kill(&mut event);
}
self.kernels.kill(ctx);
});
}
Expand Down

0 comments on commit 0171f99

Please sign in to comment.