diff --git a/trl/trainer/orpo_trainer.py b/trl/trainer/orpo_trainer.py index 9cf258c947..803bda6699 100644 --- a/trl/trainer/orpo_trainer.py +++ b/trl/trainer/orpo_trainer.py @@ -778,12 +778,18 @@ def cross_entropy_loss(logits, labels): loss = loss_fct(logits, labels) return loss - labels = concatenated_batch["concatenated_labels"].clone() + if self.is_encoder_decoder: + labels = concatenated_batch["concatenated_labels"].clone() + else: + labels = concatenated_batch["concatenated_input_ids"].clone() + attention_mask = concatenated_batch["concatenated_attention_mask"] + labels = torch.where(attention_mask == 1, labels, self.label_pad_token_id) + # orpo chosen nll loss is computed over the full prompt and response chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen]) all_logps = self.get_batch_logps( all_logits, - labels, + concatenated_batch["concatenated_labels"], average_log_prob=True, is_encoder_decoder=self.is_encoder_decoder, label_pad_token_id=self.label_pad_token_id,