Skip to content

Commit

Permalink
chore: remove unused vars
Browse files Browse the repository at this point in the history
  • Loading branch information
dancixx committed Nov 3, 2024
1 parent 8fb7ce2 commit 3e086dc
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions src/ai/fou/fou_vae.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::cell::RefCell;

use candle_core::{DType, Device, IndexOp, Result, Tensor};
use candle_core::{DType, Device, Result, Tensor};
use candle_nn::{
layer_norm, linear, linear_no_bias, seq, Activation, Dropout, LayerNorm, LayerNormConfig, Linear,
Module, Sequential, VarBuilder,
Expand Down Expand Up @@ -106,6 +106,7 @@ impl MultiHeadAttention {
let value = linear(n_state, n_state, vb.pp("v_proj"))?;
let key = linear_no_bias(n_state, n_state, vb.pp("k_proj"))?;
let out = linear(n_state, n_state, vb.pp("out_proj"))?;

Ok(Self {
query,
key,
Expand Down Expand Up @@ -150,6 +151,7 @@ impl MultiHeadAttention {
};
let wv = self.qkv_attention(&q, &k, &v, mask)?;
let out = self.out.forward(&wv)?;

Ok(out)
}

Expand All @@ -176,7 +178,10 @@ impl MultiHeadAttention {
q.matmul(&k)?
};
if let Some(mask) = mask {
let mask = mask.i((0..n_ctx, 0..n_ctx))?;
// TODO: check this
let mask = mask
.unsqueeze(1)?
.expand(&[qk.shape().dims()[0], self.n_head, n_ctx, n_ctx])?;
qk = qk.broadcast_add(&mask)?
}
let w = {
Expand Down Expand Up @@ -448,9 +453,8 @@ mod tests {
let latent_dim = 16;
let dropout_rate = 0.1;

let device = &Device::Cpu;
let varmap = VarMap::new();
let vs = VarBuilder::from_varmap(&varmap, DType::F64, device);
let vs = VarBuilder::from_varmap(&varmap, DType::F64, &Device::Cpu);

let model = TransformerVAE::new(
input_dim,
Expand Down

0 comments on commit 3e086dc

Please sign in to comment.