Skip to content

Commit

Permalink
test(causal-lm): causal-lm 共用一份测试代码
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Apr 28, 2024
1 parent 62e3cc6 commit 39e6d24
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 160 deletions.
50 changes: 50 additions & 0 deletions causal-lm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,53 @@ pub fn pos<'a, S: 'a>(
}
Tensor::new(tensor::DataType::U32, &[ans.len() as _], ans)
}

/// 测试模型实现。
pub fn test_impl<M>(meta: M::Meta, prompt: &[utok])
where
M: CausalLM,
M::Error: std::fmt::Debug,
{
use std::time::Instant;

let Some(model_dir) = common::test_model::find() else {
return;
};
println!("model_dir: {}", model_dir.display());

let t0 = Instant::now();
let model = M::load(model_dir, meta).unwrap();
let t1 = Instant::now();
println!("load {:?}", t1 - t0);

let mut cache = model.new_cache();

let mut prompt = prompt.to_vec();
let mut pos = 0;

while prompt != &[model.eos_token()] {
let token_embedded = CausalLM::token_embed(&model, prompt.iter().copied());

let queries = [QueryContext {
cache: Some(&mut cache),
range: pos..pos + prompt.len() as upos,
}];
let hidden_state = CausalLM::forward(&model, queries, token_embedded);

let decoding = [DecodingMeta {
num_query: prompt.len(),
num_decode: 1,
}];
let logits = CausalLM::decode(&model, decoding, hidden_state);

let args = [SampleMeta {
num_decode: 1,
args: SampleArgs::default(),
}];
let tokens = CausalLM::sample(&model, args, logits);

println!("{:?}", tokens);
pos += prompt.len() as upos;
prompt = tokens;
}
}
66 changes: 11 additions & 55 deletions nvidia/distributed/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ use common_nv::{
cuda::{
memcpy_d2h, AsRaw, Context, ContextResource, ContextSpore, DevMemSpore, Device, StreamSpore,
},
f16, slice, split, udim, upos, utok, DataType, Kernels, LocalSplitable, NvidiaKernels,
NvidiaKernelsPtx, SafeTensorsError, Tensor,
f16, slice, split, udim, upos, utok, DataType, FileLoadError, Kernels, LocalSplitable,
NvidiaKernels, NvidiaKernelsPtx, Tensor,
};
use itertools::izip;
use nccl::CommunicatorGroup;
Expand Down Expand Up @@ -40,7 +40,7 @@ pub struct Transformer {

impl Model for Transformer {
type Meta = Vec<Device>;
type Error = SafeTensorsError;
type Error = FileLoadError;

#[inline]
fn load(model_dir: impl AsRef<Path>, meta: Self::Meta) -> Result<Self, Self::Error> {
Expand Down Expand Up @@ -596,58 +596,14 @@ fn malloc_all(contexts: &[Context], len: usize) -> Vec<DevMemSpore> {

#[test]
fn test_infer() {
use std::time::Instant;

let Some(model_dir) = common_nv::test_model::find() else {
return;
};
println!("model_dir: {}", model_dir.display());

cuda::init();
if cuda::Device::count() < 2 {
return;
}

let t0 = Instant::now();
let model = <Transformer as Model>::load(
model_dir,
[0, 1].map(cuda::Device::new).into_iter().collect(),
)
.unwrap();
let t1 = Instant::now();
println!("load {:?}", t1 - t0);

let mut cache = model.new_cache();

let mut prompt: Vec<utok> = vec![
29966, 29989, 1792, 29989, 29958, 13, 29903, 388, 376, 18567, 29908, 304, 592, 21106,
29879, 5299, 29989, 465, 22137, 29989, 29958, 13,
];
let mut pos = 0;

while prompt != &[model.eos_token()] {
let token_embedded = CausalLM::token_embed(&model, prompt.iter().copied());

let queries = [QueryContext {
cache: Some(&mut cache),
range: pos..pos + prompt.len() as upos,
}];
let hidden_state = CausalLM::forward(&model, queries, token_embedded);

let decoding = [DecodingMeta {
num_query: prompt.len(),
num_decode: 1,
}];
let logits = CausalLM::decode(&model, decoding, hidden_state);

let args = [SampleMeta {
num_decode: 1,
args: causal_lm::SampleArgs::default(),
}];
let tokens = CausalLM::sample(&model, args, logits);

println!("{:?}", tokens);
pos += prompt.len() as upos;
prompt = tokens;
if cuda::Device::count() >= 2 {
causal_lm::test_impl::<Transformer>(
[0, 1].map(cuda::Device::new).into_iter().collect(),
&[
29966, 29989, 1792, 29989, 29958, 13, 29903, 388, 376, 18567, 29908, 304, 592,
21106, 29879, 5299, 29989, 465, 22137, 29989, 29958, 13,
],
);
}
}
14 changes: 2 additions & 12 deletions nvidia/distributed/src/parameters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,13 +141,9 @@ impl Layer<'_> {

#[test]
fn test_load() {
use common_nv::{
cuda::{self, Device},
SafeTensorsError,
};
use common_nv::cuda::{self, Device};
use log::LevelFilter::Trace;
use simple_logger::SimpleLogger;
use std::io::ErrorKind::NotFound;
use transformer::Memory;

let Some(model_dir) = common_nv::test_model::find() else {
Expand All @@ -165,15 +161,9 @@ fn test_load() {
SimpleLogger::new().with_level(Trace).init().unwrap();

let time = Instant::now();
let safetensors = Memory::load_safetensors(model_dir);
let model = Memory::load_safetensors(model_dir).unwrap();
info!("mmap {:?}", time.elapsed());

let model = match safetensors {
Ok(m) => m,
Err(SafeTensorsError::Io(e)) if e.kind() == NotFound => return,
Err(e) => panic!("{e:?}"),
};

let contexts = (0..N as _)
.map(|i| Device::new(i).retain_primary())
.collect::<Vec<_>>();
Expand Down
56 changes: 8 additions & 48 deletions nvidia/transformer/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -417,54 +417,14 @@ impl Drop for Cache {

#[test]
fn test_infer() {
use std::time::Instant;

let Some(model_dir) = common_nv::test_model::find() else {
return;
};
println!("model_dir: {}", model_dir.display());

cuda::init();
let Some(device) = cuda::Device::fetch() else {
return;
if let Some(device) = cuda::Device::fetch() {
causal_lm::test_impl::<Transformer>(
device,
&[
29966, 29989, 1792, 29989, 29958, 13, 29903, 388, 376, 18567, 29908, 304, 592,
21106, 29879, 5299, 29989, 465, 22137, 29989, 29958, 13,
],
);
};

let t0 = Instant::now();
let model = <Transformer as Model>::load(model_dir, device).unwrap();
let t1 = Instant::now();
println!("load {:?}", t1 - t0);

let mut cache = model.new_cache();

let mut prompt: Vec<utok> = vec![
29966, 29989, 1792, 29989, 29958, 13, 29903, 388, 376, 18567, 29908, 304, 592, 21106,
29879, 5299, 29989, 465, 22137, 29989, 29958, 13,
];
let mut pos = 0;

while prompt != &[model.eos_token()] {
let token_embedded = CausalLM::token_embed(&model, prompt.iter().copied());

let queries = [QueryContext {
cache: Some(&mut cache),
range: pos..pos + prompt.len() as upos,
}];
let hidden_state = CausalLM::forward(&model, queries, token_embedded);

let decoding = [DecodingMeta {
num_query: prompt.len(),
num_decode: 1,
}];
let logits = CausalLM::decode(&model, decoding, hidden_state);

let args = [SampleMeta {
num_decode: 1,
args: causal_lm::SampleArgs::default(),
}];
let tokens = CausalLM::sample(&model, args, logits);

println!("{:?}", tokens);
pos += prompt.len() as upos;
prompt = tokens;
}
}
52 changes: 7 additions & 45 deletions transformer-cpu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -366,49 +366,11 @@ impl CausalLM for Transformer {

#[test]
fn test_infer() {
use std::time::Instant;

let Some(model_dir) = common::test_model::find() else {
return;
};
println!("model_dir: {}", model_dir.display());

let t0 = Instant::now();
let model = <Transformer as Model>::load(model_dir, ()).unwrap();
let t1 = Instant::now();
println!("load {:?}", t1 - t0);

let mut cache = model.new_cache();

let mut prompt: Vec<utok> = vec![
29966, 29989, 1792, 29989, 29958, 13, 29903, 388, 376, 18567, 29908, 304, 592, 21106,
29879, 5299, 29989, 465, 22137, 29989, 29958, 13,
];
let mut pos = 0;

while prompt != &[model.eos_token()] {
let token_embedded = CausalLM::token_embed(&model, prompt.iter().copied());

let queries = [QueryContext {
cache: Some(&mut cache),
range: pos..pos + prompt.len() as upos,
}];
let hidden_state = CausalLM::forward(&model, queries, token_embedded);

let decoding = [DecodingMeta {
num_query: prompt.len(),
num_decode: 1,
}];
let logits = CausalLM::decode(&model, decoding, hidden_state);

let args = [SampleMeta {
num_decode: 1,
args: causal_lm::SampleArgs::default(),
}];
let tokens = CausalLM::sample(&model, args, logits);

println!("{:?}", tokens);
pos += prompt.len() as upos;
prompt = tokens;
}
causal_lm::test_impl::<Transformer>(
(),
&[
29966, 29989, 1792, 29989, 29958, 13, 29903, 388, 376, 18567, 29908, 304, 592, 21106,
29879, 5299, 29989, 465, 22137, 29989, 29958, 13,
],
);
}

0 comments on commit 39e6d24

Please sign in to comment.