-
Notifications
You must be signed in to change notification settings - Fork 1.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix FSDP error #1196
Fix FSDP error #1196
Conversation
Fixes error when `loss` field of model output is non-empty, and indexing as [0] returns loss instead of logits. Can happen with FSDP.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the fix, I propose to force-set return_dict=True
in case users train their RM with return_dict
being explicitly set to False in the config. Can you try that out? 🙏
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
c9ed225
to
989f4a4
Compare
force return_dict Co-authored-by: Younes Belkada <[email protected]>
Yes, that works! Thanks for spotting that. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks again for all your contributions @mgerstgrasser ! You are on fire!
* Fix FSDP error Fixes error when `loss` field of model output is non-empty, and indexing as [0] returns loss instead of logits. Can happen with FSDP. * Apply suggestions from code review force return_dict Co-authored-by: Younes Belkada <[email protected]> --------- Co-authored-by: Younes Belkada <[email protected]>
This fixes an error in reward training when the
loss
field of model output is non-empty, andmodel(...)[0]
returns loss instead of logits. That can happen with FSDP. This PR changes it so the trainer explicitly grabsmodel(...)["logits"]
instead.Closes #1195