Skip to content

Commit

Permalink
Update utils.py (#1012)
Browse files Browse the repository at this point in the history
* Update utils.py

update compute_accuracy to deal with the cases where str_chosen and str_rej got the same scores, which is probably what the developers don't want

* Update utils.py

updated so only warning is reserved

* Update trl/trainer/utils.py

Co-authored-by: Leandro von Werra <[email protected]>

---------

Co-authored-by: Leandro von Werra <[email protected]>
  • Loading branch information
ZihanWang314 and lvwerra authored Nov 29, 2023
1 parent 55d7c95 commit 4b67af3
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,10 @@ def compute_accuracy(eval_pred) -> Dict[str, float]:
predictions, labels = eval_pred
# Here, predictions is rewards_chosen and rewards_rejected.
# We want to see how much of the time rewards_chosen > rewards_rejected.
if np.array(predictions[:, 0] == predictions[:, 1], dtype=float).sum() > 0:
warnings.warn(
f"There are {np.array(predictions[:, 0] == predictions[:, 1]).sum()} out of {len(predictions[:, 0])} instances where the predictions for both options are equal. As a consequence the accuracy can be misleading."
)
predictions = np.argmax(predictions, axis=1)

accuracy = np.array(predictions == labels, dtype=float).mean().item()
Expand Down

0 comments on commit 4b67af3

Please sign in to comment.