-
Notifications
You must be signed in to change notification settings - Fork 1.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Multi-Adapter PPO] Fix and Refactor reward model adapter #982
[Multi-Adapter PPO] Fix and Refactor reward model adapter #982
Conversation
more flexible, clearer args
unwrap model since it is DDP downside, with reward adapter it seems we need to use find_unused_parameters=True
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great clean up! Thanks a ton for working on this @mnoukhov !
I left only one question, what do you think?
trl/models/modeling_base.py
Outdated
@@ -68,7 +66,7 @@ class PreTrainedModelWrapper(nn.Module): | |||
The list of arguments that are supported by the wrapper class. | |||
""" | |||
transformers_parent_class = None | |||
supported_args = None | |||
supported_args = ("score_module", "supports_rm_adapter", "rm_adapter_name") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you know if these are used? It seems we always overwrite them: https://github.com/huggingface/trl/blob/main/trl/models/modeling_value_head.py#L90
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made this change so it would be a bit more obvious the args being passed in. I'm happy to change it back if you'd like
I actually don't think we use supported_args
anywhere. I assumed it would be a future feature for some sort of argparsing. To be more in line with transformers
we would probably want a config
class instead of supported args anyways so it isn't a big deal either way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK I see, thanks!
Yes it is not a big deal, technically I have set it to None
so that the PreTrainedModelWrapper
can't be used as it is (I should have set that class as an abstract class). Yes it would be great if you can change this back then we can merge I think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
trl/models/modeling_base.py
Outdated
@@ -68,7 +66,7 @@ class PreTrainedModelWrapper(nn.Module): | |||
The list of arguments that are supported by the wrapper class. | |||
""" | |||
transformers_parent_class = None | |||
supported_args = None | |||
supported_args = ("score_module", "supports_rm_adapter", "rm_adapter_name") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
supported_args = ("score_module", "supports_rm_adapter", "rm_adapter_name") | |
supported_args = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks again!
…e#982) * reward adapter loaded as part of init more flexible, clearer args * fixed script for multi gpu unwrap model since it is DDP downside, with reward adapter it seems we need to use find_unused_parameters=True * remove gradient from reward score calculation * change supported_args back to None
A simpler and cleaner version of #472 focused on the main issues
examples/multi_adapter_ppo.py
fails with multi-gpumodel.compute_reward_score
ifmodel
is wrapped inDDP
unwrap_model
RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one....
torch.no_grad()
and make sure that the reward adapter'sscore
hasrequires_grad=False
in
modeling_base.py
, create the reward adapter before initializing the model and pass score intoinit
PeftModel.load_adapter
which includes a lot of useful logic