Skip to content

Commit

Permalink
Fix FSDP error
Browse files Browse the repository at this point in the history
Fixes error when `loss` field of model output is non-empty, and indexing as [0] returns loss instead of logits. Can happen with FSDP.
  • Loading branch information
mgerstgrasser committed Jan 9, 2024
1 parent 3267be0 commit 989f4a4
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions trl/trainer/reward_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,11 +220,11 @@ def compute_loss(
rewards_chosen = model(
input_ids=inputs["input_ids_chosen"],
attention_mask=inputs["attention_mask_chosen"],
)[0]
)["logits"]
rewards_rejected = model(
input_ids=inputs["input_ids_rejected"],
attention_mask=inputs["attention_mask_rejected"],
)[0]
)["logits"]
# calculate loss, optionally modulate with margin
if "margin" in inputs:
loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - inputs["margin"]).mean()
Expand Down

0 comments on commit 989f4a4

Please sign in to comment.