From b271f2a99dd37923e7512e350ccf479000accb53 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 20 Dec 2023 16:48:01 +0000 Subject: [PATCH 1/3] add tags to sfttrainer --- trl/trainer/sft_trainer.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 11c08c1083..855938072d 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -326,6 +326,21 @@ def train(self, *args, **kwargs): return output + @wraps(Trainer.push_to_hub) + def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str: + """ + Overwrite the `push_to_hub` method in order to force-add the tag "sft" when pushing the + model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. + """ + if "tags" not in kwargs: + kwargs["tags"] = ["sft"] + elif "tags" in kwargs and isinstance(kwargs["tags"], list): + kwargs["tags"].append("sft") + elif "tags" in kwargs and isinstance(kwargs["tags"], str): + kwargs["tags"] = [kwargs["tags"], "sft"] + + return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs) + def _prepare_dataset( self, dataset, From 65fc0e838ccc886195e85217a6c354752070c2a5 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 20 Dec 2023 16:59:26 +0000 Subject: [PATCH 2/3] extend it to other trainers --- trl/trainer/dpo_trainer.py | 14 +++++++++++++- trl/trainer/ppo_trainer.py | 15 +++++++++++++++ trl/trainer/sft_trainer.py | 9 +++------ trl/trainer/utils.py | 11 +++++++++++ 4 files changed, 42 insertions(+), 7 deletions(-) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 8d8822751a..114c6ee581 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -18,6 +18,7 @@ from collections import defaultdict from contextlib import nullcontext from copy import deepcopy +from functools import wraps from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union import numpy as np @@ -40,7 +41,7 @@ from ..import_utils import is_peft_available, is_wandb_available from ..models import PreTrainedModelWrapper, create_reference_model -from .utils import DPODataCollatorWithPadding, disable_dropout_in_model, pad_to_length +from .utils import DPODataCollatorWithPadding, disable_dropout_in_model, pad_to_length, trl_sanitze_kwargs_for_tagging if is_peft_available(): @@ -120,6 +121,7 @@ class DPOTrainer(Trainer): ref_model_init_kwargs: (`Optional[Dict]`, *optional*): Dict of Optional kwargs to pass when instantiating the ref model from a string """ + _tag_name = "trl-dpo" def __init__( self, @@ -1131,3 +1133,13 @@ def log(self, logs: Dict[str, float]) -> None: logs[key] = torch.tensor(metrics).mean().item() del self._stored_metrics[train_eval] return super().log(logs) + + @wraps(Trainer.push_to_hub) + def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str: + """ + Overwrite the `push_to_hub` method in order to force-add the tag "sft" when pushing the + model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. + """ + kwargs = trl_sanitze_kwargs_for_tagging(tag_name=self._tag_name, kwargs=kwargs) + + return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs) diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index 2df1e2fee5..bd466226f3 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -18,6 +18,7 @@ import typing import warnings from contextlib import nullcontext +from functools import wraps from typing import Callable, List, Optional, Union import datasets @@ -35,6 +36,7 @@ PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast, + Trainer, ) from ..core import ( @@ -55,6 +57,7 @@ from ..import_utils import is_npu_available, is_torch_greater_2_0, is_xpu_available from ..models import SUPPORTED_ARCHITECTURES, PreTrainedModelWrapper, create_reference_model from . import AdaptiveKLController, BaseTrainer, FixedKLController, PPOConfig, RunningMoments +from .utils import trl_sanitze_kwargs_for_tagging if is_deepspeed_available(): @@ -137,6 +140,8 @@ class PPOTrainer(BaseTrainer): **lr_scheduler** (`torch.optim.lr_scheduler`, *optional*) -- Learning rate scheduler to be used for training. """ + _tag_name = "trl-ppo" + def __init__( self, config: PPOConfig = None, @@ -1440,3 +1445,13 @@ def _prepare_deepspeed(self, model: PreTrainedModelWrapper): model, *_ = deepspeed.initialize(model=model, config=config_kwargs) model.eval() return model + + @wraps(Trainer.push_to_hub) + def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str: + """ + Overwrite the `push_to_hub` method in order to force-add the tag "sft" when pushing the + model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. + """ + kwargs = trl_sanitze_kwargs_for_tagging(tag_name=self._tag_name, kwargs=kwargs) + + return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs) diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 855938072d..78243758cc 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -42,6 +42,7 @@ DataCollatorForCompletionOnlyLM, neftune_post_forward_hook, peft_module_casting_to_bf16, + trl_sanitze_kwargs_for_tagging, ) @@ -114,6 +115,7 @@ class SFTTrainer(Trainer): dataset_kwargs: (`Optional[Dict]`, *optional*): Dict of Optional kwargs to pass when creating packed or non-packed datasets """ + _tag_name = "trl-sft" def __init__( self, @@ -332,12 +334,7 @@ def push_to_hub(self, commit_message: Optional[str] = "End of training", blockin Overwrite the `push_to_hub` method in order to force-add the tag "sft" when pushing the model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. """ - if "tags" not in kwargs: - kwargs["tags"] = ["sft"] - elif "tags" in kwargs and isinstance(kwargs["tags"], list): - kwargs["tags"].append("sft") - elif "tags" in kwargs and isinstance(kwargs["tags"], str): - kwargs["tags"] = [kwargs["tags"], "sft"] + kwargs = trl_sanitze_kwargs_for_tagging(tag_name=self._tag_name, kwargs=kwargs) return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs) diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 13eaa91db5..33d6c6f092 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -637,3 +637,14 @@ def peft_module_casting_to_bf16(model): if hasattr(module, "weight"): if module.weight.dtype == torch.float32: module = module.to(torch.bfloat16) + + +def trl_sanitze_kwargs_for_tagging(tag_name, kwargs=None): + if kwargs is not None: + if "tags" not in kwargs: + kwargs["tags"] = [tag_name] + elif "tags" in kwargs and isinstance(kwargs["tags"], list): + kwargs["tags"].append(tag_name) + elif "tags" in kwargs and isinstance(kwargs["tags"], str): + kwargs["tags"] = [kwargs["tags"], tag_name] + return kwargs From 174685529f8855e703f94dc723865a38a0e7176b Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 21 Dec 2023 15:55:56 +0000 Subject: [PATCH 3/3] add for ddpo --- trl/trainer/ddpo_trainer.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/trl/trainer/ddpo_trainer.py b/trl/trainer/ddpo_trainer.py index 0f1cee12f1..b6cd432b18 100644 --- a/trl/trainer/ddpo_trainer.py +++ b/trl/trainer/ddpo_trainer.py @@ -15,6 +15,7 @@ import os from collections import defaultdict from concurrent import futures +from functools import wraps from typing import Any, Callable, Optional, Tuple from warnings import warn @@ -22,10 +23,11 @@ from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed +from transformers import Trainer from ..models import DDPOStableDiffusionPipeline from . import BaseTrainer, DDPOConfig -from .utils import PerPromptStatTracker +from .utils import PerPromptStatTracker, trl_sanitze_kwargs_for_tagging logger = get_logger(__name__) @@ -46,6 +48,8 @@ class DDPOTrainer(BaseTrainer): **image_samples_hook** (Optional[Callable[[Any, Any, Any], Any]]) -- Hook to be called to log images """ + _tag_name = "trl-ddpo" + def __init__( self, config: DDPOConfig, @@ -574,3 +578,13 @@ def train(self, epochs: Optional[int] = None): def _save_pretrained(self, save_directory): self.sd_pipeline.save_pretrained(save_directory) + + @wraps(Trainer.push_to_hub) + def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str: + """ + Overwrite the `push_to_hub` method in order to force-add the tag "sft" when pushing the + model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. + """ + kwargs = trl_sanitze_kwargs_for_tagging(tag_name=self._tag_name, kwargs=kwargs) + + return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs)