Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introducing the Iterative Trainer #737

Merged
merged 39 commits into from
Nov 2, 2023

Conversation

gaetanlop
Copy link
Contributor

This PR is a follow-up to a requested Iterative Trainer in #704 and #576. It introduces a way to finetune models with methods that require some steps between optimizations.

@gaetanlop gaetanlop marked this pull request as draft September 4, 2023 21:10
@gaetanlop gaetanlop changed the title [WIP] Introducing the Iterative Trainer [WIP] Introducing an Iterative Trainer Sep 4, 2023
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@gaetanlop
Copy link
Contributor Author

@younesbelkada @lvwerra what do you think of the structure of the trainer? Should we add a generate function or should the generation step be done outside of the trainer?

@gaetanlop gaetanlop changed the title [WIP] Introducing an Iterative Trainer Introducing the Iterative Trainer Sep 6, 2023
@gaetanlop gaetanlop marked this pull request as ready for review September 6, 2023 00:27
@gaetanlop gaetanlop marked this pull request as draft September 6, 2023 01:16
@gaetanlop
Copy link
Contributor Author

Hello @younesbelkada @lvwerra, any news on this? I think that's a required step before adding an example for #704

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for your great work !
Design wise the PR looks great, I just left minor questions, I expect this API to be public, therefore we need to add some documentation, would you be happy doing so? I can also help you if any !
Thanks!

Attributes:
**config** (`IterativeConfig`) -- Configuration object for IterativeTrainer.
**model** (`PreTrainedModel`) -- Model to be optimized, Hugging Face transformer model with a causal language modeling head.
Check the documentation of `PreTrainedModelWrapper` for more details.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Check the documentation of `PreTrainedModelWrapper` for more details.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we're using a PreTrainedModel right?

raise ValueError(
f"tokenizer must be a PreTrainedTokenizerBase like a PreTrainedTokenizer or a PreTrainedTokenizerFast, got {type(tokenizer)}"
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can also check here is the model is an instance of PreTrainedModel

self.is_encoder_decoder = hasattr(self.model, "is_encoder_decoder")

def prepare_model_inputs(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, labels: torch.Tensor):
if self.is_encoder_decoder:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can it be an encoder decoder? from the docstring above it seems on decoder-based models are supported

f"tokenizer must be a PreTrainedTokenizerBase like a PreTrainedTokenizer or a PreTrainedTokenizerFast, got {type(tokenizer)}"
)

# Step 1: Initialize Accelerator
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can also add a stronger check, one can check if model.can_generate() and pass a warning if that method returns False

@gaetanlop
Copy link
Contributor Author

Hello @younesbelkada thanks for your feedback, I have added the requested modifications to handle seq2seq models and added some documentation. Also, I have added:

  • Made the attention_mask and the labels optional in the step function (labels should only be needed for encoder-decoder models).
  • a safety checker method to verify that input_ids, attention_mask and labels are list of tensors.
  • a logging method to be able to log metrics during training.

It should be ready to be merged.

@gaetanlop gaetanlop marked this pull request as ready for review September 15, 2023 03:29
Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for your great work @gaetanlop !! Looks very clean to me !
I left three comments, otherwise looks really great ! Looking forward to merge this PR !

docs/source/trainer.mdx Outdated Show resolved Hide resolved
trl/trainer/iterative_config.py Outdated Show resolved Hide resolved
trl/trainer/iterative_trainer.py Outdated Show resolved Hide resolved
tests/test_iterative_trainer.py Outdated Show resolved Hide resolved
@gaetanlop
Copy link
Contributor Author

Hey @younesbelkada, thanks for your feedback. I made the requested changes.

@gaetanlop
Copy link
Contributor Author

gaetanlop commented Sep 27, 2023

Hi @younesbelkada, still interested by this trainer? I don't think this will be useful for #704 as the SFTTrainer will do the job for the training phase, but that would be useful for rejection sampling and for this paper from google deepmind (https://arxiv.org/pdf/2306.13649.pdf) which seems to be the new state of the art for LLMs knowledge distillation. It uses on policy generated samples for each optimization steps.

Copy link
Member

@lvwerra lvwerra left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @gaetanlop, sorry for the delay, people working on TRL have been a bit busy with other projects the past few weeks. We plan to do a release this week and adding this PR would be nice (if you have time of course).

My main two points:

  • should we inherit from Trainer to get a lot of upstream functionality for free
  • there is some preprocessing missing I believe making sure we pad/truncate sequence

Let me know what you think. cc @younesbelkada



@dataclass
class IterativeSFTConfig(object):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note that we ported the CLI args to tyro since you opened the PR. not a big change, you can for example look at https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_config.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have removed the IterativeSFTConfig as it was not useful anymore as we inherit from Trainer

trl/trainer/iterative_sft_trainer.py Outdated Show resolved Hide resolved
trl/trainer/iterative_sft_trainer.py Outdated Show resolved Hide resolved
trl/trainer/iterative_sft_trainer.py Outdated Show resolved Hide resolved
trl/trainer/iterative_sft_trainer.py Outdated Show resolved Hide resolved
batch_dict = {}
batch_dict.update(model_inputs)

def collator(data):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we are missing a bit of preprocessing here.

  • truncation of long docs
  • padding if not all sequences have the same length

