From a64a522fccf4c21fed692ff91a7987b0e41dcfba Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Thu, 2 Nov 2023 11:27:49 +0100 Subject: [PATCH] Update dpo_trainer.py (#941) --- trl/trainer/dpo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 50e72ce0e5..1a38c7f41f 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -556,7 +556,7 @@ def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[ """Generate samples from the model and reference model for the given batch of inputs.""" policy_output = model.generate( - batch["prompt_input_ids"], + input_ids=batch["prompt_input_ids"], attention_mask=batch["prompt_attention_mask"], max_length=self.max_length, do_sample=True,