From ac8fee3d90396ed3f88a0e8ffdcb2bff10e77809 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 8 Dec 2023 09:52:43 +0000 Subject: [PATCH] fix dpo trainer bug (#1073) --- trl/trainer/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 3c29bedfb2..7fb17bd0a6 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -453,7 +453,7 @@ def collate(self, batch): if "prompt" in k: to_pad = [torch.LongTensor(ex[k][::-1]) for ex in batch] else: - to_pad = [torch.LongTensor(ex[k]) for ex in batch] + to_pad = [torch.LongTensor(ex[k][:-1]) for ex in batch] if k.endswith("_input_ids"): padding_value = self.tokenizer.pad_token_id elif k.endswith("_labels"):