Skip to content

Commit

Permalink
doc fix for mha weights (#1205)
Browse files Browse the repository at this point in the history
  • Loading branch information
ashdtu authored Jan 31, 2024
1 parent 0ec0fba commit c620176
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions burn-core/src/nn/attention/mha.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use libm::sqrtf;
/// Configuration to create a [Multi Head Attention](MultiHeadAttention) layer.
#[derive(Config)]
pub struct MultiHeadAttentionConfig {
/// The size of the each linear layer.
/// The size of each linear layer.
d_model: usize,
/// The number of heads.
n_heads: usize,
Expand Down Expand Up @@ -160,7 +160,7 @@ impl<B: Backend> MhaInput<B> {
/// [Multihead attention](MultiHeadAttention) outputs.
#[derive(Debug, Clone)]
pub struct MhaOutput<B: Backend> {
/// The attention weights [batch_size, seq_length_1, seq_length_2].
/// The attention weights [batch_size, n_heads, seq_length_1, seq_length_2].
pub weights: Tensor<B, 4>,
/// The context tensor [batch_size, seq_length_1, d_model].
pub context: Tensor<B, 3>,
Expand Down

0 comments on commit c620176

Please sign in to comment.