Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DataCollatorForCompletionOnlyLM defaults can create a model that never generates eos_token_id #976

Closed
robertgshaw2-redhat opened this issue Nov 9, 2023 · 10 comments

Comments

@robertgshaw2-redhat
Copy link

robertgshaw2-redhat commented Nov 9, 2023

Copied from: #456 (comment)

First of all - thanks for the great library

I have been trying to fine-tune some models using this collator, but the resulting model I created does not generate the stop token </s>. I have been using Mistral, which of course does not have a pad_token_id, so I set it to eos_token_id

Looking into the implementation DataCollatorForCompletionOnlyLM, I see that it inherits from CollatorForLanguageModeling, which sets label=-100 for all the pad tokens. Since we set the pad_token_id=eos_token_id, this means that the eos_token is getting a label of -100, meaning we never teach the model to end a sentence :). When I have been using the model I trained with this flow, I have been seeing that the model rarely (if ever) generates the eos_token. I wanted to flag this b/c I am sure many will face the same issue


Is it possible to set tokenizer.pad_token_id=tokenizer.bos_token_id as a workaround?

  • In concept, not backpropogating loss for bos_token_id seems okay (and I don't think there is any case where we ever end up predicting this token since it should only ever be the first token in the sequence with attention_mask=1). However, I am not sure if I don't know what I don't know
@younesbelkada
Copy link
Contributor

Thanks a lot for pointing this out @rsnm2 !
What you said makes sense and is definitely a common scenario for users. I am not sure as well about the right fix, calling tokenizer.pad_token_id = tokenizer.bos_token_id might cause issues for models that have been specifically pre-trained with that token. Maybe a fix is to upstream a fix on transformers side to properly deal with the usecase where tokenizer.pad_token_id == tokenizer.eos_token_id

@younesbelkada
Copy link
Contributor

cc @tomaarsen can you elaborate more on how you subclassed DataCollatorForLanguageModeling and fixed your issue? 🙏

@tomaarsen
Copy link
Member

Of course!

My issue at the time was this:

from transformers import DataCollatorForLanguageModeling

data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
data_collator([tokenizer(formatted_dataset[0]["text"])])
{'input_ids': tensor([[    1,   523, 18529,  9670, 12628,   272,  2296,   808,   302, 11382,
         28725,  3133,  8373, 28747,  5936, 16280,  4969,  1059,  9697,   438,
          1830,   647,   464, 20746, 18566,  9917,  3578,  1996,   378,   533,
          5446, 28705, 28770,  2421,   647,   464,  1733,   824,  2516,  9746,
          7230,  5573, 10487,  3578,  1421,  2063,  4372,   272,  2996,   464,
          5985,   272,  2078,  5944,   297,  1745,  3725,   395,   264,   464,
          5613, 28742,   442,   464,  2501,  4135,   523, 10093,  9670,  1770,
             2]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), 'labels': tensor([[    1,   523, 18529,  9670, 12628,   272,  2296,   808,   302, 11382,
         28725,  3133,  8373, 28747,  5936, 16280,  4969,  1059,  9697,   438,
          1830,   647,   464, 20746, 18566,  9917,  3578,  1996,   378,   533,
          5446, 28705, 28770,  2421,   647,   464,  1733,   824,  2516,  9746,
          7230,  5573, 10487,  3578,  1421,  2063,  4372,   272,  2996,   464,
          5985,   272,  2078,  5944,   297,  1745,  3725,   395,   264,   464,
          5613, 28742,   442,   464,  2501,  4135,   523, 10093,  9670,  1770,
          -100]])}

The labels ends with -100 for the input_id of 2 (which was my EOS).

My resolution was to create the following class:
I used:

from transformers import DataCollatorForSeq2Seq, BatchEncoding

class DataCollatorForLanguageModelingWithEOS(DataCollatorForSeq2Seq):
    def __call__(self, features, return_tensors=None) -> BatchEncoding:
        for feature in features:
            if "labels" not in feature:
                feature["labels"] = feature["input_ids"].copy()
        return super().__call__(features, return_tensors=return_tensors)

data_collator = DataCollatorForLanguageModelingWithEOS(tokenizer)

The DataCollatorForSeq2Seq allows me to also directly pass a labels, so I subclassed it and created the labels by directly copying input_ids.

This is in contrast to DataCollatorForLanguageModeling, which sets labels for pad_token_id to -100 here:
https://github.com/huggingface/transformers/blob/10f3e7b31bef9b4c7508e328f65e1f7ef186f945/src/transformers/data/data_collator.py#L745-L748

When I've set my pad token to my EOS token because the model doesn't have a pad token natively, then my EOS token gets replaced by -100 with DataCollatorForLanguageModeling. See also the docstring for mlm that explains this a bit more: https://github.com/huggingface/transformers/blob/10f3e7b31bef9b4c7508e328f65e1f7ef186f945/src/transformers/data/data_collator.py#L615-L618

  • Tom Aarsen

@MustSave
Copy link
Contributor

MustSave commented Nov 14, 2023

In my case, i used the vicuna's training solution which set pad token to unk token.

I'm not sure that this can be applied to all model, but it might be the solution

@younesbelkada
Copy link
Contributor

Thanks a lot @tomaarsen for sharing your solution
Thanks also @MustSave for your solution as well

@tomaarsen - if I am not mistaken your solution would work well in case one uses packing (i.e. no padding token in the input sequences), for that case @MustSave 's solution would be better but I am not sure if all tokenizers have an unk_token attribute.

@younesbelkada
Copy link
Contributor

younesbelkada commented Nov 14, 2023

All tokenizers should have unk_token attribute according to @ArthurZucker so all good on that end, note however that not all tokenizers have a different unk_token_id

>>> from transformers import AutoTokenizer
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
>>> tokenizer.pad_token = tokenizer.unk_token
>>> tokenizer(["hello", "hi I am here to"], return_tensors="pt", padding=True)
{'input_ids': tensor([[31373, 50256, 50256, 50256, 50256],
        [ 5303,   314,   716,   994,   284]]), 'attention_mask': tensor([[1, 0, 0, 0, 0],
        [1, 1, 1, 1, 1]])}
>>> tokenizer.eos_token
'<|endoftext|>'
>>> tokenizer.eos_token_id
50256
>>> tokenizer.pad_token_id
50256
>>> tokenizer.unk_token_id
50256

@tomaarsen
Copy link
Member

So it would still not be completely foolproof in that scenario, I imagine. Definitely tricky.

@robertgshaw2-redhat
Copy link
Author

thanks for the feedback --> I think using a custom collator that sets label[i]=-100 if attention_mask[i]=0 is the approach to take

@luffycodes
Copy link

Sorry for the question, but I am also facing this issue. is there a temporary solution that might work for mistral (if one is using packing or without packing)?

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants