-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Before update the tr_loss, make sure tr_loss_step is in the same device. #1439
Before update the tr_loss, make sure tr_loss_step is in the same device. #1439
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a suggestion
Co-authored-by: guy1992l <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great ! Thanks so much @pengwei715 @guy1992l !
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
just double checking @kashif is this all good in your opinion? |
…ce. (huggingface#1439) * before update the loss from dpo, make sure it's in the same device of tr_loss * Update trl/trainer/dpo_trainer.py Co-authored-by: guy1992l <[email protected]> --------- Co-authored-by: guy1992l <[email protected]>
What does this PR do?
Fixes: #958
Fixes: #1399
It's a very strong assumption that
tr_loss
andtr_loss_step
are on the same device.arg.device
may not be as same as the current device.For example, the
dpo_trainer
is child class oftrainer
. If you are running a DPO training job with one node multiple GPUs. If you set thedevice_map='auto'
. Thetr_loss_step
could be on different device after computing thedpo_loss
.Based on the feedback of @guy1992l, @younesbelkada, @amyeroberts, I have open a PR in the transformers
Trainer
class to do a assert checking.huggingface/transformers#29695
This PR is put the
tr_loss_step
to theself.args.device
before update.Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@younesbelkada
@amyeroberts
@guy1992l
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.