From 74e20cbbbcbac7ac8d426df09eda5f310c637def Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Mon, 4 Nov 2024 21:08:41 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=AA=AA=20Check=20with=20`token=5Fid`=20in?= =?UTF-8?q?stead=20of=20`token`=20in=20`DPOTrainer`=20(#2324)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- trl/trainer/dpo_trainer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index a8192e5905..b563cab2f5 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -572,9 +572,9 @@ def tokenize_row(features, processing_class, max_prompt_length, max_completion_l # Add special tokens (typically for encoder-decoder models) if add_special_tokens: - if tokenizer.bos_token is not None: + if tokenizer.bos_token_id is not None: prompt_input_ids = [tokenizer.bos_token_id] + prompt_input_ids - if tokenizer.eos_token is not None: + if tokenizer.eos_token_id is not None: prompt_input_ids = prompt_input_ids + [tokenizer.eos_token_id] chosen_input_ids = chosen_input_ids + [tokenizer.eos_token_id] rejected_input_ids = rejected_input_ids + [tokenizer.eos_token_id] @@ -607,9 +607,9 @@ def process_row(features, processing_class, max_prompt_length, max_completion_le # Add special tokens (typically for encoder-decoder models) if add_special_tokens: - if tokenizer.bos_token is not None: + if tokenizer.bos_token_id is not None: prompt_input_ids = [tokenizer.bos_token_id] + prompt_input_ids - if tokenizer.eos_token is not None: + if tokenizer.eos_token_id is not None: prompt_input_ids = prompt_input_ids + [tokenizer.eos_token_id] chosen_input_ids = chosen_input_ids + [tokenizer.eos_token_id] rejected_input_ids = rejected_input_ids + [tokenizer.eos_token_id]