Skip to content

Commit

Permalink
fix(xtask): 计算正确的最大步数
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Feb 27, 2024
1 parent 1b34420 commit 500f879
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
5 changes: 4 additions & 1 deletion transformer-nvidia/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,21 @@ impl<'a> Transformer<'a> {
// println!("tokens: {tokens:?}");

let mut x0 = tensor(dt, &[seq_len, d], transfer);
let e0 = transfer.record();
let mut x1 = tensor(dt, &[seq_len, d], transfer);
// `seq_len x hidden_size` -reshape-> `seq_len x (num_kv_head x head_group x head_dim)` -transpose(1,2,0,3)-> `num_kv_head x head_group x seq_len x head_dim` -reshape-> `num_kv_head x (head_group x seq_len) x head_dim`
let mut x2 = tensor(dt, &[nkvh, head_group * seq_len, dh], transfer);
let mut qkv = tensor(dt, &[seq_len, d + dkv + dkv], transfer);
let mut q_att = tensor(dt, &[nh, seq_len, dh], transfer);
let mut att = tensor(dt, &[nkvh, head_group * seq_len, att_len], transfer);
let mut gate_up = tensor(dt, &[seq_len, di + di], transfer);
transfer.synchronize();
let e_alloc = transfer.record();

e0.synchronize();
// gather(&mut x0.access_mut(), &self.model.embed_tokens(), tokens);
// println!("gather:\n{}", x0.access());

e_alloc.synchronize();
for layer in 0..self.host.num_hidden_layers() {}
}
}
Expand Down
1 change: 1 addition & 0 deletions xtask/src/generate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ fn on_nvidia_gpu(
let transformer = Transformer::new(&host, &cpy);
info!("build model host: {:?}", time.elapsed());

let step = step.min(host.max_position_embeddings());
let time = Instant::now();
let prompt_tokens = tokenizer.encode(&prompt.trim().replace(' ', "▁"));
info!("encode prompt ... {:?}", time.elapsed());
Expand Down

0 comments on commit 500f879

Please sign in to comment.