diff --git a/trl/trainer/bco_trainer.py b/trl/trainer/bco_trainer.py index 967373713d..67733a48a6 100644 --- a/trl/trainer/bco_trainer.py +++ b/trl/trainer/bco_trainer.py @@ -105,7 +105,9 @@ def _tokenize( full_input_ids = [np.array(f) for f in full_input_ids] for full, concat in zip(full_input_ids, full_concat_input_ids): if len(full) != len(concat): - raise ValueError("Prompt input ids and answer input ids should have the same length.") + raise ValueError( + "The elements in 'full_input_ids' and 'full_concat_input_ids' must have the same pairwise length." + ) # On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens # can be merged together when tokenizing prompt+answer. This could result diff --git a/trl/trainer/kto_trainer.py b/trl/trainer/kto_trainer.py index 7be86d924e..17ad779a5a 100644 --- a/trl/trainer/kto_trainer.py +++ b/trl/trainer/kto_trainer.py @@ -110,7 +110,9 @@ def _tokenize( full_input_ids = [np.array(f) for f in full_input_ids] for full, concat in zip(full_input_ids, full_concat_input_ids): if len(full) != len(concat): - raise ValueError("Prompt input ids and answer input ids should have the same length.") + raise ValueError( + "The elements in 'full_input_ids' and 'full_concat_input_ids' must have the same pairwise length." + ) # On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens # can be merged together when tokenizing prompt+answer. This could result