Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

state_dict keys don't match load_state_dict #865

Closed
samuele-bortolato opened this issue Oct 12, 2023 · 4 comments
Closed

state_dict keys don't match load_state_dict #865

samuele-bortolato opened this issue Oct 12, 2023 · 4 comments

Comments

@samuele-bortolato
Copy link

I'm using the AutoModelForCausalLMWithValueHead with a custom training loop, I was trying to copy the weights of the model to a target model with the same architecture when I noticed that the keys from state_dict don't match those for load_state_dict

model.load_state_dict(model.state_dict())

RuntimeError: Error(s) in loading state_dict for AutoModelForCausalLMWithValueHead:
	Missing key(s) in state_dict: "pretrained_model.base_model.model.model.decoder.embed_tokens.weight", ....
	Unexpected key(s) in state_dict: "base_model.model.model.decoder.embed_tokens.weight", ....

apparently it's just a key error, and modifying the state_dict function solves it

def state_dict(self, *args, **kwargs):
    r"""
    Returns the state dictionary of the model. We add the state dictionary of the value head
    to the state dictionary of the wrapped model by prepending the key with `v_head.`.
    """

    state_dict = {}

    if not self.is_peft_model:
        pretrained_model_state_dict = self.pretrained_model.state_dict(*args, **kwargs)
    else:
        # if it is a peft model, only save the v_head
        pretrained_model_state_dict = {}
    for k, v in pretrained_model_state_dict.items():
        state_dict[f"pretrained_model.{k}"] = v

    v_head_state_dict = self.v_head.state_dict(*args, **kwargs)
    for k, v in v_head_state_dict.items():
        state_dict[f"v_head.{k}"] = v

    return state_dict

I don't know if this was done to make some high level api work, but I feel like the basic api from pytorch should still be made compatible

@lvwerra
Copy link
Member

lvwerra commented Oct 13, 2023

I remember @younesbelkada had to work on that to make the models work in trl. Maybe this could be fixed, what do you think?

@younesbelkada
Copy link
Contributor

It looks like you are trying to load from a peft model state dict? in that case you only need to load the v_head as all other parameters are kept untouched right?

model.load_state_dict(model.state_dict(), strict=False)

Should do the trick

@samuele-bortolato
Copy link
Author

Actually I think all parameters are trained, not only the v_head, since I'm using the language modelling as the policy (im not doing RLHF following the tutorials, I'm doing research on RL with LLMs, I have a custom training loop).
To make it work I ended up completely overwriting the method substituting it with the original state_dict() from the torch Module

import torch
from trl import AutoModelForCausalLMWithValueHead

AutoModelForCausalLMWithValueHead.state_dict = torch.nn.Module.state_dict

I'm now wondering what was the reason to change it in the first place.

Copy link

github-actions bot commented Dec 1, 2023

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants