Skip to content

Commit

Permalink
[DPO] IPO Training loss (#1022)
Browse files Browse the repository at this point in the history
* initial IPO loss

* fix loss

* fixed comments

* added docs

* fix doc-strings

* add tests

* Update trl/trainer/dpo_trainer.py

Co-authored-by: Leandro von Werra <[email protected]>

* fixes for review

* Added doc about beta in the Trainer's docstring

---------

Co-authored-by: Leandro von Werra <[email protected]>
  • Loading branch information
kashif and lvwerra authored Nov 24, 2023
1 parent 3719f7a commit 55d7c95
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 8 deletions.
2 changes: 2 additions & 0 deletions docs/source/dpo_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ Note that the `beta` is the temperature parameter for the DPO loss, typically so
Given the preference data, we can fit a binary classifier according to the Bradley-Terry model and in fact the DPO authors propose the sigmoid loss on the normalized likelihood via the `logsigmoid` to fit a logistic regression.
The [RSO](https://arxiv.org/abs/2309.06657) authors propose to use a hinge loss on the normalized likelihood from the [SLiC](https://arxiv.org/abs/2305.10425) paper. The `DPOTrainer` can be switched to this loss via the `loss_type="hinge"` argument and the `beta` in this case is the reciprocal of the margin.

The [IPO](https://arxiv.org/abs/2310.12036) authors provide a deeper theoretical understanding of the DPO algorithms and identify an issue with overfitting and propose an alternative loss which can be used via the `loss_type="ipo"` argument to the trainer.

## Logging

While training and evaluating we record the following reward metrics:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def _init_dummy_dataset(self):
# fmt: on
return Dataset.from_dict(dummy_dataset_dict)

@parameterized.expand([["gpt2", "sigmoid"], ["t5", "hinge"]])
@parameterized.expand([["gpt2", "sigmoid"], ["t5", "hinge"], ["gpt2", "ipo"], ["t5", "ipo"]])
def test_dpo_trainer(self, name, loss_type):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
Expand Down
52 changes: 45 additions & 7 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ class DPOTrainer(Trainer):
Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no
reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.
beta (`float`, defaults to 0.1):
The beta factor in DPO loss. Higher beta means less divergence from the initial policy.
The beta factor in DPO loss. Higher beta means less divergence from the initial policy. For the IPO loss, beta is the regularization parameter denoted by tau in the paper.
loss_type (`str`, defaults to `"sigmoid"`):
The type of DPO loss to use. Either `"sigmoid"` the default DPO loss or `"hinge"` loss from SLiC paper.
The type of DPO loss to use. Either `"sigmoid"` the default DPO loss,`"hinge"` loss from SLiC paper or `"ipo"` from IPO paper.
args (`transformers.TrainingArguments`):
The arguments to use for training.
data_collator (`transformers.DataCollator`):
Expand Down Expand Up @@ -120,7 +120,7 @@ def __init__(
model: Union[PreTrainedModel, nn.Module, str] = None,
ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
beta: float = 0.1,
loss_type: Literal["sigmoid", "hinge"] = "sigmoid",
loss_type: Literal["sigmoid", "hinge", "ipo"] = "sigmoid",
args: TrainingArguments = None,
data_collator: Optional[DataCollator] = None,
label_pad_token_id: int = -100,
Expand Down Expand Up @@ -428,7 +428,6 @@ def dpo_loss(
policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)
reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)
beta: Temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5. We ignore the reference model as beta -> 0.
reference_free: If True, we ignore the _provided_ reference model and implicitly use a reference model that assigns equal probability to all responses.
Returns:
Expand All @@ -437,13 +436,15 @@ def dpo_loss(
The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
"""
pi_logratios = policy_chosen_logps - policy_rejected_logps
ref_logratios = reference_chosen_logps - reference_rejected_logps

if reference_free:
ref_logratios = 0
else:
ref_logratios = reference_chosen_logps - reference_rejected_logps

logits = pi_logratios - ref_logratios

# The beta is a temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5.
# We ignore the reference model as beta -> 0.
if self.loss_type == "sigmoid":
losses = -F.logsigmoid(self.beta * logits)
elif self.loss_type == "hinge":
Expand All @@ -456,6 +457,38 @@ def dpo_loss(

return losses, chosen_rewards, rejected_rewards

def ipo_loss(
self,
policy_chosen_logps: torch.FloatTensor,
policy_rejected_logps: torch.FloatTensor,
reference_chosen_logps: torch.FloatTensor,
reference_rejected_logps: torch.FloatTensor,
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
"""Compute the IPO loss for a batch of policy and reference model log probabilities.
Args:
policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)
reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)
Returns:
A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
The losses tensor contains the IPO loss for each example in the batch.
The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
"""
pi_logratios = policy_chosen_logps + reference_rejected_logps
ref_logratios = policy_rejected_logps + reference_chosen_logps

logits = pi_logratios - ref_logratios
# eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper.
losses = (logits - 1 / (2 * self.beta)) ** 2

chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach()
rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach()

return losses, chosen_rewards, rejected_rewards

def _get_batch_logps(
self,
logits: torch.FloatTensor,
Expand Down Expand Up @@ -560,7 +593,12 @@ def get_batch_metrics(
_,
) = self.concatenated_forward(self.ref_model, batch)

losses, chosen_rewards, rejected_rewards = self.dpo_loss(
if self.loss_type == "ipo":
loss_fn = self.ipo_loss
else:
loss_fn = self.dpo_loss

losses, chosen_rewards, rejected_rewards = loss_fn(
policy_chosen_logps,
policy_rejected_logps,
reference_chosen_logps,
Expand Down

0 comments on commit 55d7c95

Please sign in to comment.