Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[xxxTrainer] Add tags to all trainers in TRL #1120

Merged
merged 3 commits into from
Dec 21, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from collections import defaultdict
from contextlib import nullcontext
from copy import deepcopy
from functools import wraps
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union

import numpy as np
Expand All @@ -40,7 +41,7 @@

from ..import_utils import is_peft_available, is_wandb_available
from ..models import PreTrainedModelWrapper, create_reference_model
from .utils import DPODataCollatorWithPadding, disable_dropout_in_model, pad_to_length
from .utils import DPODataCollatorWithPadding, disable_dropout_in_model, pad_to_length, trl_sanitze_kwargs_for_tagging


if is_peft_available():
Expand Down Expand Up @@ -120,6 +121,7 @@ class DPOTrainer(Trainer):
ref_model_init_kwargs: (`Optional[Dict]`, *optional*):
Dict of Optional kwargs to pass when instantiating the ref model from a string
"""
_tag_name = "trl-dpo"

def __init__(
self,
Expand Down Expand Up @@ -1131,3 +1133,13 @@ def log(self, logs: Dict[str, float]) -> None:
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
return super().log(logs)

@wraps(Trainer.push_to_hub)
def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str:
"""
Overwrite the `push_to_hub` method in order to force-add the tag "sft" when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
"""
kwargs = trl_sanitze_kwargs_for_tagging(tag_name=self._tag_name, kwargs=kwargs)

return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs)
15 changes: 15 additions & 0 deletions trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import typing
import warnings
from contextlib import nullcontext
from functools import wraps
from typing import Callable, List, Optional, Union

import datasets
Expand All @@ -35,6 +36,7 @@
PreTrainedTokenizer,
PreTrainedTokenizerBase,
PreTrainedTokenizerFast,
Trainer,
)

from ..core import (
Expand All @@ -55,6 +57,7 @@
from ..import_utils import is_npu_available, is_torch_greater_2_0, is_xpu_available
from ..models import SUPPORTED_ARCHITECTURES, PreTrainedModelWrapper, create_reference_model
from . import AdaptiveKLController, BaseTrainer, FixedKLController, PPOConfig, RunningMoments
from .utils import trl_sanitze_kwargs_for_tagging


if is_deepspeed_available():
Expand Down Expand Up @@ -137,6 +140,8 @@ class PPOTrainer(BaseTrainer):
**lr_scheduler** (`torch.optim.lr_scheduler`, *optional*) -- Learning rate scheduler to be used for training.
"""

_tag_name = "trl-ppo"

def __init__(
self,
config: PPOConfig = None,
Expand Down Expand Up @@ -1440,3 +1445,13 @@ def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
model.eval()
return model

@wraps(Trainer.push_to_hub)
def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str:
"""
Overwrite the `push_to_hub` method in order to force-add the tag "sft" when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
"""
kwargs = trl_sanitze_kwargs_for_tagging(tag_name=self._tag_name, kwargs=kwargs)

return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs)
12 changes: 12 additions & 0 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
DataCollatorForCompletionOnlyLM,
neftune_post_forward_hook,
peft_module_casting_to_bf16,
trl_sanitze_kwargs_for_tagging,
)


Expand Down Expand Up @@ -114,6 +115,7 @@ class SFTTrainer(Trainer):
dataset_kwargs: (`Optional[Dict]`, *optional*):
Dict of Optional kwargs to pass when creating packed or non-packed datasets
"""
_tag_name = "trl-sft"

def __init__(
self,
Expand Down Expand Up @@ -326,6 +328,16 @@ def train(self, *args, **kwargs):

return output

@wraps(Trainer.push_to_hub)
def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str:
"""
Overwrite the `push_to_hub` method in order to force-add the tag "sft" when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
"""
kwargs = trl_sanitze_kwargs_for_tagging(tag_name=self._tag_name, kwargs=kwargs)

return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs)

def _prepare_dataset(
self,
dataset,
Expand Down
11 changes: 11 additions & 0 deletions trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,3 +637,14 @@ def peft_module_casting_to_bf16(model):
if hasattr(module, "weight"):
if module.weight.dtype == torch.float32:
module = module.to(torch.bfloat16)


def trl_sanitze_kwargs_for_tagging(tag_name, kwargs=None):
if kwargs is not None:
if "tags" not in kwargs:
kwargs["tags"] = [tag_name]
elif "tags" in kwargs and isinstance(kwargs["tags"], list):
kwargs["tags"].append(tag_name)
elif "tags" in kwargs and isinstance(kwargs["tags"], str):
kwargs["tags"] = [kwargs["tags"], tag_name]
return kwargs
Loading