diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index d4d346919c..fe66a3d6c4 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -125,5 +125,7 @@ title: Model merge - local: package_reference/helpers title: Helpers + - local: package_reference/hotswap + title: Hotswapping adapters title: Utilities title: API reference diff --git a/docs/source/package_reference/hotswap.md b/docs/source/package_reference/hotswap.md new file mode 100644 index 0000000000..d2f36bc2b6 --- /dev/null +++ b/docs/source/package_reference/hotswap.md @@ -0,0 +1,45 @@ + + +# Hotswapping adapters + +The idea of hotswapping an adapter is the following: We can already load multiple adapters, e.g. two LoRAs, at the same time. But sometimes, we want to load one LoRA and then replace its weights in-place with the LoRA weights of another adapter. This is now possible the `hotswap_adapter` function. + +In general, this should be faster than deleting one adapter and loading the adapter in its place, which would be the how to achieve the same final outcome without hotswapping. Another advantage of hotswapping is that it prevents re-compilation in case the PEFT model is already compiled using `torch.compile`. This can save quite a lot of time. + +```python +import torch +from transformers import AutoModelForCausalLM +from peft import PeftModel +from peft.utils.hotswap import hotswap_adapter + +model_id = ... +inputs = ... +device = ... +model = AutoModelForCausalLM.from_pretrained(model_id).to(device) + +# load lora 0 +model = PeftModel.from_pretrained(model, ) +model = torch.compile(model) # optionally compile the model +with torch.inference_mode(): + output_adapter_0 = model(inputs) + +# replace the "default" lora adapter with the new one +hotswap_adapter(model, , adapter_name="default", torch_device=device) +with torch.inference_mode(): + output_adapter_1 = model(inputs).logits +``` + +Hotswapping works with transformers models and diffusers models. However, there are some caveats: + +- It only works for the same PEFT method, so no swapping LoRA and LoHa, for example. +- Right now, only LoRA is properly supported. +- The adapters must be compatible (e.g. same LoRA alpha, same target modules). +- If you use `torch.compile` and want to avoid recompilation, the LoRA rank must be the same. + +[[autodoc]] utils.hotswap.hotswap_adapter + - all + +[[autodoc]] utils.hotswap.hotswap_adapter_from_state_dict + - all diff --git a/src/peft/utils/hotswap.py b/src/peft/utils/hotswap.py new file mode 100644 index 0000000000..3ff7caacce --- /dev/null +++ b/src/peft/utils/hotswap.py @@ -0,0 +1,225 @@ +# Copyright 2024-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from operator import attrgetter + +import torch + +from peft.config import PeftConfig +from peft.mapping import PEFT_TYPE_TO_CONFIG_MAPPING + +from .constants import PEFT_TYPE_TO_PREFIX_MAPPING +from .other import infer_device +from .peft_types import PeftType +from .save_and_load import _insert_adapter_name_into_state_dict, load_peft_weights + + +# so far only LoRA is supported +CONFIG_KEYS_TO_CHECK = {PeftType.LORA: ["lora_alpha", "use_rslora", "lora_dropout", "alpha_pattern", "use_dora"]} + + +def hotswap_adapter_from_state_dict(model, state_dict, adapter_name, parameter_prefix="lora_"): + """ + Swap out the adapter weights from the model with the weights from state_dict. + + As of now, only LoRA is supported. + + This is a low-level function that assumes that the adapters have been checked for compatibility and that the + state_dict has been correctly mapped to work with PEFT. For a high level function that performs this work for you, + use `hotswap_adapter` instead. + + Args: + model (`nn.Module`): + The model with the loaded adapter. + state_dict (`dict[str, torch.Tensor]`): + The state dict of the new adapter, which needs to be compatible (targeting same modules etc.). + adapter_name (`str`): + The name of the adapter that should be hot-swapped, e.g. `"default"`. The name will remain the same after + swapping. + parameter_prefix (`str`, *optional*, defaults to `"lora_"`) + The prefix used to identify the adapter's keys in the state dict. For LoRA, this would be `"lora_"` (the + default). + + Raises: + RuntimeError + If the old and the new adapter are not compatible, a RuntimeError is raised. + + """ + # Ensure that all the keys of the new adapter correspond exactly to the keys of the old adapter, otherwise + # hot-swapping is not possible + + is_compiled = hasattr(model, "_orig_mod") + # TODO: there is probably a more precise way to identify the adapter keys + missing_keys = {k for k in model.state_dict() if (parameter_prefix in k) and (adapter_name in k)} + unexpected_keys = set() + + # first: dry run, not swapping anything + for key, new_val in state_dict.items(): + try: + old_val = attrgetter(key)(model) + except AttributeError: + unexpected_keys.add(key) + continue + + if is_compiled: + missing_keys.remove("_orig_mod." + key) + else: + missing_keys.remove(key) + + if missing_keys or unexpected_keys: + msg = "Hot swapping the adapter did not succeed." + if missing_keys: + msg += f" Missing keys: {', '.join(sorted(missing_keys))}." + if unexpected_keys: + msg += f" Unexpected keys: {', '.join(sorted(unexpected_keys))}." + raise RuntimeError(msg) + + # actual swapping + for key, new_val in state_dict.items(): + # no need to account for potential _orig_mod in key here, as torch handles that + old_val = attrgetter(key)(model) + if is_compiled: + # Compiled models don't work with swap_tensors because there are weakrefs for the tensor. It is unclear if + # this workaround could not cause trouble but the tests indicate that it works. + old_val.data = new_val.data + else: + torch.utils.swap_tensors(old_val, new_val) + + +def _check_hotswap_configs_compatible(config0: PeftConfig, config1: PeftConfig) -> None: + """ + Check if two configs are compatible for hot-swapping. + + Only LoRA parameters are checked for now. + + To hot-swap two adapters, their configs must be compatible. Otherwise, the results could be false. E.g. if they use + different alpha values, after hot-swapping, the alphas from the first adapter would still be used with the weights + from the 2nd adapter, which would result in incorrect behavior. There is probably a way to swap these values as + well, but that's not implemented yet, and we need to be careful not to trigger re-compilation if the model is + compiled (so no modification of the dict). + + """ + + if config0.peft_type != config1.peft_type: + msg = f"Incompatible PEFT types found: {config0.peft_type.value} and {config1.peft_type.value}" + raise ValueError(msg) + + if config0.peft_type not in CONFIG_KEYS_TO_CHECK: + msg = ( + f"Hotswapping only supports {', '.join(CONFIG_KEYS_TO_CHECK.keys())} but " + f"{config0.peft_type.value} was passed." + ) + raise ValueError(msg) + config_keys_to_check = CONFIG_KEYS_TO_CHECK[config0.peft_type] + + # TODO: This is a very rough check only for LoRA at the moment. Also, there might be some options that don't + # necessarily require an error. + config0 = config0.to_dict() + config1 = config1.to_dict() + sentinel = object() + for key in config_keys_to_check: + val0 = config0.get(key, sentinel) + val1 = config1.get(key, sentinel) + if val0 != val1: + raise ValueError(f"Configs are incompatible: for {key}, {val0} != {val1}") + + +def hotswap_adapter(model, model_name_or_path, adapter_name, torch_device=None, **kwargs): + """Substitute old adapter data with new adapter data, keeping the rest the same. + + As of now, only LoRA is supported. + + This function is useful when you want to replace the loaded adapter with a new adapter. The adapter name will + remain the same, but the weights and other parameters will be swapped out. + + If the adapters are incomptabile, e.g. targeting different layers or having different alpha values, an error will + be raised. + + Example: + + ```py + >>> import torch + >>> from transformers import AutoModelForCausalLM + >>> from peft import PeftModel + >>> from peft.utils.hotswap import hotswap_adapter + + >>> model_id = ... + >>> inputs = ... + >>> device = ... + >>> model = AutoModelForCausalLM.from_pretrained(model_id).to(device) + + >>> # load lora 0 + >>> model = PeftModel.from_pretrained(model, "path-adapter-0") + >>> model = torch.compile(model) # optionally compile the model + >>> with torch.inference_mode(): + ... output_adapter_0 = model(inputs) + + >>> # replace the "default" lora adapter with the new one + >>> hotswap_adapter(model, "path-adapter-1", adapter_name="default", torch_device=device) + >>> with torch.inference_mode(): + ... output_adapter_1 = model(inputs).logits + ``` + + Args: + model ([`~PeftModel`]): + The PEFT model with the loaded adapter. + model_name_or_path (`str`): + The name or path of the model to load the new adapter from. + adapter_name (`str`): + The name of the adapter to swap, e.g. `"default"`. The name will stay the same after swapping. + torch_device: (`str`, *optional*, defaults to None): + The device to load the new adapter onto. + **kwargs (`optional`): + Additional keyword arguments used for loading the config and weights. + + """ + if torch_device is None: + torch_device = infer_device() + + ############################ + # LOAD CONFIG AND VALIDATE # + ############################ + + config_cls = PEFT_TYPE_TO_CONFIG_MAPPING[ + PeftConfig._get_peft_type( + model_name_or_path, + subfolder=kwargs.get("subfolder", None), + revision=kwargs.get("revision", None), + cache_dir=kwargs.get("cache_dir", None), + use_auth_token=kwargs.get("use_auth_token", None), + token=kwargs.get("token", None), + ) + ] + config = config_cls.from_pretrained(model_name_or_path, **kwargs) + # config keys that could affect the model output besides what is determined by the state_dict + _check_hotswap_configs_compatible(model.active_peft_config, config) + + state_dict = load_peft_weights(model_name_or_path, device=torch_device, **kwargs) + + ########################### + # LOAD & REMAP STATE_DICT # + ########################### + + parameter_prefix = PEFT_TYPE_TO_PREFIX_MAPPING[config.peft_type] + peft_model_state_dict = _insert_adapter_name_into_state_dict( + state_dict, adapter_name=adapter_name, parameter_prefix=parameter_prefix + ) + + hotswap_adapter_from_state_dict( + model=model, + state_dict=peft_model_state_dict, + adapter_name=adapter_name, + parameter_prefix=parameter_prefix, + ) diff --git a/src/peft/utils/save_and_load.py b/src/peft/utils/save_and_load.py index ae210877ef..478cfaa219 100644 --- a/src/peft/utils/save_and_load.py +++ b/src/peft/utils/save_and_load.py @@ -305,6 +305,25 @@ def _find_mismatched_keys( return peft_model_state_dict, mismatched +def _insert_adapter_name_into_state_dict( + state_dict: dict[str, torch.Tensor], adapter_name: str, parameter_prefix: str +) -> dict[str, torch.Tensor]: + """Utility function to remap the state_dict keys to fit the PEFT model by inserting the adapter name.""" + peft_model_state_dict = {} + for key, val in state_dict.items(): + if parameter_prefix in key: + suffix = key.split(parameter_prefix)[1] + if "." in suffix: + suffix_to_replace = ".".join(suffix.split(".")[1:]) + key = key.replace(suffix_to_replace, f"{adapter_name}.{suffix_to_replace}") + else: + key = f"{key}.{adapter_name}" + peft_model_state_dict[key] = val + else: + peft_model_state_dict[key] = val + return peft_model_state_dict + + def set_peft_model_state_dict( model, peft_model_state_dict, @@ -342,21 +361,7 @@ def set_peft_model_state_dict( else: state_dict = peft_model_state_dict - if config.peft_type in ( - PeftType.LORA, - PeftType.LOHA, - PeftType.LOKR, - PeftType.ADALORA, - PeftType.IA3, - PeftType.OFT, - PeftType.POLY, - PeftType.LN_TUNING, - PeftType.BOFT, - PeftType.VERA, - PeftType.FOURIERFT, - PeftType.HRA, - PeftType.VBLORA, - ): + if config.peft_type in PEFT_TYPE_TO_PREFIX_MAPPING: peft_model_state_dict = {} parameter_prefix = PEFT_TYPE_TO_PREFIX_MAPPING[config.peft_type] if config.peft_type == PeftType.VBLORA and config.save_only_topk_weights: @@ -386,17 +391,10 @@ def set_peft_model_state_dict( # delete the topk_indices and topk_weights from the state_dict del state_dict[k] del state_dict[k.replace("_topk_indices", "_topk_weights")] - for k, v in state_dict.items(): - if parameter_prefix in k: - suffix = k.split(parameter_prefix)[1] - if "." in suffix: - suffix_to_replace = ".".join(suffix.split(".")[1:]) - k = k.replace(suffix_to_replace, f"{adapter_name}.{suffix_to_replace}") - else: - k = f"{k}.{adapter_name}" - peft_model_state_dict[k] = v - else: - peft_model_state_dict[k] = v + + peft_model_state_dict = _insert_adapter_name_into_state_dict( + state_dict, adapter_name=adapter_name, parameter_prefix=parameter_prefix + ) if config.peft_type == PeftType.ADALORA: rank_pattern = config.rank_pattern diff --git a/tests/run_compiled_diffusion_model_hotswap.py b/tests/run_compiled_diffusion_model_hotswap.py new file mode 100644 index 0000000000..80787a0fab --- /dev/null +++ b/tests/run_compiled_diffusion_model_hotswap.py @@ -0,0 +1,143 @@ +# Copyright 2024-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""This is a standalone script that checks that we can hotswap a LoRA adapter on a compiled model + +By itself, this script is not super interesting but when we collect the compile logs, we can check that hotswapping +does not trigger recompilation. This is done in the TestLoraHotSwapping class in test_pipelines.py. + +Running this script with `check_hotswap(False)` will load the LoRA adapter without hotswapping, which will result in +recompilation. + +There is an equivalent test in diffusers, see https://github.com/huggingface/diffusers/pull/9453. + +""" + +import os +import sys +import tempfile + +import torch +from diffusers import StableDiffusionPipeline, UNet2DConditionModel +from diffusers.utils.testing_utils import floats_tensor + +from peft import LoraConfig, get_peft_model_state_dict +from peft.tuners.tuners_utils import BaseTunerLayer + + +torch_device = "cuda" if torch.cuda.is_available() else "cpu" + + +def get_small_unet(): + # from diffusers UNet2DConditionModelTests + # TODO: This appears not to work yet in full pipeline context, see: + # https://github.com/huggingface/diffusers/pull/9453#issuecomment-2418508871 + torch.manual_seed(0) + init_dict = { + "block_out_channels": (4, 8), + "norm_num_groups": 4, + "down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"), + "up_block_types": ("UpBlock2D", "CrossAttnUpBlock2D"), + "cross_attention_dim": 8, + "attention_head_dim": 2, + "out_channels": 4, + "in_channels": 4, + "layers_per_block": 1, + "sample_size": 16, + } + model = UNet2DConditionModel(**init_dict) + return model.to(torch_device) + + +def get_unet_lora_config(): + # from diffusers test_models_unet_2d_condition.py + rank = 4 + unet_lora_config = LoraConfig( + r=rank, + lora_alpha=rank, + target_modules=["to_q", "to_k", "to_v", "to_out.0"], + init_lora_weights=False, + use_dora=False, + ) + return unet_lora_config + + +def get_dummy_input(): + # from UNet2DConditionModelTests + batch_size = 4 + num_channels = 4 + sizes = (16, 16) + + noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) + time_step = torch.tensor([10]).to(torch_device) + encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device) + + return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states} + + +def get_lora_state_dicts(modules_to_save): + state_dicts = {} + for module_name, module in modules_to_save.items(): + if module is not None: + state_dicts[f"{module_name}_lora_layers"] = get_peft_model_state_dict(module) + return state_dicts + + +def set_lora_device(model, adapter_names, device): + # copied from diffusers LoraBaseMixin.set_lora_device + for module in model.modules(): + if isinstance(module, BaseTunerLayer): + for adapter_name in adapter_names: + module.lora_A[adapter_name].to(device) + module.lora_B[adapter_name].to(device) + # this is a param, not a module, so device placement is not in-place -> re-assign + if hasattr(module, "lora_magnitude_vector") and module.lora_magnitude_vector is not None: + if adapter_name in module.lora_magnitude_vector: + module.lora_magnitude_vector[adapter_name] = module.lora_magnitude_vector[adapter_name].to( + device + ) + + +def check_hotswap(do_hotswap): + dummy_input = get_dummy_input() + unet = get_small_unet() + lora_config = get_unet_lora_config() + unet.add_adapter(lora_config) + + with tempfile.TemporaryDirectory() as tmp_dirname: + lora_state_dicts = get_lora_state_dicts({"unet": unet}) + StableDiffusionPipeline.save_lora_weights( + save_directory=tmp_dirname, safe_serialization=True, **lora_state_dicts + ) + del unet + + unet = get_small_unet() + file_name = os.path.join(tmp_dirname, "pytorch_lora_weights.safetensors") + unet.load_attn_procs(file_name) + unet = torch.compile(unet, mode="reduce-overhead") + unet(**dummy_input)["sample"] + + if do_hotswap: + unet.load_attn_procs(file_name, adapter_name="default_0", hotswap=True) + else: + # offloading the old and loading the new adapter will result in recompilation + set_lora_device(unet, adapter_names=["default_0"], device="cpu") + unet.load_attn_procs(file_name, adapter_name="other_name", hotswap=False) + + # we need to call forward to potentially trigger recompilation + unet(**dummy_input)["sample"] + + +if __name__ == "__main__": + # check_hotswap(False) will trigger recompilation + check_hotswap(do_hotswap=sys.argv[1] == "1") diff --git a/tests/run_compiled_model_hotswap.py b/tests/run_compiled_model_hotswap.py new file mode 100644 index 0000000000..1ec41b456e --- /dev/null +++ b/tests/run_compiled_model_hotswap.py @@ -0,0 +1,69 @@ +# Copyright 2024-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""This is a standalone script that checks that we can hotswap a LoRA adapter on a compiled model + +By itself, this script is not super interesting but when we collect the compile logs, we can check that hotswapping +does not trigger recompilation. This is done in the TestLoraHotSwapping class in test_pipelines.py. + +Running this script with `check_hotswap(False)` will load the LoRA adapter without hotswapping, which will result in +recompilation. + +""" + +import os +import sys +import tempfile + +import torch +from transformers import AutoModelForCausalLM + +from peft import LoraConfig, PeftModel, get_peft_model +from peft.utils import infer_device +from peft.utils.hotswap import hotswap_adapter + + +torch_device = infer_device() + + +def check_hotswap(do_hotswap=True): + torch.manual_seed(0) + inputs = torch.arange(10).view(-1, 1).to(torch_device) + model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-OPTForCausalLM").to(torch_device) + config = LoraConfig(init_lora_weights=False) + model = get_peft_model(model, config, adapter_name="adapter0").eval() + model.add_adapter("adapter1", config) + + with tempfile.TemporaryDirectory() as tmp_dirname: + model.save_pretrained(tmp_dirname) + del model + + model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-OPTForCausalLM").to(torch_device) + model = PeftModel.from_pretrained(model, os.path.join(tmp_dirname, "adapter0")).eval() + model = torch.compile(model, mode="reduce-overhead") + model(inputs).logits + + # swap and check that we get the output from adapter1 + if do_hotswap: + hotswap_adapter(model, os.path.join(tmp_dirname, "adapter1"), adapter_name="default") + else: + model.load_adapter(os.path.join(tmp_dirname, "adapter1"), adapter_name="other") + model.set_adapter("other") + + # we need to call forward to potentially trigger recompilation + model(inputs).logits + + +if __name__ == "__main__": + # check_hotswap(False) will trigger recompilation + check_hotswap(do_hotswap=sys.argv[1] == "1") diff --git a/tests/test_initialization.py b/tests/test_initialization.py index 7284acbb98..1b1cc2f1f3 100644 --- a/tests/test_initialization.py +++ b/tests/test_initialization.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +import platform import re +import subprocess +import sys from contextlib import contextmanager from copy import deepcopy from unittest.mock import patch @@ -27,6 +31,7 @@ from peft import ( AdaLoraConfig, + IA3Config, LoraConfig, PeftMixedModel, PeftModel, @@ -44,6 +49,7 @@ set_peft_model_state_dict, ) from peft.utils import infer_device +from peft.utils.hotswap import hotswap_adapter class TestLoraInitialization: @@ -1544,3 +1550,300 @@ def new_state_dict(): model = PeftModel.from_pretrained(model, tmp_path) assert any(msg in str(w.message) for w in recwarn.list) assert any(missing_key in str(w.message) for w in recwarn.list) + + +@pytest.mark.skipif( + platform.system() != "Linux", reason="Out of the box, torch.compile does not work on Windows or MacOS" +) +class TestHotSwapping: + """Tests for the hotswapping function""" + + torch_device = infer_device() + + def compile(self, model, do_compile): + if not do_compile: + return model + return torch.compile(model) + + def get_model(self): + class MLP(nn.Module): + def __init__(self, bias=True): + super().__init__() + self.lin0 = nn.Linear(10, 20, bias=True) + self.relu = nn.ReLU() + self.lin1 = nn.Linear(20, 5, bias=False) + + def forward(self, X): + X = X.float() + X = self.lin0(X) + X = self.relu(X) + X = self.lin1(X) + return X + + torch.manual_seed(0) + return MLP().to(self.torch_device) + + # this works with all adapters except prompt learning, but we don't test all + # as it is unnecessary and would be slow + @pytest.mark.parametrize( + "config", + [ + LoraConfig(init_lora_weights=0, target_modules=["lin0"]), + LoraConfig(init_lora_weights=0, target_modules=["lin0", "lin1"]), + ], + ) + @pytest.mark.parametrize("do_compile", [False, True]) + def test_hotswap_works(self, config, do_compile, tmp_path): + # Load 2 different adapters and check that we can hotswap between them, with the model optionally being + # compiled. + atol, rtol = 1e-4, 1e-4 + inputs = torch.rand(3, 10).to(self.torch_device) + + # create adapter 0 + model = self.get_model() + torch.manual_seed(0) + model = get_peft_model(model, config) + model = self.compile(model, do_compile=do_compile) + model.eval() + with torch.inference_mode(): + output0 = model(inputs) + model.save_pretrained(tmp_path / "adapter0") + + del model + + # create adapter 1 + model = self.get_model() + torch.manual_seed(1) + model = get_peft_model(model, config) + model = self.compile(model, do_compile=do_compile) + model.eval() + with torch.inference_mode(): + output1 = model(inputs) + model.save_pretrained(tmp_path / "adapter1") + + # sanity check: they're not the same + assert not torch.allclose(output0, output1, atol=atol, rtol=rtol) + + del model + + # load adapter 0 + model = self.get_model() + model = PeftModel.from_pretrained(model, tmp_path / "adapter0") + model = self.compile(model, do_compile=do_compile) + with torch.inference_mode(): + output_loaded0 = model(inputs) + + # sanity check: same output after loading for adapter 0 + assert torch.allclose(output0, output_loaded0, atol=atol, rtol=rtol) + + # hotswap with adapter 1 + hotswap_adapter(model, tmp_path / "adapter1", adapter_name="default") + with torch.inference_mode(): + output_loaded1 = model(inputs) + + # real check: model now behaves like adapter 1 + assert torch.allclose(output1, output_loaded1, atol=atol, rtol=rtol) + + # hotswap back to adapter 0 + hotswap_adapter(model, tmp_path / "adapter0", adapter_name="default") + with torch.inference_mode(): + output_loaded_back0 = model(inputs) + + # real check: model now behaves again like adapter 0 + assert torch.allclose(output0, output_loaded_back0, atol=atol, rtol=rtol) + + def test_hotswap_incompatible_config_params_raises(self, tmp_path): + # When the configs of the two adapters are incompatible, an error is raised + config0 = LoraConfig(target_modules=["lin0"], lora_alpha=1.0) + config1 = LoraConfig(target_modules=["lin0"], lora_alpha=2.0) + + model = self.get_model() + model = get_peft_model(model, config0) + model.save_pretrained(tmp_path / "adapter0") + del model + + model = self.get_model() + model = get_peft_model(model, config1) + model.save_pretrained(tmp_path / "adapter1") + del model + + # load adapter 0 + model = self.get_model() + model = PeftModel.from_pretrained(model, tmp_path / "adapter0") + + msg = r"Configs are incompatible: for lora_alpha, 1.0 != 2.0" + with pytest.raises(ValueError, match=msg): + hotswap_adapter(model, tmp_path / "adapter1", adapter_name="default") + + def test_hotswap_different_peft_types_raises(self, tmp_path): + # When the configs of the two adapters are different PEFT methods, raise + config0 = LoraConfig(target_modules=["lin0"]) + config1 = IA3Config(target_modules=["lin0"], feedforward_modules=[]) + + model = self.get_model() + model = get_peft_model(model, config0) + model.save_pretrained(tmp_path / "adapter0") + del model + + model = self.get_model() + model = get_peft_model(model, config1) + model.save_pretrained(tmp_path / "adapter1") + del model + + # load adapter 0 + model = self.get_model() + model = PeftModel.from_pretrained(model, tmp_path / "adapter0") + + msg = r"Incompatible PEFT types found: LORA and IA3" + with pytest.raises(ValueError, match=msg): + hotswap_adapter(model, tmp_path / "adapter1", adapter_name="default") + + def test_hotswap_wrong_peft_types_raises(self, tmp_path): + # Only LoRA is supported at the moment + config0 = IA3Config(target_modules=["lin0"], feedforward_modules=[]) + config1 = IA3Config(target_modules=["lin0"], feedforward_modules=[]) + + model = self.get_model() + model = get_peft_model(model, config0) + model.save_pretrained(tmp_path / "adapter0") + del model + + model = self.get_model() + model = get_peft_model(model, config1) + model.save_pretrained(tmp_path / "adapter1") + del model + + # load adapter 0 + model = self.get_model() + model = PeftModel.from_pretrained(model, tmp_path / "adapter0") + + msg = r"Hotswapping only supports LORA but IA3 was passed" + with pytest.raises(ValueError, match=msg): + hotswap_adapter(model, tmp_path / "adapter1", adapter_name="default") + + def test_hotswap_missing_key_raises(self, tmp_path): + # When a key is missing, raise + config = LoraConfig(target_modules=["lin0", "lin1"]) + + model = self.get_model() + model = get_peft_model(model, config) + model.save_pretrained(tmp_path / "adapter0") + del model + + model = self.get_model() + model = get_peft_model(model, config) + + # remove one key from the state_dict + key = "base_model.model.lin1.lora_A.default.weight" + state_dict = model.state_dict() + del state_dict[key] + model.state_dict = lambda: state_dict + model.save_pretrained(tmp_path / "adapter1") + del model + + # load adapter 0 + model = self.get_model() + model = PeftModel.from_pretrained(model, tmp_path / "adapter0") + + msg = f"Hot swapping the adapter did not succeed. Missing keys: {key}" + with pytest.raises(RuntimeError, match=msg): + hotswap_adapter(model, tmp_path / "adapter1", adapter_name="default") + + def test_hotswap_extra_key_raises(self, tmp_path): + # When there is an extra key, raise + config = LoraConfig(target_modules=["lin0"]) + + model = self.get_model() + model = get_peft_model(model, config) + model.save_pretrained(tmp_path / "adapter0") + del model + + model = self.get_model() + model = get_peft_model(model, config) + + # add an unexpected key + state_dict = model.state_dict() + new_key = "base_model.model.lin1.lora_A.default.weight" + state_dict[new_key] = torch.zeros(8, 20) + model.state_dict = lambda: state_dict + model.save_pretrained(tmp_path / "adapter1") + del model + + # load adapter 0 + model = self.get_model() + model = PeftModel.from_pretrained(model, tmp_path / "adapter0") + + msg = f"Hot swapping the adapter did not succeed. Unexpected keys: {new_key}" + with pytest.raises(RuntimeError, match=msg): + hotswap_adapter(model, tmp_path / "adapter1", adapter_name="default") + + def test_hotswapping_compiled_model_does_not_trigger_recompilation(self): + env = os.environ.copy() + env["TORCH_LOGS"] = "guards,recompiles" + here = os.path.dirname(__file__) + file_name = os.path.join(here, "run_compiled_model_hotswap.py") + + process = subprocess.Popen( + [sys.executable, file_name, "1"], env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + + # Communicate will read the output and error streams, preventing deadlock + stdout, stderr = process.communicate() + exit_code = process.returncode + + # sanity check: + assert exit_code == 0 + + # check that the recompilation message is not present + assert "__recompiles" not in stderr.decode() + + # contingency check: without hotswapping, we *do* get recompilation + process = subprocess.Popen( + [sys.executable, file_name, "0"], env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + + # Communicate will read the output and error streams, preventing deadlock + stdout, stderr = process.communicate() + exit_code = process.returncode + + # sanity check: + assert exit_code == 0 + + # check that the recompilation message is not present + assert "__recompiles" in stderr.decode() + + @pytest.mark.xfail(strict=True, reason="Requires hotswap to be implemented in diffusers") + def test_hotswapping_compiled_diffusion_model_does_not_trigger_recompilation(self): + env = os.environ.copy() + env["TORCH_LOGS"] = "guards,recompiles" + here = os.path.dirname(__file__) + file_name = os.path.join(here, "run_compiled_diffusion_model_hotswap.py") + + process = subprocess.Popen( + [sys.executable, file_name, "1"], env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + + # Communicate will read the output and error streams, preventing deadlock + stdout, stderr = process.communicate() + exit_code = process.returncode + + # sanity check: + assert exit_code == 0 + + # check that the recompilation message is not present + assert "__recompiles" not in stderr.decode() + + # contingency check: without hotswapping, we *do* get recompilation + process = subprocess.Popen( + [sys.executable, file_name, "0"], env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + + # Communicate will read the output and error streams, preventing deadlock + stdout, stderr = process.communicate() + exit_code = process.returncode + + # sanity check: + assert exit_code == 0 + + # check that the recompilation message is not present + assert "__recompiles" in stderr.decode()