From 5628493825ef8465ce3d7c7f81f89ab7bf2717b0 Mon Sep 17 00:00:00 2001 From: peng wei Date: Mon, 18 Mar 2024 09:48:25 -0700 Subject: [PATCH 1/2] before update the loss from dpo, make sure it's in the same device of tr_loss --- trl/trainer/dpo_trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index fb901a4510..a438e542fe 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -1080,6 +1080,7 @@ def compute_loss( with compute_loss_context_manager(): loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train") + loss = loss.to(self.args.device) # force log the metrics self.store_metrics(metrics, train_eval="train") From 7f5df2a3b7c3972ddbef2382ef1b8998d92d1523 Mon Sep 17 00:00:00 2001 From: Peng Wei <33500509+pengwei715@users.noreply.github.com> Date: Mon, 18 Mar 2024 10:08:55 -0700 Subject: [PATCH 2/2] Update trl/trainer/dpo_trainer.py Co-authored-by: guy1992l <83535508+guy1992l@users.noreply.github.com> --- trl/trainer/dpo_trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index a438e542fe..dac980610d 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -1080,6 +1080,7 @@ def compute_loss( with compute_loss_context_manager(): loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train") + # Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class: loss = loss.to(self.args.device) # force log the metrics self.store_metrics(metrics, train_eval="train")