Skip to content

Commit

Permalink
fix: 改正 attention 的 workspace 估算
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Jan 3, 2025
1 parent cf6d8fe commit 87e9391
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion models/clip/common/src/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ impl<Ops: Operators, W> ClipWorker<Ops, W> {

let qkv = Tensor::new(dt, &[np * (nh + nkvh + nkvh), dh]).take();
let q = Tensor::new(dt, &[np, nh, dh]).take();
let att = Tensor::new(dt, &[nkvh, np, np]).take();
let att = Tensor::new(dt, &[nh, np, np]).take();

let up = Tensor::new(dt, &[np, di]).take();
embd + (qkv + q + att).max(up)
Expand Down
2 changes: 1 addition & 1 deletion models/gpt2/common/src/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ impl<Ops: Operators, W> Gpt2Worker<Ops, W> {

let qkv = Tensor::new(dt, &[nt * (nh + nkvh + nkvh), dh]).take();
let q = Tensor::new(dt, &[max_seq_len, nh, dh]).take();
let att = Tensor::new(dt, &[nkvh, max_seq_len, max_att_len]).take();
let att = Tensor::new(dt, &[nh, max_seq_len, max_att_len]).take();

let up = Tensor::new(dt, &[nt, di]).take();
embd + (qkv + q + att).max(up)
Expand Down
2 changes: 1 addition & 1 deletion models/llama/common/src/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ impl<Ops: Operators, W> LlamaWorker<Ops, W> {

let qkv = Tensor::new(dt, &[nt * (nh + nkvh + nkvh), dh]).take();
let q = Tensor::new(dt, &[max_seq_len, nh, dh]).take();
let att = Tensor::new(dt, &[nkvh, max_seq_len, max_att_len]).take();
let att = Tensor::new(dt, &[nh, max_seq_len, max_att_len]).take();

if self.meta.is_moe() {
let routes = Tensor::new(dt, &[nt, nexp]).take();
Expand Down

0 comments on commit 87e9391

Please sign in to comment.