Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inconsistent Parameter Mismatches After Merging PEFT and Base Models #2289

Open
2 of 4 tasks
enhulu-ms opened this issue Dec 19, 2024 · 23 comments
Open
2 of 4 tasks

Inconsistent Parameter Mismatches After Merging PEFT and Base Models #2289

enhulu-ms opened this issue Dec 19, 2024 · 23 comments

Comments

@enhulu-ms
Copy link

enhulu-ms commented Dec 19, 2024

System Info

peft 0.14.0, transformers 4.45.2, accelerate 1.0.1, Python 3.11.9, windows

Who can help?

@BenjaminBossan @sayakpaul

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

from transformers_custom.modeling impor6t CustomConfig
from transformers_custom.tokenization import CustomTokenizer
from transformers_custom.multitask_model import CustomForSequenceClassificationMultitask
from peft import PeftModel
import torch    

def compare_model_params(model1, model2):
    # Extract state dictionaries
    sd1 = model1.state_dict()
    sd2 = model2.state_dict()
    
    # First, check if they have the same keys
    keys1 = set(sd1.keys())
    keys2 = set(sd2.keys())
    
    # Find parameters that are not present in both
    missing_in_model2 = keys1 - keys2
    missing_in_model1 = keys2 - keys1
    
    if missing_in_model2:
        print("Parameters missing in model2:", missing_in_model2)
    if missing_in_model1:
        print("Parameters missing in model1:", missing_in_model1)
        
    # Now compare parameters that exist in both
    mismatch_names = []
    for key in sorted(keys1.intersection(keys2)):
        param1 = sd1[key]
        param2 = sd2[key]
        
        # Check for shape mismatch
        if param1.shape != param2.shape:
            mismatch_names.append(key)
            continue
        
        # Check for value mismatch
        if not torch.allclose(param1, param2):
            print("Mismatched values for parameter:", key, f"model1: {param1}", f"model2: {param2}")
            mismatch_names.append(key)
    
    # Print out results
    if mismatch_names:
        print("Mismatched parameters:", mismatch_names)
    else:
        print("All parameters match perfectly.")

base_model_path = r"C:\models\tms\download\base2"
peft_path = r"C:\models\tms\download\adapter2"
merged_model_path = r"C:\models\tms\download\adapter2_merged\peft_merged"

config = CustomConfig.from_pretrained(
    base_model_path,
    num_labels=8,
    finetuning_task=None,
    cache_dir=None,
    revision="main",
)

base_model = CustomForSequenceClassificationMultitask.from_pretrained(
    base_model_path,
    config=config,
    cache_dir=None,
    revision="main",
) 

peft_model = PeftModel.from_pretrained(base_model, peft_path)

peft_model_merged = peft_model.merge_and_unload()
peft_model_merged.eval()

merged_config = CustomConfig.from_pretrained(
    merged_model_path,
    num_labels=8,
    finetuning_task=None,
    cache_dir=None,
    revision="main",
)

merged_model = CustomForSequenceClassificationMultitask.from_pretrained(
    merged_model_path,
    config=merged_config,
    cache_dir=None,
    revision="main",
)
merged_model.eval()

compare_model_params(peft_model_merged, merged_model)

Expected behavior

I saved the base model and the merged model (using save_pretrained) after training and calling merge_and_unload(). I also saved the PEFT model (via trainer.save_model). After loading the PEFT parameters on top of the base model and calling merge_and_unload(), I compared the newly merged model with the previously saved merged model. Some parameters do not match, and the specific mismatches change with each run to compare models. For example, sometimes the mismatched parameters are ['classifier2.class_dense.bias', 'classifier2.class_dense.weight', ...] and other times ['custom.encoder.layer.19.attention.self.query.weight'].

How can I resolve this issue? Ideally, there should be no mismatches, or at least the mismatches should be consistent across runs.

@enhulu-ms enhulu-ms changed the title PEFT loaded parameters are randomized after merge_and_unload Inconsistent Parameter Mismatches After Merging PEFT and Base Models Dec 19, 2024
@enhulu-ms
Copy link
Author

enhulu-ms commented Dec 19, 2024

Although each run contains different set of components with mismatches, I also noticed that the mismatches across multiple runs appear to have the same mismatched values on the same component. e.g. Mismatched values for parameter: classifier4.out_proj.bias model1: tensor([ 0.0013, -0.0003, -0.0003, -0.0004, -0.0004, -0.0005, -0.0007, -0.0010]) model2: tensor([ 2.3966e-01, 6.4979e-03, 8.9810e-04, 1.0589e-04, -2.9830e-03,
-5.1880e-03, -1.1035e-02, -2.7188e-02]), where model1 is the newly merged model and model2 is merged model after training.

