diff --git a/src/ai/fou/fou_vae.rs b/src/ai/fou/fou_vae.rs index 82ec972..59e26e6 100644 --- a/src/ai/fou/fou_vae.rs +++ b/src/ai/fou/fou_vae.rs @@ -1,6 +1,6 @@ use std::cell::RefCell; -use candle_core::{DType, Device, IndexOp, Result, Tensor}; +use candle_core::{DType, Device, Result, Tensor}; use candle_nn::{ layer_norm, linear, linear_no_bias, seq, Activation, Dropout, LayerNorm, LayerNormConfig, Linear, Module, Sequential, VarBuilder, @@ -106,6 +106,7 @@ impl MultiHeadAttention { let value = linear(n_state, n_state, vb.pp("v_proj"))?; let key = linear_no_bias(n_state, n_state, vb.pp("k_proj"))?; let out = linear(n_state, n_state, vb.pp("out_proj"))?; + Ok(Self { query, key, @@ -150,6 +151,7 @@ impl MultiHeadAttention { }; let wv = self.qkv_attention(&q, &k, &v, mask)?; let out = self.out.forward(&wv)?; + Ok(out) } @@ -176,7 +178,10 @@ impl MultiHeadAttention { q.matmul(&k)? }; if let Some(mask) = mask { - let mask = mask.i((0..n_ctx, 0..n_ctx))?; + // TODO: check this + let mask = mask + .unsqueeze(1)? + .expand(&[qk.shape().dims()[0], self.n_head, n_ctx, n_ctx])?; qk = qk.broadcast_add(&mask)? } let w = { @@ -448,9 +453,8 @@ mod tests { let latent_dim = 16; let dropout_rate = 0.1; - let device = &Device::Cpu; let varmap = VarMap::new(); - let vs = VarBuilder::from_varmap(&varmap, DType::F64, device); + let vs = VarBuilder::from_varmap(&varmap, DType::F64, &Device::Cpu); let model = TransformerVAE::new( input_dim,