Skip to content

Commit

Permalink
FIX Add missing attributes to MultiheadAttention (#2335)
Browse files Browse the repository at this point in the history
See initial report here:
#761 (comment).

For MHA to work in all circumstances, for instance in eval model, it
requires us to expose a couple of more attributes that we have missed so
far. Those were added now.
  • Loading branch information
BenjaminBossan authored Jan 20, 2025
1 parent da998c8 commit 8302817
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 0 deletions.
27 changes: 27 additions & 0 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1405,6 +1405,33 @@ def batch_first(self) -> bool:
def head_dim(self) -> int:
return self.get_base_layer().head_dim

@property
def in_proj_weight(self) -> nn.Parameter:
return self.get_base_layer().in_proj_weight

@property
def in_proj_bias(self) -> nn.Parameter:
return self.get_base_layer().in_proj_bias

@property
def out_proj(self) -> nn.Module:
return self.get_base_layer().out_proj.get_base_layer()

@property
def bias_k(self) -> Optional[nn.Parameter]:
return self.get_base_layer().bias_k

@property
def bias_v(self) -> Optional[nn.Parameter]:
return self.get_base_layer().bias_v

def merge_masks(self, *args, **kwargs) -> tuple[Optional[torch.Tensor], Optional[int]]:
return self.get_base_layer().merge_masks(*args, **kwargs)

@property
def add_zero_attn(self) -> bool:
return self.get_base_layer().add_zero_attn

def update_layer(self, *args, **kwargs) -> None:
super().update_layer(*args, **kwargs)
# Note: LoRA is applied to both in_proj and out_proj. There is currently no way to only specify one of them.
Expand Down
40 changes: 40 additions & 0 deletions tests/test_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -1145,6 +1145,7 @@ def test_mha_with_dora_raises(self, mha_cls):
get_peft_model(model, config)

def test_mha_exposes_attributes(self, mha_cls):
# MHA requires a bunch of attributes to be exposed, try to check them exhaustively here
model = mha_cls()
embed_dim = model.mha.embed_dim
kdim = model.mha.kdim
Expand All @@ -1154,6 +1155,12 @@ def test_mha_exposes_attributes(self, mha_cls):
dropout = model.mha.dropout
batch_first = model.mha.batch_first
head_dim = model.mha.head_dim
in_proj_weight = model.mha.in_proj_weight
in_proj_bias = model.mha.in_proj_bias
out_proj = model.mha.out_proj
bias_k = model.mha.bias_k
bias_v = model.mha.bias_v
add_zero_attn = model.mha.add_zero_attn

config = LoraConfig(target_modules=["mha"])
peft_model = get_peft_model(model, config)
Expand All @@ -1165,6 +1172,39 @@ def test_mha_exposes_attributes(self, mha_cls):
assert peft_model.base_model.mha.dropout == dropout
assert peft_model.base_model.mha.batch_first == batch_first
assert peft_model.base_model.mha.head_dim == head_dim
if in_proj_weight is not None:
assert torch.allclose(peft_model.base_model.mha.in_proj_weight, in_proj_weight)
else:
assert peft_model.base_model.mha.in_proj_weight is None
if in_proj_bias is not None:
assert torch.allclose(peft_model.base_model.mha.in_proj_bias, in_proj_bias)
else:
assert peft_model.base_model.mha.in_proj_bias is None
assert peft_model.base_model.mha.out_proj is out_proj
if bias_k is not None:
assert torch.allclose(peft_model.base_model.mha.bias_k, bias_k)
else:
assert peft_model.base_model.mha.bias_k is None
if bias_v is not None:
assert torch.allclose(peft_model.base_model.mha.bias_v, bias_v)
else:
assert peft_model.base_model.mha.bias_v is None
assert peft_model.base_model.mha.add_zero_attn == add_zero_attn

def test_mha_merge_masks_method(self, mha_cls):
# MHA requires a merge_masks method to be exposed, check that it works
model = mha_cls()
config = LoraConfig(target_modules=["mha"])
peft_model = get_peft_model(model, config)

attn_mask = torch.randint(0, 2, (10, 10))
key_padding_mask = torch.randint(0, 2, (10, 10))
query = torch.rand(10, 10, 10)
merged_mask0, mask_type0 = model.mha.merge_masks(attn_mask, key_padding_mask, query)
merged_mask1, mask_type1 = peft_model.base_model.mha.merge_masks(attn_mask, key_padding_mask, query)

assert torch.allclose(merged_mask0, merged_mask1)
assert mask_type0 == mask_type1

def test_lora_with_bias_extra_params(self):
# lora with lora_bias=True
Expand Down

0 comments on commit 8302817

Please sign in to comment.