@enhulu-ms
Copy link
Author

Also I noticed that by just adding an irrelevant second model loading without actually using it, all parameters become mismatched... I guess it might be a incorrect memory pointer somewhere in Peft implementation??

peft_model1 = PeftModel.from_pretrained(base_model, peft_path)
peft_model1_merged = peft_model1.merge_and_unload()
peft_model1_merged.eval()

@enhulu-ms
Copy link
Author

enhulu-ms commented Dec 19, 2024

I noticed a really weird behavior while debugging. It seems that both PeftModel.from_pretrained(base_model, peft_path) and peft_model.merge_and_unload() will change the parameters in base_model??? below is the code to reproduce:

base_model = CustomForSequenceClassificationMultitask.from_pretrained(
    base_model_path,
    config=config,
    cache_dir=None,
    revision="main",
) 

base_model1 = CustomForSequenceClassificationMultitask.from_pretrained(
    base_model_path,
    config=config,
    cache_dir=None,
    revision="main",
) 

print("Comparing base_model and base_model1 before loading peft model")
compare_model_params(base_model, base_model1)

peft_model = PeftModel.from_pretrained(base_model, peft_path)
peft_model_merged = peft_model.merge_and_unload()
peft_model_merged.eval()

peft_model1 = PeftModel.from_pretrained(base_model, peft_path)
peft_model1_merged = peft_model1.merge_and_unload()
peft_model1_merged.eval()

print("Comparing base_model and base_model1")
compare_model_params(base_model, base_model1)

It results in mismatches in all components between base_model and base_model1... Any explanation?

@enhulu-ms
Copy link
Author

enhulu-ms commented Dec 19, 2024

for the PEFT configuration, I am using modules_to_save = ["classifier","classifier2","classifier3","classifier4"], and it seems the random mismatches happens mostly in "classifier2","classifier3","classifier4". For example, Mismatched parameters: ['classifier2.class_dense.bias', 'classifier2.class_dense.weight', 'classifier2.out_proj.bias', 'classifier2.out_proj.weight', 'classifier3.class_dense.bias', 'classifier3.class_dense.weight', 'classifier3.out_proj.bias', 'classifier3.out_proj.weight', 'classifier4.class_dense.bias', 'classifier4.class_dense.weight', 'classifier4.out_proj.bias', 'classifier4.out_proj.weight']

@githubnemo
Copy link
Collaborator

Hey :) Thanks for raising an issue.

I noticed a really weird behavior while debugging. It seems that both PeftModel.from_pretrained(base_model, peft_path) and peft_model.merge_and_unload() will change the parameters in base_model??? below is the code to reproduce:

I think this behavior is expected and documented. This is done to save memory on large models. See from_pretrained and merge_and_unload.

Could this already explain the discrepancies you're seeing? It is not really possible for me to reproduce your setup exactly since I don't know what your exact lora config is nor how your model behaves.

@enhulu-ms
Copy link
Author

enhulu-ms commented Dec 19, 2024

@githubnemo , thanks for the explanation. Unfortunately, the behavior of changing base model does not solve or explain the mismatches in this case. So basically, the issue is that I got different model parameters each session I load the PEFT model (load base -> apply lora -> merge_and_unload). I tried loading two models in the same session and those two models are the same. So the discrepancies happens in each session not for each instance in the same session. I also tried setting the random seed and the problem persists. I suspect it is related to modules_to_save function. Anyway, the LoRA configuration I am using is the below:

  • lora_rank: 128, lora_alpha: 256, lora_dropout: 0.1
  • lora target_modules: ['query', 'value', 'key', 'dense'], saved_modules: ['classifier', 'classifier2', 'classifier3', 'classifier4', 'gate_ur_linear']
  • The base model architecture is below
