diff --git a/models/clip/common/src/compute.rs b/models/clip/common/src/compute.rs index 1c96746..ea97f68 100644 --- a/models/clip/common/src/compute.rs +++ b/models/clip/common/src/compute.rs @@ -112,7 +112,7 @@ impl ClipWorker { let qkv = Tensor::new(dt, &[np * (nh + nkvh + nkvh), dh]).take(); let q = Tensor::new(dt, &[np, nh, dh]).take(); - let att = Tensor::new(dt, &[nkvh, np, np]).take(); + let att = Tensor::new(dt, &[nh, np, np]).take(); let up = Tensor::new(dt, &[np, di]).take(); embd + (qkv + q + att).max(up) diff --git a/models/gpt2/common/src/compute.rs b/models/gpt2/common/src/compute.rs index cfacc0b..76b40c0 100644 --- a/models/gpt2/common/src/compute.rs +++ b/models/gpt2/common/src/compute.rs @@ -107,7 +107,7 @@ impl Gpt2Worker { let qkv = Tensor::new(dt, &[nt * (nh + nkvh + nkvh), dh]).take(); let q = Tensor::new(dt, &[max_seq_len, nh, dh]).take(); - let att = Tensor::new(dt, &[nkvh, max_seq_len, max_att_len]).take(); + let att = Tensor::new(dt, &[nh, max_seq_len, max_att_len]).take(); let up = Tensor::new(dt, &[nt, di]).take(); embd + (qkv + q + att).max(up) diff --git a/models/llama/common/src/compute.rs b/models/llama/common/src/compute.rs index 8087524..706e4d8 100644 --- a/models/llama/common/src/compute.rs +++ b/models/llama/common/src/compute.rs @@ -141,7 +141,7 @@ impl LlamaWorker { let qkv = Tensor::new(dt, &[nt * (nh + nkvh + nkvh), dh]).take(); let q = Tensor::new(dt, &[max_seq_len, nh, dh]).take(); - let att = Tensor::new(dt, &[nkvh, max_seq_len, max_att_len]).take(); + let att = Tensor::new(dt, &[nh, max_seq_len, max_att_len]).take(); if self.meta.is_moe() { let routes = Tensor::new(dt, &[nt, nexp]).take();