Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KeyError in load_model() #717

Open
ManuelZ opened this issue Sep 16, 2024 · 2 comments
Open

KeyError in load_model() #717

ManuelZ opened this issue Sep 16, 2024 · 2 comments

Comments

@ManuelZ
Copy link

ManuelZ commented Sep 16, 2024

This line is generating a RuntimeError:

RuntimeError: Error(s) in loading state_dict for DataParallel:
Missing key(s) in state_dict: "module.conv1.weight", "module.bn1.weight", ...

This happened after running this example, replacing trainer.train(num_epochs=num_epochs) by;

start_epoch = hooks.load_latest_saved_models(trainer, model_folder, device)
trainer.train(start_epoch, num_epochs=num_epochs)

And running trainer.train(start_epoch, num_epochs=num_epochs) a second time.

Also note that I ran the example:

  • Without pip install pytorch-metric-learning[with-hooks] because that downgraded my PML.
  • With dataloader_num_workers=0.

PyTorch version: 2.4.1+cu121
PyTorch Metric Learning version: 2.6.0

@ManuelZ ManuelZ changed the title load_model() catches KeyError instead of RuntimeError KeyError in load_model() Sep 16, 2024
@ManuelZ
Copy link
Author

ManuelZ commented Sep 16, 2024

I guess my use of dataloader_num_workers=0 is causing the keys of the weights dictionary to be saved without the module. part, and since the trunk and embedder are wrapped in torch.nn.DataParallel, model_def.load_state_dict expects a dictionary with keys prepended with module..

My problem is solved by not using torch.nn.DataParallel, but I don't know if this is a behavior that could be fixed in PML or if I was just using it incorrectly.

@KevinMusgrave
Copy link
Owner

KevinMusgrave commented Oct 15, 2024

I'm not sure what's causing this. I recommend using another library for the training loop, like Lightning, Ignite, Transformers, or Timm. You can also look at Open Metric Learning which I think has training loops that are compatible with many parts of this library.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants