Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to check dataset labels in SFTTrainer #1414

Closed
wants to merge 8 commits into from

Conversation

geronimi73
Copy link

Hi everyone!

This PR introduces a new option check_dataset_labels in SFTTrainer. When enabled, the trainer calls the collator on the first sample from the training set and logs token_id, decoded token_id, and the corresponding label. This helps to uncover mistakes early, such as setting the tokenizer's pad_token toeos_token. The idea is taken from axolotl where such an option exists already.

Problem(s):
It's common practice to set the tokenizer's pad_token to eos_token. This is problematic because DataCollatorForCompletionOnlyLM sets the label for all occurrences of pad_token to -100. If pad=eos, then the model will never learn to output eos. This issue can be hard to debug and many people struggle with it.

Additionally, when using instruction_template and response_template, logging the tokens and labels would be helpful to ensure that all the labels are correctly set to only train on output and ignore the instruction. Furthermore, tokenizers can be complex and the output would aid in spotting tokenization issues like handling of special tokens quickly

Solution:

  • Add a new option check_dataset_labels to the trainer.
  • When check_dataset_labels is enabled, the trainer calls the collator on the first sample from the training set and logs token_id, decoded token_id, and the corresponding label.
  • This creates transparency and helps to uncover mistakes early.

Usage:

# load model and tokenizer 
...

messages = [
    {"role": "user", "content": "Hello who are you?"},
    {"role": "assistant", "content": "Luke, I am your father"},
    {"role": "user", "content": "WTF"},
]
dataset = Dataset.from_list([dict(messages=messages)])

trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset,
    data_collator = DataCollatorForCompletionOnlyLM(
        instruction_template = "<|im_start|>user", 
        response_template = "<|im_start|>assistant", 
        tokenizer = tokenizer, 
        mlm = False),
    check_dataset_labels = True,
    dataset_kwargs=dict(add_special_tokens=False),
    args = TrainingArguments(output_dir = "out")
)

output:

check_dataset_labels:
<|im_start|> user
Hello who are you? <|im_end|> 
 <|im_start|> assistant
Luke, I am your father <|im_end|> 
 <|im_start|> user
WTF <|im_end|> 

32000 '<|im_start|>' -100
1792 'user' -100
13 '<0x0A>' -100
10994 'Hello' -100
1058 'who' -100
526 'are' -100
366 'you' -100
29973 '?' -100
32001 '<|im_end|>' -100
13 '<0x0A>' -100
32000 '<|im_start|>' -100
465 'ass' -100
22137 'istant' -100
13 '<0x0A>' 13
24126 'Lu' 24126
446 'ke' 446
29892 ',' 29892
306 'I' 306
626 'am' 626
596 'your' 596
4783 'father' 4783
32001 '<|im_end|>' 32001
13 '<0x0A>' 13
32000 '<|im_start|>' -100
1792 'user' -100
13 '<0x0A>' -100
29956 'W' -100
8969 'TF' -100
32001 '<|im_end|>' -100
13 '<0x0A>' -100

Related issues:

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing work !
Can you add the docstring of that arg in SFTTrainer's docstring together with few lines explaining what you posted on the PR? 🙏

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @geronimi73
the changes look great ! can you just run the styling checks? make precommit , then we can merge

@geronimi73
Copy link
Author

Thanks @geronimi73 the changes look great ! can you just run the styling checks? make precommit , then we can merge

what about the print statements I mentioned? leave it like this?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I think we can leave the print

trl/trainer/sft_trainer.py Outdated Show resolved Hide resolved
Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @geronimi73 !
Again thank you very much for this contribution, after thinking a bit, I think we should make things standardized and use things like logger.info instead of print, so that it approach is harmonized across our coding practice in the HF codebase - sorry for that as you already asked the question and i said yes! - would you mind switching print statements to logger.info ? you will need to import logging at the top of the file and import the logger properly ! let me know if you need any help or if you think we should keep the print statements

@younesbelkada
Copy link
Contributor

(to fix the failing tests you just have to rebase with main)

@geronimi73
Copy link
Author

you will need to import logging at the top of the file and import the logger properly

i tried. it works but not sure if this is the way to do it. please check my comment on the code

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks ! I left one comment ! can you also run the styling checks? make precommit

trl/trainer/sft_trainer.py Outdated Show resolved Hide resolved
geronimi73 and others added 2 commits March 14, 2024 10:09
@younesbelkada
Copy link
Contributor

Hi @geronimi73 sorry for all the iteration ! can you re-run the styling checks again? 🙏

@geronimi73
Copy link
Author

geronimi73 commented Mar 15, 2024

Hi @geronimi73 sorry for all the iteration ! can you re-run the styling checks again? 🙏

sure! but i'm still wondering whether it works correctly. Problem is, I never see the output of logger.info()!

how am I supposed to set the log level (as a user of SFTTrainer) ? I tried this, doesnt work:

import transformers
transformers.logging.set_verbosity_info()

@geronimi73
Copy link
Author

Hi @geronimi73 sorry for all the iteration ! can you re-run the styling checks again? 🙏

sure! but i'm still wondering whether it works. I never see the output of logger.info()!

how am I supposed to set the log level (as a user of SFTTrainer) ? I tried this, doesnt work:

import transformers
transformers.logging.set_verbosity_info()

also this one does not enable it

args = TrainingArguments(
    output_dir = "out",
    log_level = "info"
    )

@geronimi73
Copy link
Author

sorry to keep bothering you @younesbelkada but I think the problem ist that we are using logging from transformers.utils and obtaining a logger with __name__ (=trl.trainer.sft_trainer). This means that setting the log_level in TrainingArguments does not have an effect on the logger we obtained because this line in trainer.py affects the root logger transformers and does not have any effect on our logger (trl.trainer.sft_trainer).

To do this correctly I think, we would have to either

  • add a trl.utils.logger analogous to transformers.utils.logger (too much effort?)
  • OR use logger.setLevel(log_level) in SFTTrainer.init(). Otherwise we will not see the output of logging.info().

what do you think? option 2?

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@github-actions github-actions bot closed this Apr 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants