Skip to content

Commit

Permalink
ENH: Improved attribute access for modules_to_save (#2117)
Browse files Browse the repository at this point in the history
Resolves #2099

So far, if a module was wrapped due to modules_to_save, we handled
access to the weight and bias attribute (albeit incorrectly in case of
disabled adapters!). However, there could be more attributes than those
that could be accessed, in which case we got an error so far.

Instead of special properties, we now implement a generic __getattr__
method that can deal with any attribute. The implementation is a bit
complex to take into account the way that torch.nn.Module handles
__getattr__.
  • Loading branch information
BenjaminBossan authored Oct 2, 2024
1 parent 2a80735 commit ae297f0
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 11 deletions.
34 changes: 23 additions & 11 deletions src/peft/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,17 +227,29 @@ def active_adapter(self) -> str:
# use a property to ensure that active_adapter is not set directly, instead use the set_adapter method
return self._active_adapter

@property
def weight(self):
if self.active_adapter not in self.modules_to_save:
return self.original_module.weight
return self.modules_to_save[self.active_adapter].weight

@property
def bias(self):
if self.active_adapter not in self.modules_to_save:
return self.original_module.bias
return self.modules_to_save[self.active_adapter].bias
def __getattr__(self, name: str):
# Note: This whole method may seem overly complex at first but PyTorch messes with __getattr__ in a way that
# requires very careful handling to avoid infinite recursion.
try:
return super().__getattr__(name)
except AttributeError:
pass

if "_modules" not in self.__dict__:
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")

# Could not find the attribute the PyTorch way. So let's check if it's an attribute on the
# original_module/modules_to_save.
modules = self.__dict__["_modules"]
if self.disable_adapters:
module = modules["original_module"]
elif self.active_adapter in modules["modules_to_save"]:
module = modules["modules_to_save"][self.active_adapter]
else:
# For some reason, there is no module corresponding to the active adapter; this should normally not be
# reached and exists as a failsafe (otherwise, a KeyError would be raised)
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
return getattr(module, name)

def update(self, adapter_name):
context_manager = nullcontext()
Expand Down
96 changes: 96 additions & 0 deletions tests/test_other.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification

from peft import LoraConfig, get_peft_model
from peft.utils.other import ModulesToSaveWrapper


class ModelWithModuleDict(nn.Module):
Expand Down Expand Up @@ -103,3 +104,98 @@ def test_get_peft_model_revision_warning(tmp_path):
overwrite_warning = f"peft config has already set base model revision to {base_revision}, overwriting with revision {overwrite_revision}"
with pytest.warns(UserWarning, match=overwrite_warning):
_ = get_peft_model(base_model, lora_config, revision=overwrite_revision)


class TestModulesToSaveAttributeAccess:
"""Test attribute accces on the ModulesToSaveWrapper class.
When we have modules_to_save, the original module is wrapped. As long as only forward was called on this wrapped
module, we were good. However, if, for instance, model parameters were directly accessed by another module, this
would typically fail, as the wrapper does not have this attribute. We had special properties for weight and bias,
but this is not enough. Therefore, attribute access is now transiently delegated to the active adapter (or original
module, if the adapter is disabled).
For one example, see #2099.
"""

@pytest.fixture
def mlp(self):
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.lin0 = nn.Linear(1, 2)
self.lin1 = nn.Linear(3, 4)

return MLP()

def test_transient_attribute_access_default_adapter(self, mlp):
config = LoraConfig(target_modules=["lin0"], modules_to_save=["lin1"])
model = get_peft_model(mlp, config)
assert model.lin1.weight is model.lin1.modules_to_save["default"].weight
assert model.lin1.bias is model.lin1.modules_to_save["default"].bias

def test_transient_attribute_access_non_default_adapter(self, mlp):
config = LoraConfig(target_modules=["lin0"], modules_to_save=["lin1"])
model = get_peft_model(mlp, config)
model.add_adapter("other", config)

# at this point, default is still active
assert model.lin1.weight is model.lin1.modules_to_save["default"].weight
assert model.lin1.bias is model.lin1.modules_to_save["default"].bias
assert model.lin1.weight is not model.lin1.modules_to_save["other"].weight
assert model.lin1.bias is not model.lin1.modules_to_save["other"].bias

model.set_adapter("other")
assert model.lin1.weight is not model.lin1.modules_to_save["default"].weight
assert model.lin1.bias is not model.lin1.modules_to_save["default"].bias
assert model.lin1.weight is model.lin1.modules_to_save["other"].weight
assert model.lin1.bias is model.lin1.modules_to_save["other"].bias

def test_transient_attribute_access_disabled_adapter(self, mlp):
config = LoraConfig(target_modules=["lin0"], modules_to_save=["lin1"])
model = get_peft_model(mlp, config)

# at this point, default is still active
assert model.lin1.weight is model.lin1.modules_to_save["default"].weight
assert model.lin1.bias is model.lin1.modules_to_save["default"].bias
assert model.lin1.weight is not model.lin1.original_module.weight
assert model.lin1.bias is not model.lin1.original_module.bias

with model.disable_adapter():
assert model.lin1.weight is not model.lin1.modules_to_save["default"].weight
assert model.lin1.bias is not model.lin1.modules_to_save["default"].bias
assert model.lin1.weight is model.lin1.original_module.weight
assert model.lin1.bias is model.lin1.original_module.bias

def test_transient_attribute_access_uninitialized_adapter(self, mlp):
# ensure that there is no weird infinite recursion when accessing a non-existing attribute on the class itself
with pytest.raises(AttributeError, match="has no attribute 'original_module'"):
ModulesToSaveWrapper.original_module

def test_transient_attribute_access_attr_does_not_exist_on_modules_to_save(self, mlp):
# ensure that there is no weird infinite recursion when accessing a non-existing attribute on the
# ModelToSaveWrapper instance
config = LoraConfig(target_modules=["lin0"], modules_to_save=["lin1"])
model = get_peft_model(mlp, config)

with pytest.raises(AttributeError, match="has no attribute 'foo'"):
model.lin1.foo

def test_transient_attribute_access_attr_does_not_exist_on_original_module(self, mlp):
# ensure that there is no weird infinite recursion when accessing a non-existing attribute on the
# original module of the ModelToSaveWrapper instance
config = LoraConfig(target_modules=["lin0"], modules_to_save=["lin1"])
model = get_peft_model(mlp, config)

with pytest.raises(AttributeError, match="has no attribute 'foo'"):
with model.disable_adapter():
model.lin1.foo

def test_transient_attribute_access_non_existing_adapter(self, mlp):
# This should normally never happen, as the active adapter should always exist, but it's a failsafe
config = LoraConfig(target_modules=["lin0"], modules_to_save=["lin1"])
model = get_peft_model(mlp, config)
model.base_model.model.lin1._active_adapter = "does-not-exist"
with pytest.raises(AttributeError, match="has no attribute 'weight'"):
model.lin1.weight

0 comments on commit ae297f0

Please sign in to comment.