CustomForSequenceClassificationMultitask(
  (Custom): CustomModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(500002, 1024, padding_idx=1)
      (position_embeddings): Embedding(514, 1024, padding_idx=1)
      (token_type_embeddings): Embedding(1, 1024)
      (LayerNorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): CustomEncoder(
      (layer): ModuleList(
        (0-23): 24 x CustomLayer(
          (attention): CustomAttention(
            (self): CustomSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=False)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (gate_ur_linear): Linear(in_features=64, out_features=8, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=1024, out_features=4096, bias=True)
            (intermediate_act_fn): GELUActivation()
          )
          (output): BertOutput(
            (dense): Linear(in_features=4096, out_features=1024, bias=True)
            (LayerNorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
    )
    (pooler): BertPooler(
      (dense): Linear(in_features=1024, out_features=1024, bias=True)
      (activation): Tanh()
    )
    (rel_pos_bias): Linear(in_features=32, out_features=16, bias=False)
  )
  (classifier): CustomClassificationHead(
    (class_dense): Linear(in_features=1024, out_features=1024, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (out_proj): Linear(in_features=1024, out_features=8, bias=True)
  )
  (classifier2): CustomClassificationHead(
    (class_dense): Linear(in_features=1024, out_features=1024, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (out_proj): Linear(in_features=1024, out_features=8, bias=True)
  )
  (classifier3): CustomClassificationHead(
    (class_dense): Linear(in_features=1024, out_features=1024, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (out_proj): Linear(in_features=1024, out_features=8, bias=True)
  )
  (classifier4): CustomClassificationHead(
    (class_dense): Linear(in_features=1024, out_features=1024, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (out_proj): Linear(in_features=1024, out_features=8, bias=True)
  )
)

@githubnemo
Copy link
Collaborator

githubnemo commented Dec 19, 2024

I think this is expected. If CustomForSequenceClassificationMultitask works similar to AutoModelForSequenceClassification then a model, say gpt2, will receive a new (untrained, freshly initialized) classification head. That would be classification* in your case. If you use the task type SEQ_CLS in your LoraConfig the classification head(s) will automatically be added to LoraConfig.modules_to_save, thus, they will be saved but they are not adapters. Therefore they are not merged unto the base model.

Therefore the mystery is that you are comparing base models that have merged adapters but differently initialized classification heads which, of course, differ.

You should not see a difference when comparing the PeftModel.from_pretrained() instances.

peft_model1 = PeftModel.from_pretrained(base_model1, peft_path)
peft_model1_merged = peft_model1.merge_and_unload()
peft_model1_merged.eval()

peft_model2 = PeftModel.from_pretrained(base_model2, peft_path)
peft_model2_merged = peft_model2.merge_and_unload()
peft_model2_merged.eval()

print("Comparing base_model1 and base_model2")
compare_model_params(base_model1, base_model2) # expecting a difference

print("Comparing peft_model1 and peft_model2")
compare_model_params(peft_model1, peft_model2) # no difference

@enhulu-ms
Copy link
Author

enhulu-ms commented Dec 19, 2024

@githubnemo , the CustomForSequenceClassificationMultitask does include the parameters for classification heads. So there should not be any initialization for classification head. What I am comparing with is the previously trained-merged model and the loaded-merged model. The trained-merged model is saved as CustomForSequenceClassificationMultitask, so no initialization for it either. I already mentioned that within the same session, if I load two models (load base -> apply lora -> merge_and_unload), there is no difference within the session. However, in each session, I am comparing the loaded model with the checkpoint of the trained-merged one which should not change across sessions as I just load it from checkpoint. below is the same code attached in the description of the issue. For each session, the mismatched components are different. If it is initialization issue then the mismatched components should stay the same. Or do you have any explanation on that?

base_model_path = r"C:\models\tms\download\base2"
peft_path = r"C:\models\tms\download\adapter2"
merged_model_path = r"C:\models\tms\download\adapter2_merged\peft_merged"

config = CustomConfig.from_pretrained(
    base_model_path,
    num_labels=8,
    finetuning_task=None,
    cache_dir=None,
    revision="main",
)

base_model = CustomForSequenceClassificationMultitask.from_pretrained(
    base_model_path,
    config=config,
    cache_dir=None,
    revision="main",
) 

peft_model = PeftModel.from_pretrained(base_model, peft_path)
peft_model_merged = peft_model.merge_and_unload()
peft_model_merged.eval()

merged_config = CustomConfig.from_pretrained(
    merged_model_path,
    num_labels=8,
    finetuning_task=None,
    cache_dir=None,
    revision="main",
)

merged_model = CustomForSequenceClassificationMultitask.from_pretrained(
    merged_model_path,
    config=merged_config,
    cache_dir=None,
    revision="main",
)
merged_model.eval()

print("Comparing base_model and base_model1")
compare_model_params(base_model, base_model1)

@githubnemo
Copy link
Collaborator

If I understand you correctly you are wondering why the classification heads are replaced even though you are passing pretrained classifiers. That is understandable and surprising when you simply want to fine-tune the in-between layers instead of the classification heads. PEFT assumes that if your task type is classification (LoraConfig.task_type) that the classification* or score* layers want to be trained as well.

What you are seeing is because you probably set the task type to SEQ_CLS and once you use get_peft_model the classification heads are being retrained as well.

Try setting LoraConfig.task_type to None or using the PeftModel class directly when training your adapters. You should get your expected behavior then.

@enhulu-ms
Copy link
Author

enhulu-ms commented Dec 20, 2024

@githubnemo, when i was training the model, I did not pass task_type. So the default should be None? The pretrained model I load includes classification head parameters as well. Also I put classification heads into modules_to_save. Do you mean that the default task_type is not None in this case? I checked the saved adapter_config.json, which seems to be null

target_modules = ["query","value","key","dense", "gate_ur_linear"]
saved_modules = ["classifier","classifier2","classifier3","classifier4"]
peft_config = LoraConfig(
r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, target_modules=target_modules,, modules_to_save=saved_modules
 )
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

below is the saved adapter_config.json

{
  "alpha_pattern": {},
  "auto_mapping": {
    "base_model_class": "CustomForSequenceClassificationMultitask",
    "parent_library": "transformers_custom.multitask_model"
  },
  "base_model_name_or_path": "<local>",
  "bias": "none",
  "eva_config": null,
  "exclude_modules": null,
  "fan_in_fan_out": false,
  "inference_mode": true,
  "init_lora_weights": true,
  "layer_replication": null,
  "layers_pattern": null,
  "layers_to_transform": null,
  "loftq_config": {},
  "lora_alpha": 256,
  "lora_bias": false,
  "lora_dropout": 0.1,
  "megatron_config": null,
  "megatron_core": "megatron.core",
  "modules_to_save": [
    "classifier",
    "classifier2",
    "classifier3",
    "classifier4"
  ],
  "peft_type": "LORA",
  "r": 128,
  "rank_pattern": {},
  "revision": null,
  "target_modules": [
    "key",
    "query",
    "value",
    "dense"
  ],
  "task_type": null,
  "use_dora": false,
  "use_rslora": false
}

@githubnemo
Copy link
Collaborator

Ah, yes. Thanks for providing more context. I think you assume that modules_to_save means whichever modules you put in will simply be saved. That is not the case, though:

modules_to_save: List of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint. These typically include model’s custom head that is randomly initialized for the fine-tuning task.

So everything you pass in that list will be trained again. It is the same effect as passing the classification task type, for example, which will automatically add the classification head to modules_to_save which is why I mentioned the task_type parameter.

Not passing the classification heads to modules_to_save should give you the expected behavior.

@enhulu-ms
Copy link
Author

@githubnemo I do expect that everything I put in modules_to_save should be trained. I did train them and saved them. However, it does not explain that every time I load them from a checkpoint I got different values for each session...

@BenjaminBossan
Copy link
Member

@enhulu-ms Thanks for the patience in explaining your problem, which indeed looks bizarre. It is very hard for us to debug it properly since you use a custom model with custom layers. Would it be possible for you to create a reproducer that uses an open model and go through the steps? Please be very precise in describing how to call the scripts, since you mention that the discrepancy appears to be session based.

Moreover, you showed a numerical example of the bias differing, and I noticed that the values were all very close to 0. Did you see other numerical differences that are more substantial or could this be somehow related to numerical imprecision?

@enhulu-ms
Copy link
Author

@BenjaminBossan, the numerical example is just to illustrate the issue. What I care is the accuracy of the model, which I do see significant drop compared with merged model. About using open models, do you mean using the default classification head for single task?

@enhulu-ms
Copy link
Author

@BenjaminBossan, the discrepancy does appear in each session. Just within the session, if I load the same checkpoint multiple times, there are consistent.

@BenjaminBossan
Copy link
Member

@enhulu-ms What I mean is: Would it be possible to replace your custom model with an openly available model (e.g. one of the Llama models) and share a self-contained script that illustrates the issue?

@enhulu-ms
Copy link
Author

@BenjaminBossan, I found even easier way to reproduce. No need to train, I just load base -> apply lora ckpt -> merge_and_unload to get merged_model_1 and saved, then on the second session I will do the same to load base -> apply lora ckpt -> merge_and_unload to get merged_model_2, then I will load the saved merged_model_1 and compare with merged_model_2. There is mismatch!! You can use my script to do so, no need to even train...

@enhulu-ms
Copy link
Author

enhulu-ms commented Jan 8, 2025

@BenjaminBossan, if you could provide me an example base model file, its lora file, and the code to load base and apply lora, then I can do the rest to reproduce. I prefer a small model e.g. bert.

@BenjaminBossan
Copy link
Member

Thanks for describing the exact steps @enhulu-ms. Based on this, I created a script that walks through these exact steps. I re-used your compare_model_params method from above.

# script.py

import copy
import sys

import torch
from transformers import AutoModelForSequenceClassification
from peft import LoraConfig, TaskType, get_peft_model

# this function is the same as posted in the issue
def compare_model_params(model1, model2):
    # Extract state dictionaries
    sd1 = model1.state_dict()
    sd2 = model2.state_dict()

    # First, check if they have the same keys
    keys1 = set(sd1.keys())
    keys2 = set(sd2.keys())

    # Find parameters that are not present in both
    missing_in_model2 = keys1 - keys2
    missing_in_model1 = keys2 - keys1

    if missing_in_model2:
        print("Parameters missing in model2:", missing_in_model2)
    if missing_in_model1:
        print("Parameters missing in model1:", missing_in_model1)

    # Now compare parameters that exist in both
    mismatch_names = []
    for key in sorted(keys1.intersection(keys2)):
        param1 = sd1[key]
        param2 = sd2[key]

        # Check for shape mismatch
        if param1.shape != param2.shape:
            mismatch_names.append(key)
            continue

        # Check for value mismatch
        if not torch.allclose(param1, param2):
            print("Mismatched values for parameter:", key, f"model1: {param1}", f"model2: {param2}")
            mismatch_names.append(key)

    # Print out results
    if mismatch_names:
        print("Mismatched parameters:", mismatch_names)
    else:
        print("All parameters match perfectly.")

def create_lora_adapter(model_id, path):
    torch.manual_seed(0)
    # load a model with a classification head
    base_model = AutoModelForSequenceClassification.from_pretrained(model_id, num_labels=2)
    base_model_copy = copy.deepcopy(base_model)

    # create a LoRA adapter
    lora_config = LoraConfig(
        task_type=TaskType.SEQ_CLS,
        target_modules="all-linear",
        init_lora_weights=False,  # to ensure that LoRA_B is not zeros
    )
    model = get_peft_model(base_model, lora_config)

    # merge the LoRA weights and save
    merged_and_unloaded_1 = model.merge_and_unload()
    print("Sanity check: Comparing the base model to the merged PEFT model: The parameters should be different")
    compare_model_params(base_model_copy, merged_and_unloaded_1)

    merged_and_unloaded_1.save_pretrained(path)
    print(f"Saved the merged model to {path}")

def load_lora_adapter(model_id, path):
    # load the model from the checkpoint
    torch.manual_seed(0)
    merged_and_unloaded_1 = AutoModelForSequenceClassification.from_pretrained(path)
    print("Loaded the merged model")

    # re-create the model as previously
    torch.manual_seed(0)
    base_model = AutoModelForSequenceClassification.from_pretrained(model_id, num_labels=2)
    lora_config = LoraConfig(
        task_type=TaskType.SEQ_CLS,
        target_modules="all-linear",
        init_lora_weights=False,
    )
    model = get_peft_model(base_model, lora_config)
    merged_and_unloaded_2 = model.merge_and_unload()

    print("The loaded model and the re-created model should have the same parameters")
    compare_model_params(merged_and_unloaded_1, merged_and_unloaded_2)

if __name__ == "__main__":
    model_id = sys.argv[1]
    path = sys.argv[2]
    mode = sys.argv[3]
    if mode == "create":
        create_lora_adapter(model_id, path)
    elif mode == "load":
        load_lora_adapter(model_id, path)
    else:
        raise ValueError("Invalid mode. Use 'create' or 'load'.")

As you can see, this script mirrors the 2 sessions that you mentioned. First, call this script in "create" mode, where it creates the merged model. The first argument is the model id, second argument is the path where to save, and third is the mode:

python script.py "facebook/opt-125m" /tmp/peft/2289 create

Note that here I perform a sanity check that the merged model must be different from the base model, or else the check would pass trivially.

Then, in the 2nd session, we use the "load" mode to load the merged model and compare it to a newly created one using the same random seed:

python script.py "facebook/opt-125m" /tmp/peft/2289 load

When we compare these models, we should expect them to be the same, and this is indeed the case when I run this locally.

@enhulu-ms
Copy link
Author

@BenjaminBossan, thanks for trying to reproduce. This is really weird issue. If I use a newly initialized lora weights, I cannot reproduce the issue either. It seems that the key to reproduce is to use the trained lora adapter. In my case, I did the following four experiments to check if newly merged model has any mismatches with the saved-merged model. Each experiment I run three sessions with the same settings. Using newly initialized base + newly initialized lora (pass/pass/pass), using original base + newly initialized lora (pass/pass/pass), using original base + trained lora (pass/fail/fail), using newly initialized base + trained lora (pass/fail/fail). It seems that the trained adapter has an important role to reproduce the issue. And the issue is not always reproducible, but three sessions can hit at least one so far...

@enhulu-ms
Copy link
Author

And the discrepancy happens mostly in classification heads but never happens to the classifier... only to classifer2/classifer3/classifer4 ... I also tried to save the adapter again once the discrepancy happens, the saved adapter is the same as previously saved adapter...

@BenjaminBossan
Copy link
Member

You need to provide more details on what exactly you did or else I cannot try to reproduce the error. Did you modify my script and work based on that? If so, please share your modification and describe each step precisely.

Note that if you add a training step in my script inside of create_lora_adapter, the weights are expected to be different because inside of load_lora_adapter, we compare to a freshly created LoRA adapter.

@enhulu-ms
Copy link
Author

Unfortunately I cannot share the training code nor the model in this case... Once the lora adapter is trained. then I use the following code to try

from peft import PeftModel, get_peft_model
from transformers_custom.multitask_model import CustomSequenceClassificationMultitask
import torch    
import random

import os
import shutil
# random.seed(10)
# print(random.random())

def compare_model_params(model1, model2):
    # Extract state dictionaries
    sd1 = model1.state_dict()
    sd2 = model2.state_dict()
    
    # First, check if they have the same keys
    keys1 = set(sd1.keys())
    keys2 = set(sd2.keys())
    
    # Find parameters that are not present in both
    missing_in_model2 = keys1 - keys2
    missing_in_model1 = keys2 - keys1
    
    if missing_in_model2:
        print("Parameters missing in model2:", missing_in_model2)
    if missing_in_model1:
        print("Parameters missing in model1:", missing_in_model1)
        
    # Now compare parameters that exist in both
    mismatch_names = []
    for key in sorted(keys1.intersection(keys2)):
        param1 = sd1[key]
        param2 = sd2[key]
        
        # Check for shape mismatch
        if param1.shape != param2.shape:
            mismatch_names.append(key)
            continue
        
        # Check for value mismatch
        if not torch.allclose(param1, param2):
            print("Mismatched values for parameter:", key, f"model1: {param1}", f"model2: {param2}")
            mismatch_names.append(key)
    
    # Print out results
    if len(mismatch_names)>0:
        print("Mismatched parameters:", mismatch_names)
    else:
        print("All shared parameters match perfectly.")

base_model_path = r"C:\models\base"
lora_model_path = r"C:\models\adapter1"
merged_model_path = r"C:\models\adapter1_merged"
merged_newly_model_path =r"c:\models\adapter1_merged_newly"
lora_model_save_path = r"C:\models\adapter1_newly_saved"
if os.path.exists(lora_model_save_path):
    shutil.rmtree(lora_model_save_path)

print(f"Load base model from {base_model_path}")
base_model = CustomSequenceClassificationMultitask.from_pretrained(base_model_path)
print(f"Load peft model from {lora_model_path}")
peft_model = PeftModel.from_pretrained(base_model, lora_model_path)

print(f"Save peft model to {lora_model_save_path}")
peft_model.save_pretrained(lora_model_save_path)

print("Merge peft model")
peft_model_merged = peft_model.merge_and_unload()

print(f'Save merged model to {merged_newly_model_path}')
peft_model_merged.save_pretrained(merged_newly_model_path)

print("Load previously merged model")
peft_model_loaded = CustomSequenceClassificationMultitask.from_pretrained(merged_model_path)
print("Compare previously merged model with newly merged model")
compare_model_params(peft_model_merged, peft_model_loaded)



Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants