-
Notifications
You must be signed in to change notification settings - Fork 1.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introducing the Iterative Trainer #737
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
@younesbelkada @lvwerra what do you think of the structure of the trainer? Should we add a generate function or should the generation step be done outside of the trainer? |
Hello @younesbelkada @lvwerra, any news on this? I think that's a required step before adding an example for #704 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for your great work !
Design wise the PR looks great, I just left minor questions, I expect this API to be public, therefore we need to add some documentation, would you be happy doing so? I can also help you if any !
Thanks!
trl/trainer/iterative_trainer.py
Outdated
Attributes: | ||
**config** (`IterativeConfig`) -- Configuration object for IterativeTrainer. | ||
**model** (`PreTrainedModel`) -- Model to be optimized, Hugging Face transformer model with a causal language modeling head. | ||
Check the documentation of `PreTrainedModelWrapper` for more details. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Check the documentation of `PreTrainedModelWrapper` for more details. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here we're using a PreTrainedModel
right?
trl/trainer/iterative_trainer.py
Outdated
raise ValueError( | ||
f"tokenizer must be a PreTrainedTokenizerBase like a PreTrainedTokenizer or a PreTrainedTokenizerFast, got {type(tokenizer)}" | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can also check here is the model is an instance of PreTrainedModel
trl/trainer/iterative_trainer.py
Outdated
self.is_encoder_decoder = hasattr(self.model, "is_encoder_decoder") | ||
|
||
def prepare_model_inputs(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, labels: torch.Tensor): | ||
if self.is_encoder_decoder: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can it be an encoder decoder? from the docstring above it seems on decoder-based models are supported
trl/trainer/iterative_trainer.py
Outdated
f"tokenizer must be a PreTrainedTokenizerBase like a PreTrainedTokenizer or a PreTrainedTokenizerFast, got {type(tokenizer)}" | ||
) | ||
|
||
# Step 1: Initialize Accelerator |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can also add a stronger check, one can check if model.can_generate()
and pass a warning if that method returns False
Hello @younesbelkada thanks for your feedback, I have added the requested modifications to handle seq2seq models and added some documentation. Also, I have added:
It should be ready to be merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for your great work @gaetanlop !! Looks very clean to me !
I left three comments, otherwise looks really great ! Looking forward to merge this PR !
Hey @younesbelkada, thanks for your feedback. I made the requested changes. |
Hi @younesbelkada, still interested by this trainer? I don't think this will be useful for #704 as the SFTTrainer will do the job for the training phase, but that would be useful for rejection sampling and for this paper from google deepmind (https://arxiv.org/pdf/2306.13649.pdf) which seems to be the new state of the art for LLMs knowledge distillation. It uses on policy generated samples for each optimization steps. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @gaetanlop, sorry for the delay, people working on TRL have been a bit busy with other projects the past few weeks. We plan to do a release this week and adding this PR would be nice (if you have time of course).
My main two points:
- should we inherit from
Trainer
to get a lot of upstream functionality for free - there is some preprocessing missing I believe making sure we pad/truncate sequence
Let me know what you think. cc @younesbelkada
trl/trainer/iterative_sft_config.py
Outdated
|
||
|
||
@dataclass | ||
class IterativeSFTConfig(object): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note that we ported the CLI args to tyro
since you opened the PR. not a big change, you can for example look at https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_config.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have removed the IterativeSFTConfig
as it was not useful anymore as we inherit from Trainer
batch_dict = {} | ||
batch_dict.update(model_inputs) | ||
|
||
def collator(data): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we are missing a bit of preprocessing here.
- truncation of long docs
- padding if not all sequences have the same length
We could use the DataCollatorForLanguageModeling
to implement some of that logic. Maybe we need to pass some additional kwargs to the step
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Padding is already done inside prepare_model_inputs
. I have added truncation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for your great work ! Agreed with all the points shared by @lvwerra
My comments are similar than @lvwerra 's comments:
1- Let's inherit from transformers.Trainer
to benefit from all features from the trainer, in the current implementation saving / pushing to hub mechanism are missing for instance
2- let's properly tokenize a batch of sequences instead of tokenizing one by one
3- Let's use explcit arguments in step to make sure we avoid unexpected behaviours
4- We should use DataCollatorForLanguageModeling
to properly handle padding
5- We can't take model.is_peft_model
as model is a PreTrainedModel
, I proposed an alternative to check if the model is a peft model.
Thanks!
Hi @younesbelkada @lvwerra, thanks for the review. I have made the changes. The padding part was already done inside the I have removed the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking great, thanks! I left one question, I think that is_encoder_decoder
has 0 effect as it is an attribute of PreTrainedModelWrapper
: https://github.com/huggingface/trl/blob/main/trl/models/modeling_value_head.py#L286
Can you try to remove it and see if the tests pass? Or alternatively try the alternative I suggested
trl/trainer/iterative_sft_trainer.py
Outdated
"When no scheduler is provided, you need to set the total number of training steps to perform `max_steps`" | ||
) | ||
|
||
self.is_encoder_decoder = hasattr(model, "is_encoder_decoder") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.is_encoder_decoder = hasattr(model, "is_encoder_decoder") | |
self.is_encoder_decoder = getattr(model.config, "is_encoder_decoder", False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed, that's an error
trl/trainer/iterative_sft_trainer.py
Outdated
"When no scheduler is provided, you need to set the total number of training steps to perform `max_steps`" | ||
) | ||
|
||
self.is_encoder_decoder = hasattr(model, "is_encoder_decoder") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.is_encoder_decoder = hasattr(model, "is_encoder_decoder") |
I also wonder if this is needed at first place?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes we need it to decide which collator to use in case the user didn't specify one in the init
@younesbelkada Looks like tests do not start for the last commit |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Almost go to merge I would say! One question remaining regarding evaluation. I think we can probably just keep track of how many optimization steps have been run and call the Trainer
evaluation when it's time.
**optimizers** (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): -- The optimizer and scheduler to use for training. | ||
**data_collator** (Union[DataCollatorForLanguageModeling, DataCollatorForSeq2Seq], *optional*) -- Data collator to be used for training and | ||
passed along the dataloader. | ||
**eval_dataset** (`datasets.Dataset`): The dataset to use for evaluation. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently, we never evaluate, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed, the user had to call the evaluate function of the iterative trainer. I have made some changes
trl/trainer/iterative_sft_trainer.py
Outdated
optimizers=optimizers, | ||
preprocess_logits_for_metrics=preprocess_logits_for_metrics, | ||
) | ||
|
||
self.optimizer, self.lr_scheduler = optimizers | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
isnt' that redundant? the parent class should set the optimizer/scheduler, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed, thanks for catching this, I have removed and made the necessary changes
I made some changes to enable simple logging and evaluation. The function
The tests should be fixed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! LGTM! 🚀
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for your great work! Let's 🚢 it !
* initial skeleton * iterative trainer for decoder only * iterative trainer unittest * encoder_decoder support * fix typo in unittest * init * fix typo * fix init typo * adding loggings and safety checker * fixed minor issues * doc * table of contents update * add test for seq2seq2 models * change year * adding text as step input * precommit * fixing typo * run precommit * fixing typo in safety checker * fix text tokenization issue * add truncate and inherit from trainer * remove iterative config from tests * remove iterative config from init * fix peft model * change truncation side based on truncation_mode * removed iterativeconfig autodoc * fixed typo in trainer.mdx * remove mention of iterative config in docs * make sure optimizer and scheduler are created * adding max_steps to test * remove log_stats fn * remove compute loss * fixing encoder decoder detection * fix PPODecorator * run precommit * fix testing * fix small typos in iterative trainer * adapted function log and eval
This PR is a follow-up to a requested Iterative Trainer in #704 and #576. It introduces a way to finetune models with methods that require some steps between optimizations.