Skip to content

Commit

Permalink
Merge pull request #6465 from hiyouga/hiyouga/fix_eval_loss
Browse files Browse the repository at this point in the history
[trainer] fix eval loss
  • Loading branch information
hiyouga authored Dec 27, 2024
2 parents f68074d + 2719867 commit b558902
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 27 deletions.
1 change: 1 addition & 0 deletions src/llamafactory/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,5 +121,6 @@ def main():
else:
raise NotImplementedError(f"Unknown command: {command}.")


if __name__ == "__main__":
main()
6 changes: 5 additions & 1 deletion src/llamafactory/data/collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,10 @@ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tenso
fake_messages = [{"role": "user", "content": IMAGE_PLACEHOLDER}]
fake_images = [Image.new("RGB", (64, 64), (255, 255, 255))]
fake_messages = self.template.mm_plugin.process_messages(fake_messages, fake_images, [], self.processor)
fake_input_ids = self.processor.tokenizer.encode(fake_messages[0]["content"], add_special_tokens=False)
fake_input_ids = self.tokenizer.encode(fake_messages[0]["content"], add_special_tokens=False)
fake_input_ids, _ = self.template.mm_plugin.process_token_ids(
fake_input_ids, None, fake_images, [], self.tokenizer, self.processor
)
if self.tokenizer.padding_side == "right":
features[0]["input_ids"] = features[0]["input_ids"] + fake_input_ids
features[0]["attention_mask"] = features[0]["attention_mask"] + [0] * len(fake_input_ids)
Expand All @@ -116,6 +119,7 @@ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tenso
features[0]["labels"] = [IGNORE_INDEX] * len(fake_input_ids) + features[0]["labels"]

batch_images = fake_images
batch_imglens[0] = 1
batch_input_ids[0] = features[0]["input_ids"]

mm_inputs = self.template.mm_plugin.get_mm_inputs(
Expand Down
10 changes: 5 additions & 5 deletions src/llamafactory/train/dpo/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def __init__(
self.simpo_gamma = finetuning_args.simpo_gamma

Trainer.__init__(self, model=model, **kwargs)
self.model_accepts_loss_kwargs = False # overwrite trainer's default behavior
if not hasattr(self, "accelerator"):
raise AttributeError("Please update `transformers`.")

Expand Down Expand Up @@ -274,15 +275,14 @@ def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
r"""
Fixes the loss value for transformers 4.46.0.
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
Fixes the loss value. See https://github.com/huggingface/transformers/pull/35438 for details.
"""
loss = super().compute_loss(model, inputs, return_outputs)
if is_transformers_version_equal_to_4_46() and kwargs.pop("num_items_in_batch", False):
if is_transformers_version_equal_to_4_46() and kwargs.get("num_items_in_batch"):
if return_outputs:
return (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
loss = (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
else:
return loss / self.args.gradient_accumulation_steps
loss = loss / self.args.gradient_accumulation_steps

return loss

Expand Down
10 changes: 5 additions & 5 deletions src/llamafactory/train/kto/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __init__(
self.ftx_gamma = finetuning_args.pref_ftx

Trainer.__init__(self, model=model, **kwargs)
self.model_accepts_loss_kwargs = False # overwrite trainer's default behavior
if not hasattr(self, "accelerator"):
raise AttributeError("Please update `transformers`.")

Expand Down Expand Up @@ -252,15 +253,14 @@ def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
r"""
Fixes the loss value for transformers 4.46.0.
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
Fixes the loss value. See https://github.com/huggingface/transformers/pull/35438 for details.
"""
loss = super().compute_loss(model, inputs, return_outputs)
if is_transformers_version_equal_to_4_46() and kwargs.pop("num_items_in_batch", False):
if is_transformers_version_equal_to_4_46() and kwargs.get("num_items_in_batch"):
if return_outputs:
return (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
loss = (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
else:
return loss / self.args.gradient_accumulation_steps
loss = loss / self.args.gradient_accumulation_steps

return loss

Expand Down
12 changes: 5 additions & 7 deletions src/llamafactory/train/pt/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from transformers import Trainer
from typing_extensions import override

from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
from ...extras.packages import is_transformers_version_greater_than
from ..callbacks import SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler

Expand Down Expand Up @@ -78,15 +78,13 @@ def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
r"""
Fixes the loss value for transformers 4.46.0.
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
Fixes the loss value. See https://github.com/huggingface/transformers/pull/35438 for details.
"""
loss = super().compute_loss(model, inputs, return_outputs, **kwargs)
if is_transformers_version_equal_to_4_46() and not getattr(self, "model_accepts_loss_kwargs", False):
# other model should not scale the loss
if kwargs.get("num_items_in_batch") and not getattr(self, "model_accepts_loss_kwargs", False):
if return_outputs:
return (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
loss = (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
else:
return loss / self.args.gradient_accumulation_steps
loss = loss / self.args.gradient_accumulation_steps

return loss
5 changes: 3 additions & 2 deletions src/llamafactory/train/rm/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(
kwargs["processing_class"] = kwargs.pop("tokenizer")

super().__init__(**kwargs)
self.model_accepts_loss_kwargs = False # overwrite trainer's default behavior
self.finetuning_args = finetuning_args
self.can_return_loss = True # override property to return eval_loss
self.add_callback(FixValueHeadModelCallback)
Expand Down Expand Up @@ -107,8 +108,8 @@ def compute_loss(

loss = -torch.nn.functional.logsigmoid(chosen_scores.float() - rejected_scores.float()).mean()

if is_transformers_version_equal_to_4_46() and kwargs.pop("num_items_in_batch", False):
loss /= self.args.gradient_accumulation_steps # fixes the loss value for transformers 4.46.0
if is_transformers_version_equal_to_4_46() and kwargs.get("num_items_in_batch"):
loss /= self.args.gradient_accumulation_steps # fixes the loss value for transformers 4.46.0-4.46.1

if return_outputs:
return loss, (loss, chosen_scores, rejected_scores)
Expand Down
12 changes: 5 additions & 7 deletions src/llamafactory/train/sft/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

from ...extras import logging
from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
from ...extras.packages import is_transformers_version_greater_than
from ..callbacks import SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler

Expand Down Expand Up @@ -93,16 +93,14 @@ def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
r"""
Fixes the loss value for transformers 4.46.0.
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
Fixes the loss value. See https://github.com/huggingface/transformers/pull/35438 for details.
"""
loss = super().compute_loss(model, inputs, return_outputs, **kwargs)
if is_transformers_version_equal_to_4_46() and not getattr(self, "model_accepts_loss_kwargs", False):
# other model should not scale the loss
if kwargs.get("num_items_in_batch") and not getattr(self, "model_accepts_loss_kwargs", False):
if return_outputs:
return (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
loss = (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
else:
return loss / self.args.gradient_accumulation_steps
loss = loss / self.args.gradient_accumulation_steps

return loss

Expand Down

0 comments on commit b558902

Please sign in to comment.