diff --git a/trl/trainer/ddpo_trainer.py b/trl/trainer/ddpo_trainer.py index 0f1cee12f1..b6cd432b18 100644 --- a/trl/trainer/ddpo_trainer.py +++ b/trl/trainer/ddpo_trainer.py @@ -15,6 +15,7 @@ import os from collections import defaultdict from concurrent import futures +from functools import wraps from typing import Any, Callable, Optional, Tuple from warnings import warn @@ -22,10 +23,11 @@ from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed +from transformers import Trainer from ..models import DDPOStableDiffusionPipeline from . import BaseTrainer, DDPOConfig -from .utils import PerPromptStatTracker +from .utils import PerPromptStatTracker, trl_sanitze_kwargs_for_tagging logger = get_logger(__name__) @@ -46,6 +48,8 @@ class DDPOTrainer(BaseTrainer): **image_samples_hook** (Optional[Callable[[Any, Any, Any], Any]]) -- Hook to be called to log images """ + _tag_name = "trl-ddpo" + def __init__( self, config: DDPOConfig, @@ -574,3 +578,13 @@ def train(self, epochs: Optional[int] = None): def _save_pretrained(self, save_directory): self.sd_pipeline.save_pretrained(save_directory) + + @wraps(Trainer.push_to_hub) + def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str: + """ + Overwrite the `push_to_hub` method in order to force-add the tag "sft" when pushing the + model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. + """ + kwargs = trl_sanitze_kwargs_for_tagging(tag_name=self._tag_name, kwargs=kwargs) + + return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 8d8822751a..114c6ee581 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -18,6 +18,7 @@ from collections import defaultdict from contextlib import nullcontext from copy import deepcopy +from functools import wraps from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union import numpy as np @@ -40,7 +41,7 @@ from ..import_utils import is_peft_available, is_wandb_available from ..models import PreTrainedModelWrapper, create_reference_model -from .utils import DPODataCollatorWithPadding, disable_dropout_in_model, pad_to_length +from .utils import DPODataCollatorWithPadding, disable_dropout_in_model, pad_to_length, trl_sanitze_kwargs_for_tagging if is_peft_available(): @@ -120,6 +121,7 @@ class DPOTrainer(Trainer): ref_model_init_kwargs: (`Optional[Dict]`, *optional*): Dict of Optional kwargs to pass when instantiating the ref model from a string """ + _tag_name = "trl-dpo" def __init__( self, @@ -1131,3 +1133,13 @@ def log(self, logs: Dict[str, float]) -> None: logs[key] = torch.tensor(metrics).mean().item() del self._stored_metrics[train_eval] return super().log(logs) + + @wraps(Trainer.push_to_hub) + def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str: + """ + Overwrite the `push_to_hub` method in order to force-add the tag "sft" when pushing the + model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. + """ + kwargs = trl_sanitze_kwargs_for_tagging(tag_name=self._tag_name, kwargs=kwargs) + + return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs) diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index 2df1e2fee5..bd466226f3 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -18,6 +18,7 @@ import typing import warnings from contextlib import nullcontext +from functools import wraps from typing import Callable, List, Optional, Union import datasets @@ -35,6 +36,7 @@ PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast, + Trainer, ) from ..core import ( @@ -55,6 +57,7 @@ from ..import_utils import is_npu_available, is_torch_greater_2_0, is_xpu_available from ..models import SUPPORTED_ARCHITECTURES, PreTrainedModelWrapper, create_reference_model from . import AdaptiveKLController, BaseTrainer, FixedKLController, PPOConfig, RunningMoments +from .utils import trl_sanitze_kwargs_for_tagging if is_deepspeed_available(): @@ -137,6 +140,8 @@ class PPOTrainer(BaseTrainer): **lr_scheduler** (`torch.optim.lr_scheduler`, *optional*) -- Learning rate scheduler to be used for training. """ + _tag_name = "trl-ppo" + def __init__( self, config: PPOConfig = None, @@ -1440,3 +1445,13 @@ def _prepare_deepspeed(self, model: PreTrainedModelWrapper): model, *_ = deepspeed.initialize(model=model, config=config_kwargs) model.eval() return model + + @wraps(Trainer.push_to_hub) + def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str: + """ + Overwrite the `push_to_hub` method in order to force-add the tag "sft" when pushing the + model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. + """ + kwargs = trl_sanitze_kwargs_for_tagging(tag_name=self._tag_name, kwargs=kwargs) + + return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs) diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 11c08c1083..78243758cc 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -42,6 +42,7 @@ DataCollatorForCompletionOnlyLM, neftune_post_forward_hook, peft_module_casting_to_bf16, + trl_sanitze_kwargs_for_tagging, ) @@ -114,6 +115,7 @@ class SFTTrainer(Trainer): dataset_kwargs: (`Optional[Dict]`, *optional*): Dict of Optional kwargs to pass when creating packed or non-packed datasets """ + _tag_name = "trl-sft" def __init__( self, @@ -326,6 +328,16 @@ def train(self, *args, **kwargs): return output + @wraps(Trainer.push_to_hub) + def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str: + """ + Overwrite the `push_to_hub` method in order to force-add the tag "sft" when pushing the + model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. + """ + kwargs = trl_sanitze_kwargs_for_tagging(tag_name=self._tag_name, kwargs=kwargs) + + return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs) + def _prepare_dataset( self, dataset, diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 13eaa91db5..33d6c6f092 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -637,3 +637,14 @@ def peft_module_casting_to_bf16(model): if hasattr(module, "weight"): if module.weight.dtype == torch.float32: module = module.to(torch.bfloat16) + + +def trl_sanitze_kwargs_for_tagging(tag_name, kwargs=None): + if kwargs is not None: + if "tags" not in kwargs: + kwargs["tags"] = [tag_name] + elif "tags" in kwargs and isinstance(kwargs["tags"], list): + kwargs["tags"].append(tag_name) + elif "tags" in kwargs and isinstance(kwargs["tags"], str): + kwargs["tags"] = [kwargs["tags"], tag_name] + return kwargs