Skip to content

Commit

Permalink
↩️ Revert ORPO loss changes (#2527)
Browse files Browse the repository at this point in the history
* revert orpo changes

* add comment
  • Loading branch information
kashif authored Jan 8, 2025
1 parent f2d42fa commit beb892b
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions trl/trainer/orpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,12 +778,18 @@ def cross_entropy_loss(logits, labels):
loss = loss_fct(logits, labels)
return loss

labels = concatenated_batch["concatenated_labels"].clone()
if self.is_encoder_decoder:
labels = concatenated_batch["concatenated_labels"].clone()
else:
labels = concatenated_batch["concatenated_input_ids"].clone()
attention_mask = concatenated_batch["concatenated_attention_mask"]
labels = torch.where(attention_mask == 1, labels, self.label_pad_token_id)
# orpo chosen nll loss is computed over the full prompt and response
chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])

all_logps = self.get_batch_logps(
all_logits,
labels,
concatenated_batch["concatenated_labels"],
average_log_prob=True,
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
Expand Down

0 comments on commit beb892b

Please sign in to comment.