Skip to content

Commit

Permalink
Revert back change to signature
Browse files Browse the repository at this point in the history
  • Loading branch information
pablomlago committed Nov 5, 2024
1 parent b97babd commit c9c8206
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 3 deletions.
3 changes: 1 addition & 2 deletions src/brevitas/nn/quant_mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,8 +602,7 @@ def forward(
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
average_attn_weights: bool = True,
position_ids: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
average_attn_weights: bool = True) -> Tuple[Tensor, Optional[Tensor]]:
r"""
Args:
query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False``
Expand Down
1 change: 0 additions & 1 deletion src/brevitas_examples/llm/llm_quant/mha_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,6 @@ def forward(
attn_mask=attention_mask,
need_weights=output_attentions,
average_attn_weights=False,
position_ids=position_ids,
)
past_key_value = None
return attn_output, attn_output_weights, past_key_value

0 comments on commit c9c8206

Please sign in to comment.