We could use the DataCollatorForLanguageModeling to implement some of that logic. Maybe we need to pass some additional kwargs to the step

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Padding is already done inside prepare_model_inputs. I have added truncation

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for your great work ! Agreed with all the points shared by @lvwerra

My comments are similar than @lvwerra 's comments:

1- Let's inherit from transformers.Trainer to benefit from all features from the trainer, in the current implementation saving / pushing to hub mechanism are missing for instance

2- let's properly tokenize a batch of sequences instead of tokenizing one by one

3- Let's use explcit arguments in step to make sure we avoid unexpected behaviours

4- We should use DataCollatorForLanguageModeling to properly handle padding

5- We can't take model.is_peft_model as model is a PreTrainedModel , I proposed an alternative to check if the model is a peft model.

Thanks!

trl/trainer/iterative_sft_trainer.py Outdated Show resolved Hide resolved
trl/trainer/iterative_sft_trainer.py Outdated Show resolved Hide resolved
trl/trainer/iterative_sft_trainer.py Outdated Show resolved Hide resolved
@gaetanlop
Copy link
Contributor Author

Hi @younesbelkada @lvwerra, thanks for the review. I have made the changes.

The padding part was already done inside the prepare_model_inputs function. However, truncation was not set. I have added two keyword arguments to the init (truncation mode and max_length). If the user provides texts instead of input ids then the truncation is directly done when tokenizing. In any case, we also truncate inside the prepare_model_inputs function in case the user didn't truncate its model_inputs before passing it to the step function.

I have removed the IterativeSFTConfig as we can directly use the TrainingArguments if we inherit from the HF Trainer.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking great, thanks! I left one question, I think that is_encoder_decoder has 0 effect as it is an attribute of PreTrainedModelWrapper: https://github.com/huggingface/trl/blob/main/trl/models/modeling_value_head.py#L286
Can you try to remove it and see if the tests pass? Or alternatively try the alternative I suggested

"When no scheduler is provided, you need to set the total number of training steps to perform `max_steps`"
)

self.is_encoder_decoder = hasattr(model, "is_encoder_decoder")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.is_encoder_decoder = hasattr(model, "is_encoder_decoder")
self.is_encoder_decoder = getattr(model.config, "is_encoder_decoder", False)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, that's an error

"When no scheduler is provided, you need to set the total number of training steps to perform `max_steps`"
)

self.is_encoder_decoder = hasattr(model, "is_encoder_decoder")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.is_encoder_decoder = hasattr(model, "is_encoder_decoder")

I also wonder if this is needed at first place?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes we need it to decide which collator to use in case the user didn't specify one in the init

@gaetanlop
Copy link
Contributor Author

gaetanlop commented Nov 1, 2023

@younesbelkada Looks like tests do not start for the last commit

Copy link
Member

@lvwerra lvwerra left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Almost go to merge I would say! One question remaining regarding evaluation. I think we can probably just keep track of how many optimization steps have been run and call the Trainer evaluation when it's time.

**optimizers** (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): -- The optimizer and scheduler to use for training.
**data_collator** (Union[DataCollatorForLanguageModeling, DataCollatorForSeq2Seq], *optional*) -- Data collator to be used for training and
passed along the dataloader.
**eval_dataset** (`datasets.Dataset`): The dataset to use for evaluation.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, we never evaluate, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, the user had to call the evaluate function of the iterative trainer. I have made some changes

Comment on lines 124 to 129
optimizers=optimizers,
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)

self.optimizer, self.lr_scheduler = optimizers

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isnt' that redundant? the parent class should set the optimizer/scheduler, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, thanks for catching this, I have removed and made the necessary changes

@gaetanlop
Copy link
Contributor Author

gaetanlop commented Nov 2, 2023

I made some changes to enable simple logging and evaluation. The function _maybe_log_save_evaluate of the HF Trainer could not be used in our case so I made a simple version of the function that has the following behavior:

  • We log and evaluate only if the user has specified logging_steps and eval_steps respectively (we do not have access to the number of epochs so step args are the only thing that matters).
  • Logging and evaluation are done every logging_steps and eval_steps respectively.
  • We only keep track of the loss and the learning rate.

The tests should be fixed.

Copy link
Member

@lvwerra lvwerra left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! LGTM! 🚀

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for your great work! Let's 🚢 it !

@younesbelkada younesbelkada merged commit cc1de98 into huggingface:main Nov 2, 2023
lapp0 pushed a commit to lapp0/trl that referenced this pull request May 10, 2024
* initial skeleton

* iterative trainer for decoder only

* iterative trainer unittest

* encoder_decoder support

* fix typo in unittest

* init

* fix typo

* fix init typo

* adding loggings and safety checker

* fixed minor issues

* doc

* table of contents update

* add test for seq2seq2 models

* change year

* adding text as step input

* precommit

* fixing typo

* run precommit

* fixing typo in safety checker

* fix text tokenization issue

* add truncate and inherit from trainer

* remove iterative config from tests

* remove iterative config from init

* fix peft model

* change truncation side based on truncation_mode

* removed iterativeconfig autodoc

* fixed typo in trainer.mdx

* remove mention of iterative config in docs

* make sure optimizer and scheduler are created

* adding max_steps to test

* remove log_stats fn

* remove compute loss

* fixing encoder decoder detection

* fix PPODecorator

* run precommit

* fix testing

* fix small typos in iterative trainer

* adapted function log and eval
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants