-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
The provided peft_type
'PROMPT_TUNING' is not compatible with the PeftMixedModel
.
#2307
Comments
Could you please share what you have tried so far and how it fails? Since prompt tuning is a "prompt learning" technique, I think it is not necessary to support it via |
Hi @BenjaminBossan , so basically i am finetuning a LLM for retrieval. I can finetune with a fixed TASK_PROMPT , but i want to optimize it at the same time, otherwise it will become no more than a constant. TASK_PROMPT would be the virtual tokens in the prompt tuning. I can also try to optimize the prompt on the base model, and then with the optimized prompt finetune the model, but i think i'll get better results if i optimize at the same time. I am also not sure if i can integrate the learned prompt. What do you think? Does it make sense? |
This is an interesting idea. I have never seen this before, so no idea how well it will work. Since this is not exactly an expected use case in PEFT, there is no out of the box way to achieve this. However, with a few lines of code, I think I could create a model that fits your description: from transformers import AutoModelForCausalLM
from peft import LoraConfig, PromptTuningConfig, TaskType, get_peft_model
model_id = "facebook/opt-125m"
base_model = AutoModelForCausalLM.from_pretrained(model_id)
lora_config = LoraConfig(...)
lora_model = get_peft_model(base_model, lora_config)
print("Number of params of LoRA model")
lora_model.print_trainable_parameters()
lora_base_model = lora_model.base_model.model
pt_config = PromptTuningConfig(num_virtual_tokens=20, task_type=TaskType.CAUSAL_LM, ...)
combined_model = get_peft_model(lora_base_model, pt_config)
print("Number of params after applying prompt tuning")
combined_model.print_trainable_parameters()
# prompt tuning freezes all parameters except for the prompt embedding
# re-activate the LoRA parameters to make them trainable
for n, p in combined_model.named_parameters():
if ".lora_" in n:
p.requires_grad_(True)
print("Number of params after re-activating LoRA")
combined_model.print_trainable_parameters() Explaining the non-obvious parts:
Perhaps you can adjust this snippet to your needs and then give this a try. I haven't looked into efficiently saving and loading this model, as is you'd probably have to save the full |
Feature request
PROMPT_TUNING is an useful adapter and it would be great if we can combine it with LORA.
Motivation
Lots of finetunes on consumer grade hardware leverage lora. It would be great we can mix prompt tuning with lora as plug and play.
Your contribution
I would like to submit a PR if there is interest.
The text was updated successfully, but these errors were encountered: