From c9c8206e0a81ffc12f6c4dddd6242546786074b0 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Tue, 5 Nov 2024 16:52:27 +0000 Subject: [PATCH] Revert back change to signature --- src/brevitas/nn/quant_mha.py | 3 +-- src/brevitas_examples/llm/llm_quant/mha_layers.py | 1 - 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/brevitas/nn/quant_mha.py b/src/brevitas/nn/quant_mha.py index 0effdf68e..6720fe280 100644 --- a/src/brevitas/nn/quant_mha.py +++ b/src/brevitas/nn/quant_mha.py @@ -602,8 +602,7 @@ def forward( key_padding_mask: Optional[Tensor] = None, need_weights: bool = True, attn_mask: Optional[Tensor] = None, - average_attn_weights: bool = True, - position_ids: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]: + average_attn_weights: bool = True) -> Tuple[Tensor, Optional[Tensor]]: r""" Args: query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False`` diff --git a/src/brevitas_examples/llm/llm_quant/mha_layers.py b/src/brevitas_examples/llm/llm_quant/mha_layers.py index fb4abb3e4..67eeb8738 100644 --- a/src/brevitas_examples/llm/llm_quant/mha_layers.py +++ b/src/brevitas_examples/llm/llm_quant/mha_layers.py @@ -176,7 +176,6 @@ def forward( attn_mask=attention_mask, need_weights=output_attentions, average_attn_weights=False, - position_ids=position_ids, ) past_key_value = None return attn_output, attn_output_weights, past_key_value