Skip to content

Commit

Permalink
FIX: Prefix tuning with model on multiple devices (#2189)
Browse files Browse the repository at this point in the history
See #2134

After introducing the usage of DynamicCache for prefix tuning, a bug
could now occur if the model is dispatched to different devices. This is
because we need to move the key and value cache for each layer to that
layer's respective device.

The new code mostly consists of code copied from transformers to be
consistent with how transformers solves this.
  • Loading branch information
BenjaminBossan authored Nov 1, 2024
1 parent 8eeae0a commit 5cda3a8
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 1 deletion.
2 changes: 2 additions & 0 deletions src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
id_tensor_storage,
infer_device,
load_peft_weights,
map_cache_to_layer_device_map,
set_peft_model_state_dict,
shift_tokens_right,
)
Expand Down Expand Up @@ -742,6 +743,7 @@ def get_prompt(self, batch_size: int, task_ids: Optional[torch.Tensor] = None) -
past_key_values.is_updated = {
layer_idx: False for layer_idx in range(len(past_key_values.cross_attention_cache.key_cache))
}
map_cache_to_layer_device_map(self.get_base_model(), past_key_values) # no-op if not a Cache instance
return past_key_values
else:
if peft_config.peft_type == PeftType.MULTITASK_PROMPT_TUNING:
Expand Down
1 change: 1 addition & 0 deletions src/peft/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# limitations under the License.

# from .config import PeftConfig, PeftType, PromptLearningConfig, TaskType
from .integrations import map_cache_to_layer_device_map
from .loftq_utils import replace_lora_weights_loftq
from .peft_types import PeftType, TaskType
from .other import (
Expand Down
50 changes: 50 additions & 0 deletions src/peft/utils/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,53 @@ def get_bnb_param_type(param: torch.nn.Parameter) -> Literal[False, "4bit", "8bi
if param.__class__.__name__ == "Int8Params":
return "8bit"
return False


# adapted from:
# https://github.com/huggingface/transformers/blob/eab6c491d439e83d5e31c660df6f7e36592eb0a2/src/transformers/generation/utils.py#L1617-L1643
def get_layer_device_map(model):
"""
Derive the device map for the layers of the model.
"""
main_device = [d for d in model.hf_device_map.values() if d not in ["cpu", "disk"]][0]

execution_device_map = {
name: main_device if device in ["cpu", "disk"] else device for name, device in model.hf_device_map.items()
}

if execution_device_map is None:
return None

if len(execution_device_map) == 1 and "" in execution_device_map:
return {idx: execution_device_map[""] for idx in range(model.config.num_hidden_layers)}

layer_device_map = {}
for layer in execution_device_map:
for idx in range(model.config.num_hidden_layers):
if f".{idx}." in f"{layer}.":
layer_device_map[idx] = execution_device_map[layer]
break
for idx in range(model.config.num_hidden_layers):
if idx not in layer_device_map:
raise RuntimeError(f"layer {idx} has not been mapped to a device.")
return layer_device_map


# adapted from:
# https://github.com/huggingface/transformers/blob/eab6c491d439e83d5e31c660df6f7e36592eb0a2/src/transformers/cache_utils.py#L1159-L1179
def map_cache_to_layer_device_map(model, cache) -> None:
"""
Ensure that the key and value cache of the model are on the same device as their corresponding layers.
"""
if not (isinstance(cache, transformers.Cache) and hasattr(model, "hf_device_map")):
return

if isinstance(cache, transformers.EncoderDecoderCache):
map_cache_to_layer_device_map(model, cache.self_attention_cache)
return

layer_device_map = get_layer_device_map(model)
for idx in range(model.config.num_hidden_layers):
layer_device = layer_device_map[idx]
cache.key_cache[idx] = cache.key_cache[idx].to(layer_device)
cache.value_cache[idx] = cache.value_cache[idx].to(layer_device)
63 changes: 62 additions & 1 deletion tests/test_gpu_examples.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023-present the HuggingFace Inc. team.#
# Copyright 2023-present the HuggingFace Inc. team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down Expand Up @@ -53,6 +53,7 @@
LoftQConfig,
LoraConfig,
PeftModel,
PrefixTuningConfig,
PromptEncoderConfig,
TaskType,
VeraConfig,
Expand Down Expand Up @@ -3888,3 +3889,63 @@ def test_low_cpu_mem_usage_model_model_on_gpu_state_dict_on_cpu_works(self, devi

assert torch.allclose(logits_low_cpu_mem, logits_not_low_cpu_mem)
assert {p.device.type for p in model.parameters()} == {device_model}


@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires a GPU")
@pytest.mark.multi_gpu_tests
class TestPrefixTuning:
def test_prefix_tuning_multiple_devices_decoder_model(self):
# See issue 2134
model_id = "hf-internal-testing/tiny-random-MistralForCausalLM"
tokenizer = AutoTokenizer.from_pretrained(model_id, padding="left")
inputs = tokenizer(["A list of colors: red, blue"], return_tensors="pt").to("cuda")

device_map = {
"model.embed_tokens": 0,
"model.layers.0": 0,
"model.layers.1": 1,
"model.norm": 1,
"lm_head": 1,
}
model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device_map)
# sanity check, as the test passes trivially for a single device
assert len({p.device for p in model.parameters()}) > 1
# sanity check: this should work without peft
model.generate(**inputs) # does not raise

peft_config = PrefixTuningConfig(num_virtual_tokens=10, task_type="CAUSAL_LM")
model = get_peft_model(model, peft_config)
model.generate(**inputs) # does not raise

def test_prefix_tuning_multiple_devices_encoder_decoder_model(self):
# See issue 2134
model_id = "hf-internal-testing/tiny-random-T5Model"
tokenizer = AutoTokenizer.from_pretrained(model_id, padding="left")
inputs = tokenizer(["A list of colors: red, blue"], return_tensors="pt").to("cuda")
device_map = {
"shared": 0,
"encoder.embed_tokens": 0,
"encoder.block.0": 0,
"encoder.block.1": 0,
"encoder.block.2": 1,
"encoder.block.3": 1,
"encoder.block.4": 1,
"encoder.final_layer_norm": 1,
"decoder.embed_tokens": 0,
"decoder.block.0": 0,
"decoder.block.1": 0,
"decoder.block.2": 1,
"decoder.block.3": 1,
"decoder.block.4": 1,
"decoder.final_layer_norm": 1,
"lm_head": 0,
}
model = AutoModelForSeq2SeqLM.from_pretrained(model_id, device_map=device_map)
# sanity check, as the test passes trivially for a single device
assert len({p.device for p in model.parameters()}) > 1
# sanity check: this should work without peft
model.generate(**inputs) # does not raise

peft_config = PrefixTuningConfig(num_virtual_tokens=10, task_type="SEQ_2_SEQ_LM")
model = get_peft_model(model, peft_config)
model.generate(**inputs) # does not raise

0 comments on commit 5cda3a8

Please sign in to comment.