From 5cda3a883c9d5de2381aebca62b87eee8be9f20a Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Fri, 1 Nov 2024 10:48:00 +0100 Subject: [PATCH] FIX: Prefix tuning with model on multiple devices (#2189) See #2134 After introducing the usage of DynamicCache for prefix tuning, a bug could now occur if the model is dispatched to different devices. This is because we need to move the key and value cache for each layer to that layer's respective device. The new code mostly consists of code copied from transformers to be consistent with how transformers solves this. --- src/peft/peft_model.py | 2 ++ src/peft/utils/__init__.py | 1 + src/peft/utils/integrations.py | 50 +++++++++++++++++++++++++++ tests/test_gpu_examples.py | 63 +++++++++++++++++++++++++++++++++- 4 files changed, 115 insertions(+), 1 deletion(-) diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index a38e0750f2..481db13d4a 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -79,6 +79,7 @@ id_tensor_storage, infer_device, load_peft_weights, + map_cache_to_layer_device_map, set_peft_model_state_dict, shift_tokens_right, ) @@ -742,6 +743,7 @@ def get_prompt(self, batch_size: int, task_ids: Optional[torch.Tensor] = None) - past_key_values.is_updated = { layer_idx: False for layer_idx in range(len(past_key_values.cross_attention_cache.key_cache)) } + map_cache_to_layer_device_map(self.get_base_model(), past_key_values) # no-op if not a Cache instance return past_key_values else: if peft_config.peft_type == PeftType.MULTITASK_PROMPT_TUNING: diff --git a/src/peft/utils/__init__.py b/src/peft/utils/__init__.py index 7d284607c0..63a7216168 100644 --- a/src/peft/utils/__init__.py +++ b/src/peft/utils/__init__.py @@ -18,6 +18,7 @@ # limitations under the License. # from .config import PeftConfig, PeftType, PromptLearningConfig, TaskType +from .integrations import map_cache_to_layer_device_map from .loftq_utils import replace_lora_weights_loftq from .peft_types import PeftType, TaskType from .other import ( diff --git a/src/peft/utils/integrations.py b/src/peft/utils/integrations.py index bf9bd2aecc..df65084608 100644 --- a/src/peft/utils/integrations.py +++ b/src/peft/utils/integrations.py @@ -120,3 +120,53 @@ def get_bnb_param_type(param: torch.nn.Parameter) -> Literal[False, "4bit", "8bi if param.__class__.__name__ == "Int8Params": return "8bit" return False + + +# adapted from: +# https://github.com/huggingface/transformers/blob/eab6c491d439e83d5e31c660df6f7e36592eb0a2/src/transformers/generation/utils.py#L1617-L1643 +def get_layer_device_map(model): + """ + Derive the device map for the layers of the model. + """ + main_device = [d for d in model.hf_device_map.values() if d not in ["cpu", "disk"]][0] + + execution_device_map = { + name: main_device if device in ["cpu", "disk"] else device for name, device in model.hf_device_map.items() + } + + if execution_device_map is None: + return None + + if len(execution_device_map) == 1 and "" in execution_device_map: + return {idx: execution_device_map[""] for idx in range(model.config.num_hidden_layers)} + + layer_device_map = {} + for layer in execution_device_map: + for idx in range(model.config.num_hidden_layers): + if f".{idx}." in f"{layer}.": + layer_device_map[idx] = execution_device_map[layer] + break + for idx in range(model.config.num_hidden_layers): + if idx not in layer_device_map: + raise RuntimeError(f"layer {idx} has not been mapped to a device.") + return layer_device_map + + +# adapted from: +# https://github.com/huggingface/transformers/blob/eab6c491d439e83d5e31c660df6f7e36592eb0a2/src/transformers/cache_utils.py#L1159-L1179 +def map_cache_to_layer_device_map(model, cache) -> None: + """ + Ensure that the key and value cache of the model are on the same device as their corresponding layers. + """ + if not (isinstance(cache, transformers.Cache) and hasattr(model, "hf_device_map")): + return + + if isinstance(cache, transformers.EncoderDecoderCache): + map_cache_to_layer_device_map(model, cache.self_attention_cache) + return + + layer_device_map = get_layer_device_map(model) + for idx in range(model.config.num_hidden_layers): + layer_device = layer_device_map[idx] + cache.key_cache[idx] = cache.key_cache[idx].to(layer_device) + cache.value_cache[idx] = cache.value_cache[idx].to(layer_device) diff --git a/tests/test_gpu_examples.py b/tests/test_gpu_examples.py index 2e485654e7..2a8530e4ae 100644 --- a/tests/test_gpu_examples.py +++ b/tests/test_gpu_examples.py @@ -1,4 +1,4 @@ -# Copyright 2023-present the HuggingFace Inc. team.# +# Copyright 2023-present the HuggingFace Inc. team. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -53,6 +53,7 @@ LoftQConfig, LoraConfig, PeftModel, + PrefixTuningConfig, PromptEncoderConfig, TaskType, VeraConfig, @@ -3888,3 +3889,63 @@ def test_low_cpu_mem_usage_model_model_on_gpu_state_dict_on_cpu_works(self, devi assert torch.allclose(logits_low_cpu_mem, logits_not_low_cpu_mem) assert {p.device.type for p in model.parameters()} == {device_model} + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires a GPU") +@pytest.mark.multi_gpu_tests +class TestPrefixTuning: + def test_prefix_tuning_multiple_devices_decoder_model(self): + # See issue 2134 + model_id = "hf-internal-testing/tiny-random-MistralForCausalLM" + tokenizer = AutoTokenizer.from_pretrained(model_id, padding="left") + inputs = tokenizer(["A list of colors: red, blue"], return_tensors="pt").to("cuda") + + device_map = { + "model.embed_tokens": 0, + "model.layers.0": 0, + "model.layers.1": 1, + "model.norm": 1, + "lm_head": 1, + } + model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device_map) + # sanity check, as the test passes trivially for a single device + assert len({p.device for p in model.parameters()}) > 1 + # sanity check: this should work without peft + model.generate(**inputs) # does not raise + + peft_config = PrefixTuningConfig(num_virtual_tokens=10, task_type="CAUSAL_LM") + model = get_peft_model(model, peft_config) + model.generate(**inputs) # does not raise + + def test_prefix_tuning_multiple_devices_encoder_decoder_model(self): + # See issue 2134 + model_id = "hf-internal-testing/tiny-random-T5Model" + tokenizer = AutoTokenizer.from_pretrained(model_id, padding="left") + inputs = tokenizer(["A list of colors: red, blue"], return_tensors="pt").to("cuda") + device_map = { + "shared": 0, + "encoder.embed_tokens": 0, + "encoder.block.0": 0, + "encoder.block.1": 0, + "encoder.block.2": 1, + "encoder.block.3": 1, + "encoder.block.4": 1, + "encoder.final_layer_norm": 1, + "decoder.embed_tokens": 0, + "decoder.block.0": 0, + "decoder.block.1": 0, + "decoder.block.2": 1, + "decoder.block.3": 1, + "decoder.block.4": 1, + "decoder.final_layer_norm": 1, + "lm_head": 0, + } + model = AutoModelForSeq2SeqLM.from_pretrained(model_id, device_map=device_map) + # sanity check, as the test passes trivially for a single device + assert len({p.device for p in model.parameters()}) > 1 + # sanity check: this should work without peft + model.generate(**inputs) # does not raise + + peft_config = PrefixTuningConfig(num_virtual_tokens=10, task_type="SEQ_2_SEQ_LM") + model = get_peft_model(model, peft_config) + model.generate(**inputs) # does not raise