-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
FEAT Add hotswapping functionality (#2120)
The idea of hotswapping an adapter is the following: We can already load multiple adapters, e.g. two LoRAs, at the same time. But sometimes, we want to load one LoRA and then replace its weights in-place with the LoRA weights of another adapter. This is now possible the hotswap_adapter function. In general, this should be faster than deleting one adapter and loading the adapter in its place, which would be the current way to achieve the same final outcome. Another advantage of hotswapping is that it prevents re-compilation in case the PEFT model is already compiled. This can save quite a lot of time. There are some caveats for hotswapping: - It only works for the same PEFT method, so no swapping LoRA and LoHa. - Right now, only LoRA is properly supported. - The adapters must be compatible (e.g. same LoRA alpha, same target modules). - To avoid recompilation, ranks must be identical See also huggingface/diffusers#9453
- Loading branch information
1 parent
0d58942
commit cff2a45
Showing
7 changed files
with
811 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
<!--⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be | ||
rendered properly in your Markdown viewer. | ||
--> | ||
|
||
# Hotswapping adapters | ||
|
||
The idea of hotswapping an adapter is the following: We can already load multiple adapters, e.g. two LoRAs, at the same time. But sometimes, we want to load one LoRA and then replace its weights in-place with the LoRA weights of another adapter. This is now possible the `hotswap_adapter` function. | ||
|
||
In general, this should be faster than deleting one adapter and loading the adapter in its place, which would be the how to achieve the same final outcome without hotswapping. Another advantage of hotswapping is that it prevents re-compilation in case the PEFT model is already compiled using `torch.compile`. This can save quite a lot of time. | ||
|
||
```python | ||
import torch | ||
from transformers import AutoModelForCausalLM | ||
from peft import PeftModel | ||
from peft.utils.hotswap import hotswap_adapter | ||
|
||
model_id = ... | ||
inputs = ... | ||
device = ... | ||
model = AutoModelForCausalLM.from_pretrained(model_id).to(device) | ||
|
||
# load lora 0 | ||
model = PeftModel.from_pretrained(model, <path-adapter-0>) | ||
model = torch.compile(model) # optionally compile the model | ||
with torch.inference_mode(): | ||
output_adapter_0 = model(inputs) | ||
|
||
# replace the "default" lora adapter with the new one | ||
hotswap_adapter(model, <path-adapter-1>, adapter_name="default", torch_device=device) | ||
with torch.inference_mode(): | ||
output_adapter_1 = model(inputs).logits | ||
``` | ||
|
||
Hotswapping works with transformers models and diffusers models. However, there are some caveats: | ||
|
||
- It only works for the same PEFT method, so no swapping LoRA and LoHa, for example. | ||
- Right now, only LoRA is properly supported. | ||
- The adapters must be compatible (e.g. same LoRA alpha, same target modules). | ||
- If you use `torch.compile` and want to avoid recompilation, the LoRA rank must be the same. | ||
|
||
[[autodoc]] utils.hotswap.hotswap_adapter | ||
- all | ||
|
||
[[autodoc]] utils.hotswap.hotswap_adapter_from_state_dict | ||
- all |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,225 @@ | ||
# Copyright 2024-present the HuggingFace Inc. team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from __future__ import annotations | ||
|
||
from operator import attrgetter | ||
|
||
import torch | ||
|
||
from peft.config import PeftConfig | ||
from peft.mapping import PEFT_TYPE_TO_CONFIG_MAPPING | ||
|
||
from .constants import PEFT_TYPE_TO_PREFIX_MAPPING | ||
from .other import infer_device | ||
from .peft_types import PeftType | ||
from .save_and_load import _insert_adapter_name_into_state_dict, load_peft_weights | ||
|
||
|
||
# so far only LoRA is supported | ||
CONFIG_KEYS_TO_CHECK = {PeftType.LORA: ["lora_alpha", "use_rslora", "lora_dropout", "alpha_pattern", "use_dora"]} | ||
|
||
|
||
def hotswap_adapter_from_state_dict(model, state_dict, adapter_name, parameter_prefix="lora_"): | ||
""" | ||
Swap out the adapter weights from the model with the weights from state_dict. | ||
As of now, only LoRA is supported. | ||
This is a low-level function that assumes that the adapters have been checked for compatibility and that the | ||
state_dict has been correctly mapped to work with PEFT. For a high level function that performs this work for you, | ||
use `hotswap_adapter` instead. | ||
Args: | ||
model (`nn.Module`): | ||
The model with the loaded adapter. | ||
state_dict (`dict[str, torch.Tensor]`): | ||
The state dict of the new adapter, which needs to be compatible (targeting same modules etc.). | ||
adapter_name (`str`): | ||
The name of the adapter that should be hot-swapped, e.g. `"default"`. The name will remain the same after | ||
swapping. | ||
parameter_prefix (`str`, *optional*, defaults to `"lora_"`) | ||
The prefix used to identify the adapter's keys in the state dict. For LoRA, this would be `"lora_"` (the | ||
default). | ||
Raises: | ||
RuntimeError | ||
If the old and the new adapter are not compatible, a RuntimeError is raised. | ||
""" | ||
# Ensure that all the keys of the new adapter correspond exactly to the keys of the old adapter, otherwise | ||
# hot-swapping is not possible | ||
|
||
is_compiled = hasattr(model, "_orig_mod") | ||
# TODO: there is probably a more precise way to identify the adapter keys | ||
missing_keys = {k for k in model.state_dict() if (parameter_prefix in k) and (adapter_name in k)} | ||
unexpected_keys = set() | ||
|
||
# first: dry run, not swapping anything | ||
for key, new_val in state_dict.items(): | ||
try: | ||
old_val = attrgetter(key)(model) | ||
except AttributeError: | ||
unexpected_keys.add(key) | ||
continue | ||
|
||
if is_compiled: | ||
missing_keys.remove("_orig_mod." + key) | ||
else: | ||
missing_keys.remove(key) | ||
|
||
if missing_keys or unexpected_keys: | ||
msg = "Hot swapping the adapter did not succeed." | ||
if missing_keys: | ||
msg += f" Missing keys: {', '.join(sorted(missing_keys))}." | ||
if unexpected_keys: | ||
msg += f" Unexpected keys: {', '.join(sorted(unexpected_keys))}." | ||
raise RuntimeError(msg) | ||
|
||
# actual swapping | ||
for key, new_val in state_dict.items(): | ||
# no need to account for potential _orig_mod in key here, as torch handles that | ||
old_val = attrgetter(key)(model) | ||
if is_compiled: | ||
# Compiled models don't work with swap_tensors because there are weakrefs for the tensor. It is unclear if | ||
# this workaround could not cause trouble but the tests indicate that it works. | ||
old_val.data = new_val.data | ||
else: | ||
torch.utils.swap_tensors(old_val, new_val) | ||
|
||
|
||
def _check_hotswap_configs_compatible(config0: PeftConfig, config1: PeftConfig) -> None: | ||
""" | ||
Check if two configs are compatible for hot-swapping. | ||
Only LoRA parameters are checked for now. | ||
To hot-swap two adapters, their configs must be compatible. Otherwise, the results could be false. E.g. if they use | ||
different alpha values, after hot-swapping, the alphas from the first adapter would still be used with the weights | ||
from the 2nd adapter, which would result in incorrect behavior. There is probably a way to swap these values as | ||
well, but that's not implemented yet, and we need to be careful not to trigger re-compilation if the model is | ||
compiled (so no modification of the dict). | ||
""" | ||
|
||
if config0.peft_type != config1.peft_type: | ||
msg = f"Incompatible PEFT types found: {config0.peft_type.value} and {config1.peft_type.value}" | ||
raise ValueError(msg) | ||
|
||
if config0.peft_type not in CONFIG_KEYS_TO_CHECK: | ||
msg = ( | ||
f"Hotswapping only supports {', '.join(CONFIG_KEYS_TO_CHECK.keys())} but " | ||
f"{config0.peft_type.value} was passed." | ||
) | ||
raise ValueError(msg) | ||
config_keys_to_check = CONFIG_KEYS_TO_CHECK[config0.peft_type] | ||
|
||
# TODO: This is a very rough check only for LoRA at the moment. Also, there might be some options that don't | ||
# necessarily require an error. | ||
config0 = config0.to_dict() | ||
config1 = config1.to_dict() | ||
sentinel = object() | ||
for key in config_keys_to_check: | ||
val0 = config0.get(key, sentinel) | ||
val1 = config1.get(key, sentinel) | ||
if val0 != val1: | ||
raise ValueError(f"Configs are incompatible: for {key}, {val0} != {val1}") | ||
|
||
|
||
def hotswap_adapter(model, model_name_or_path, adapter_name, torch_device=None, **kwargs): | ||
"""Substitute old adapter data with new adapter data, keeping the rest the same. | ||
As of now, only LoRA is supported. | ||
This function is useful when you want to replace the loaded adapter with a new adapter. The adapter name will | ||
remain the same, but the weights and other parameters will be swapped out. | ||
If the adapters are incomptabile, e.g. targeting different layers or having different alpha values, an error will | ||
be raised. | ||
Example: | ||
```py | ||
>>> import torch | ||
>>> from transformers import AutoModelForCausalLM | ||
>>> from peft import PeftModel | ||
>>> from peft.utils.hotswap import hotswap_adapter | ||
>>> model_id = ... | ||
>>> inputs = ... | ||
>>> device = ... | ||
>>> model = AutoModelForCausalLM.from_pretrained(model_id).to(device) | ||
>>> # load lora 0 | ||
>>> model = PeftModel.from_pretrained(model, "path-adapter-0") | ||
>>> model = torch.compile(model) # optionally compile the model | ||
>>> with torch.inference_mode(): | ||
... output_adapter_0 = model(inputs) | ||
>>> # replace the "default" lora adapter with the new one | ||
>>> hotswap_adapter(model, "path-adapter-1", adapter_name="default", torch_device=device) | ||
>>> with torch.inference_mode(): | ||
... output_adapter_1 = model(inputs).logits | ||
``` | ||
Args: | ||
model ([`~PeftModel`]): | ||
The PEFT model with the loaded adapter. | ||
model_name_or_path (`str`): | ||
The name or path of the model to load the new adapter from. | ||
adapter_name (`str`): | ||
The name of the adapter to swap, e.g. `"default"`. The name will stay the same after swapping. | ||
torch_device: (`str`, *optional*, defaults to None): | ||
The device to load the new adapter onto. | ||
**kwargs (`optional`): | ||
Additional keyword arguments used for loading the config and weights. | ||
""" | ||
if torch_device is None: | ||
torch_device = infer_device() | ||
|
||
############################ | ||
# LOAD CONFIG AND VALIDATE # | ||
############################ | ||
|
||
config_cls = PEFT_TYPE_TO_CONFIG_MAPPING[ | ||
PeftConfig._get_peft_type( | ||
model_name_or_path, | ||
subfolder=kwargs.get("subfolder", None), | ||
revision=kwargs.get("revision", None), | ||
cache_dir=kwargs.get("cache_dir", None), | ||
use_auth_token=kwargs.get("use_auth_token", None), | ||
token=kwargs.get("token", None), | ||
) | ||
] | ||
config = config_cls.from_pretrained(model_name_or_path, **kwargs) | ||
# config keys that could affect the model output besides what is determined by the state_dict | ||
_check_hotswap_configs_compatible(model.active_peft_config, config) | ||
|
||
state_dict = load_peft_weights(model_name_or_path, device=torch_device, **kwargs) | ||
|
||
########################### | ||
# LOAD & REMAP STATE_DICT # | ||
########################### | ||
|
||
parameter_prefix = PEFT_TYPE_TO_PREFIX_MAPPING[config.peft_type] | ||
peft_model_state_dict = _insert_adapter_name_into_state_dict( | ||
state_dict, adapter_name=adapter_name, parameter_prefix=parameter_prefix | ||
) | ||
|
||
hotswap_adapter_from_state_dict( | ||
model=model, | ||
state_dict=peft_model_state_dict, | ||
adapter_name=adapter_name, | ||
parameter_prefix=parameter_prefix, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.