From 1284d6553be32d56df7762f2346df6e847a2a7eb Mon Sep 17 00:00:00 2001 From: MustSave Date: Tue, 14 Nov 2023 14:56:44 +0900 Subject: [PATCH] [DataCollatorForCompletionOnlyLM] warn if eos_token_id and pad_token_id are identical Display a warning message if the and values are the same in order to prevent unintended behavior during multi-turn training. --- trl/trainer/utils.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index b7ba374cec..27cb2f2b0c 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -99,6 +99,14 @@ def __init__( # The user already provides the token ids self.response_token_ids = response_template + if not self.mlm and self.instruction_template and self.tokenizer.pad_token_id == self.tokenizer.eos_token_id: + warnings.warn( + "The pad_token_id and eos_token_id values of this tokenizer are identical. " + "If you are planning for multi-turn training, " + "it can result in the model continuously generating questions and answers without eos token. " + "To avoid this, set the pad_token_id to a different value." + ) + self.ignore_index = ignore_index def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]: