Skip to content

Commit

Permalink
test(transformer-cpu): 添加连续推理测试
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Apr 24, 2024
1 parent 4780ebe commit ec4359b
Showing 1 changed file with 31 additions and 23 deletions.
54 changes: 31 additions & 23 deletions transformer-cpu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -616,29 +616,37 @@ fn test_infer() {
let t1 = Instant::now();
println!("load {:?}", t1 - t0);

const PROMPT: [utok; 5] = [9038, 2501, 263, 931, 29892];

let mut cache = model.new_cache();

let token_embedded = CausalLM::token_embed(&model, PROMPT);

let queries = [QueryContext {
cache: Some(&mut cache),
range: 0..PROMPT.len() as _,
}];
let hidden_state = CausalLM::forward(&model, queries, token_embedded);

let decoding = [DecodingMeta {
num_query: PROMPT.len(),
num_decode: 1,
}];
let logits = CausalLM::decode(&model, decoding, hidden_state);

let args = [SampleMeta {
num_decode: 1,
args: causal_lm::SampleArgs::default(),
}];
let tokens = CausalLM::sample(&model, args, logits);

println!("{:?}", tokens);
let mut prompt: Vec<utok> = vec![
29966, 29989, 1792, 29989, 29958, 13, 29903, 388, 376, 18567, 29908, 304, 592, 21106,
29879, 5299, 29989, 465, 22137, 29989, 29958, 13,
];
let mut pos = 0;

while prompt != &[model.eos_token()] {
let token_embedded = CausalLM::token_embed(&model, prompt.iter().copied());

let queries = [QueryContext {
cache: Some(&mut cache),
range: pos..pos + prompt.len() as upos,
}];
let hidden_state = CausalLM::forward(&model, queries, token_embedded);

let decoding = [DecodingMeta {
num_query: prompt.len(),
num_decode: 1,
}];
let logits = CausalLM::decode(&model, decoding, hidden_state);

let args = [SampleMeta {
num_decode: 1,
args: causal_lm::SampleArgs::default(),
}];
let tokens = CausalLM::sample(&model, args, logits);

println!("{:?}", tokens);
pos += prompt.len() as upos;
prompt = tokens;
}
}

0 comments on commit ec4359b

Please sign in to comment.