Skip to content

Commit

Permalink
🪪 Check with token_id instead of token in DPOTrainer (#2324)
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec authored Nov 4, 2024
1 parent 27b9e3a commit 74e20cb
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,9 +572,9 @@ def tokenize_row(features, processing_class, max_prompt_length, max_completion_l

# Add special tokens (typically for encoder-decoder models)
if add_special_tokens:
if tokenizer.bos_token is not None:
if tokenizer.bos_token_id is not None:
prompt_input_ids = [tokenizer.bos_token_id] + prompt_input_ids
if tokenizer.eos_token is not None:
if tokenizer.eos_token_id is not None:
prompt_input_ids = prompt_input_ids + [tokenizer.eos_token_id]
chosen_input_ids = chosen_input_ids + [tokenizer.eos_token_id]
rejected_input_ids = rejected_input_ids + [tokenizer.eos_token_id]
Expand Down Expand Up @@ -607,9 +607,9 @@ def process_row(features, processing_class, max_prompt_length, max_completion_le

# Add special tokens (typically for encoder-decoder models)
if add_special_tokens:
if tokenizer.bos_token is not None:
if tokenizer.bos_token_id is not None:
prompt_input_ids = [tokenizer.bos_token_id] + prompt_input_ids
if tokenizer.eos_token is not None:
if tokenizer.eos_token_id is not None:
prompt_input_ids = prompt_input_ids + [tokenizer.eos_token_id]
chosen_input_ids = chosen_input_ids + [tokenizer.eos_token_id]
rejected_input_ids = rejected_input_ids + [tokenizer.eos_token_id]
Expand Down

0 comments on commit 74e20cb

Please sign in to comment.