Skip to content

Commit

Permalink
[xxxTrainer] Add tags to all trainers in TRL (#1120)
Browse files Browse the repository at this point in the history
* add tags to sfttrainer

* extend it to other trainers

* add for ddpo
  • Loading branch information
younesbelkada authored Dec 21, 2023
1 parent 2aff709 commit b07935f
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 2 deletions.
16 changes: 15 additions & 1 deletion trl/trainer/ddpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,19 @@
import os
from collections import defaultdict
from concurrent import futures
from functools import wraps
from typing import Any, Callable, Optional, Tuple
from warnings import warn

import torch
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from transformers import Trainer

from ..models import DDPOStableDiffusionPipeline
from . import BaseTrainer, DDPOConfig
from .utils import PerPromptStatTracker
from .utils import PerPromptStatTracker, trl_sanitze_kwargs_for_tagging


logger = get_logger(__name__)
Expand All @@ -46,6 +48,8 @@ class DDPOTrainer(BaseTrainer):
**image_samples_hook** (Optional[Callable[[Any, Any, Any], Any]]) -- Hook to be called to log images
"""

_tag_name = "trl-ddpo"

def __init__(
self,
config: DDPOConfig,
Expand Down Expand Up @@ -574,3 +578,13 @@ def train(self, epochs: Optional[int] = None):

def _save_pretrained(self, save_directory):
self.sd_pipeline.save_pretrained(save_directory)

@wraps(Trainer.push_to_hub)
def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str:
"""
Overwrite the `push_to_hub` method in order to force-add the tag "sft" when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
"""
kwargs = trl_sanitze_kwargs_for_tagging(tag_name=self._tag_name, kwargs=kwargs)

return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs)
14 changes: 13 additions & 1 deletion trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from collections import defaultdict
from contextlib import nullcontext
from copy import deepcopy
from functools import wraps
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union

import numpy as np
Expand All @@ -40,7 +41,7 @@

from ..import_utils import is_peft_available, is_wandb_available
from ..models import PreTrainedModelWrapper, create_reference_model
from .utils import DPODataCollatorWithPadding, disable_dropout_in_model, pad_to_length
from .utils import DPODataCollatorWithPadding, disable_dropout_in_model, pad_to_length, trl_sanitze_kwargs_for_tagging


if is_peft_available():
Expand Down Expand Up @@ -120,6 +121,7 @@ class DPOTrainer(Trainer):
ref_model_init_kwargs: (`Optional[Dict]`, *optional*):
Dict of Optional kwargs to pass when instantiating the ref model from a string
"""
_tag_name = "trl-dpo"

def __init__(
self,
Expand Down Expand Up @@ -1131,3 +1133,13 @@ def log(self, logs: Dict[str, float]) -> None:
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
return super().log(logs)

@wraps(Trainer.push_to_hub)
def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str:
"""
Overwrite the `push_to_hub` method in order to force-add the tag "sft" when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
"""
kwargs = trl_sanitze_kwargs_for_tagging(tag_name=self._tag_name, kwargs=kwargs)

return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs)
15 changes: 15 additions & 0 deletions trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import typing
import warnings
from contextlib import nullcontext
from functools import wraps
from typing import Callable, List, Optional, Union

import datasets
Expand All @@ -35,6 +36,7 @@
PreTrainedTokenizer,
PreTrainedTokenizerBase,
PreTrainedTokenizerFast,
Trainer,
)

from ..core import (
Expand All @@ -55,6 +57,7 @@
from ..import_utils import is_npu_available, is_torch_greater_2_0, is_xpu_available
from ..models import SUPPORTED_ARCHITECTURES, PreTrainedModelWrapper, create_reference_model
from . import AdaptiveKLController, BaseTrainer, FixedKLController, PPOConfig, RunningMoments
from .utils import trl_sanitze_kwargs_for_tagging


if is_deepspeed_available():
Expand Down Expand Up @@ -137,6 +140,8 @@ class PPOTrainer(BaseTrainer):
**lr_scheduler** (`torch.optim.lr_scheduler`, *optional*) -- Learning rate scheduler to be used for training.
"""

_tag_name = "trl-ppo"

def __init__(
self,
config: PPOConfig = None,
Expand Down Expand Up @@ -1440,3 +1445,13 @@ def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
model.eval()
return model

@wraps(Trainer.push_to_hub)
def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str:
"""
Overwrite the `push_to_hub` method in order to force-add the tag "sft" when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
"""
kwargs = trl_sanitze_kwargs_for_tagging(tag_name=self._tag_name, kwargs=kwargs)

return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs)
12 changes: 12 additions & 0 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
DataCollatorForCompletionOnlyLM,
neftune_post_forward_hook,
peft_module_casting_to_bf16,
trl_sanitze_kwargs_for_tagging,
)


Expand Down Expand Up @@ -114,6 +115,7 @@ class SFTTrainer(Trainer):
dataset_kwargs: (`Optional[Dict]`, *optional*):
Dict of Optional kwargs to pass when creating packed or non-packed datasets
"""
_tag_name = "trl-sft"

def __init__(
self,
Expand Down Expand Up @@ -326,6 +328,16 @@ def train(self, *args, **kwargs):

return output

@wraps(Trainer.push_to_hub)
def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str:
"""
Overwrite the `push_to_hub` method in order to force-add the tag "sft" when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
"""
kwargs = trl_sanitze_kwargs_for_tagging(tag_name=self._tag_name, kwargs=kwargs)

return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs)

def _prepare_dataset(
self,
dataset,
Expand Down
11 changes: 11 additions & 0 deletions trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,3 +637,14 @@ def peft_module_casting_to_bf16(model):
if hasattr(module, "weight"):
if module.weight.dtype == torch.float32:
module = module.to(torch.bfloat16)


def trl_sanitze_kwargs_for_tagging(tag_name, kwargs=None):
if kwargs is not None:
if "tags" not in kwargs:
kwargs["tags"] = [tag_name]
elif "tags" in kwargs and isinstance(kwargs["tags"], list):
kwargs["tags"].append(tag_name)
elif "tags" in kwargs and isinstance(kwargs["tags"], str):
kwargs["tags"] = [kwargs["tags"], tag_name]
return kwargs

0 comments on commit b07935f

Please sign in to comment.