From 5239b9462dccd8751936f6f4e181a6db0c2bef7d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Sun, 22 Dec 2024 12:19:17 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=92=A7=20Generalize=20`disable=5Fdropout`?= =?UTF-8?q?=20(#2511)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- trl/trainer/bco_config.py | 3 +++ trl/trainer/bco_trainer.py | 11 +++++------ trl/trainer/cpo_trainer.py | 1 + trl/trainer/dpo_trainer.py | 1 + trl/trainer/gkd_config.py | 2 +- trl/trainer/gkd_trainer.py | 1 + trl/trainer/kto_config.py | 2 +- trl/trainer/kto_trainer.py | 3 +-- trl/trainer/online_dpo_config.py | 2 +- trl/trainer/online_dpo_trainer.py | 4 +++- trl/trainer/orpo_trainer.py | 1 + trl/trainer/prm_config.py | 3 +++ trl/trainer/prm_trainer.py | 6 +++++- trl/trainer/reward_config.py | 3 +++ trl/trainer/reward_trainer.py | 5 +++++ 15 files changed, 35 insertions(+), 13 deletions(-) diff --git a/trl/trainer/bco_config.py b/trl/trainer/bco_config.py index 10cd82b9f5..b3398ae914 100644 --- a/trl/trainer/bco_config.py +++ b/trl/trainer/bco_config.py @@ -46,6 +46,8 @@ class BCOConfig(TrainingArguments): truncation_mode (`str`, *optional*, defaults to `"keep_end"`): Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`. This argument is required if you want to use the default data collator. + disable_dropout (`bool`, *optional*, defaults to `True`): + Whether to disable dropout in the model and reference model. generate_during_eval (`bool`, *optional*, defaults to `False`): If `True`, generates and logs completions from both the model and the reference model to W&B during evaluation. @@ -78,6 +80,7 @@ class BCOConfig(TrainingArguments): label_pad_token_id: int = -100 padding_value: Optional[int] = None truncation_mode: str = "keep_end" + disable_dropout: bool = True generate_during_eval: bool = False is_encoder_decoder: Optional[bool] = None precompute_ref_log_probs: bool = False diff --git a/trl/trainer/bco_trainer.py b/trl/trainer/bco_trainer.py index c2d58ab3f2..1c26516793 100644 --- a/trl/trainer/bco_trainer.py +++ b/trl/trainer/bco_trainer.py @@ -309,8 +309,6 @@ class BCOTrainer(Trainer): The function to use to preprocess the logits before computing the metrics. peft_config (`dict`, defaults to `None`): The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model. - disable_dropout (`bool`, defaults to `True`): - Whether or not to disable dropouts in `model` and `ref_model`. compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*): The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to metric values. @@ -538,10 +536,11 @@ def make_inputs_require_grad(module, input, output): else: self.use_dpo_data_collator = False - # disable dropout in the model and reference model - disable_dropout_in_model(model) - if self.ref_model is not None: - disable_dropout_in_model(self.ref_model) + # Disable dropout in the model and reference model + if args.disable_dropout: + disable_dropout_in_model(model) + if self.ref_model is not None: + disable_dropout_in_model(self.ref_model) self.max_length = max_length self.generate_during_eval = args.generate_during_eval diff --git a/trl/trainer/cpo_trainer.py b/trl/trainer/cpo_trainer.py index 6d236cfb37..4720517b6a 100644 --- a/trl/trainer/cpo_trainer.py +++ b/trl/trainer/cpo_trainer.py @@ -268,6 +268,7 @@ def make_inputs_require_grad(module, input, output): else: self.use_dpo_data_collator = False + # Disable dropout in the model if args.disable_dropout: disable_dropout_in_model(model) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 04b0583237..d820857de1 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -376,6 +376,7 @@ def make_inputs_require_grad(module, input, output): if data_collator is None: data_collator = PreferenceCollator(pad_token_id=self.padding_value) + # Disable dropout in the model and reference model if args.disable_dropout: disable_dropout_in_model(model) if self.ref_model is not None: diff --git a/trl/trainer/gkd_config.py b/trl/trainer/gkd_config.py index e9b9d76363..e110b047d1 100644 --- a/trl/trainer/gkd_config.py +++ b/trl/trainer/gkd_config.py @@ -41,7 +41,7 @@ class GKDConfig(SFTConfig): Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the teacher model from a string. disable_dropout (`bool`, *optional*, defaults to `True`): - Whether or not to disable dropouts in `model`. + Whether to disable dropout in the model. seq_kd (`bool`, *optional*, defaults to `False`): Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised FT on teacher-generated output). diff --git a/trl/trainer/gkd_trainer.py b/trl/trainer/gkd_trainer.py index be48f1925b..f212b5a296 100644 --- a/trl/trainer/gkd_trainer.py +++ b/trl/trainer/gkd_trainer.py @@ -126,6 +126,7 @@ def __init__( else: teacher_model = AutoModelForCausalLM.from_pretrained(teacher_model, **teacher_model_init_kwargs) + # Disable dropout in the model if args.disable_dropout: disable_dropout_in_model(self.model) diff --git a/trl/trainer/kto_config.py b/trl/trainer/kto_config.py index 563d0cdbc9..e5feb2dbad 100644 --- a/trl/trainer/kto_config.py +++ b/trl/trainer/kto_config.py @@ -77,7 +77,7 @@ class KTOConfig(TrainingArguments): dataset_num_proc: (`Optional[int]`, *optional*, defaults to `None`): Number of processes to use for processing the dataset. disable_dropout (`bool`, *optional*, defaults to `True`): - Whether to disable dropout in the model. + Whether to disable dropout in the model and reference model. """ learning_rate: float = 1e-6 diff --git a/trl/trainer/kto_trainer.py b/trl/trainer/kto_trainer.py index d054d97e7d..b19955c145 100644 --- a/trl/trainer/kto_trainer.py +++ b/trl/trainer/kto_trainer.py @@ -304,8 +304,6 @@ class KTOTrainer(Trainer): The function to use to preprocess the logits before computing the metrics. peft_config (`dict`, defaults to `None`): The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model. - disable_dropout (`bool`, defaults to `True`): - Whether or not to disable dropouts in `model` and `ref_model`. compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*): The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to metric values. @@ -526,6 +524,7 @@ def make_inputs_require_grad(module, input, output): else: self.use_dpo_data_collator = False + # Disable dropout in the model and reference model if args.disable_dropout: disable_dropout_in_model(model) if self.ref_model is not None: diff --git a/trl/trainer/online_dpo_config.py b/trl/trainer/online_dpo_config.py index 0b06c79cb5..5e75ede883 100644 --- a/trl/trainer/online_dpo_config.py +++ b/trl/trainer/online_dpo_config.py @@ -57,7 +57,7 @@ class OnlineDPOConfig(TrainingArguments): dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`): Number of processes to use for processing the dataset. disable_dropout (`bool`, *optional*, defaults to `True`): - Whether to disable dropout in the model. + Whether to disable dropout in the model and reference model. """ learning_rate: float = 5e-7 diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index 68008881f5..ebab5cdcfc 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -196,9 +196,11 @@ def __init__( # Get peft model with the given config model = get_peft_model(model, peft_config) - # Disable dropout in the model if specified + # Disable dropout in the model and reference model if args.disable_dropout: disable_dropout_in_model(model) + if self.ref_model is not None: + disable_dropout_in_model(self.ref_model) # Handle the ref_model # Usually, the user wants the ref model to be the initial version of the model. When using PEFT, it's easy to diff --git a/trl/trainer/orpo_trainer.py b/trl/trainer/orpo_trainer.py index 50392526db..65d80802be 100644 --- a/trl/trainer/orpo_trainer.py +++ b/trl/trainer/orpo_trainer.py @@ -282,6 +282,7 @@ def make_inputs_require_grad(module, input, output): else: self.use_dpo_data_collator = False + # Disable dropout in the model and reference model if args.disable_dropout: disable_dropout_in_model(model) diff --git a/trl/trainer/prm_config.py b/trl/trainer/prm_config.py index 4558084572..21a4fc5662 100644 --- a/trl/trainer/prm_config.py +++ b/trl/trainer/prm_config.py @@ -35,6 +35,8 @@ class PRMConfig(TrainingArguments): Maximum length of the sequences (prompt + completion) used for truncation. max_completion_length (`Optional[int]`, *optional*, defaults to `None`): Maximum length of the completion used for truncation. The completion is the concatenation of the steps. + disable_dropout (`bool`, *optional*, defaults to `True`): + Whether to disable dropout in the model. step_separator (`str`, *optional*, defaults to `"\n"`): Separator used to separate each step of the reasoning process. train_on_last_step_only (`bool`, *optional*, defaults to `False`): @@ -46,6 +48,7 @@ class PRMConfig(TrainingArguments): learning_rate: float = 1e-5 max_length: Optional[int] = None max_completion_length: Optional[int] = None + disable_dropout: bool = True step_separator: str = "\n" train_on_last_step_only: bool = False dataset_num_proc: Optional[int] = None diff --git a/trl/trainer/prm_trainer.py b/trl/trainer/prm_trainer.py index dbb3558d57..47d73ce19c 100644 --- a/trl/trainer/prm_trainer.py +++ b/trl/trainer/prm_trainer.py @@ -39,7 +39,7 @@ from transformers.utils import is_peft_available from .prm_config import PRMConfig -from .utils import compute_accuracy, generate_model_card +from .utils import compute_accuracy, disable_dropout_in_model, generate_model_card if is_peft_available(): @@ -130,6 +130,10 @@ def __init__( model = get_peft_model(model, peft_config) + # Disable dropout in the model + if args.disable_dropout: + disable_dropout_in_model(model) + if compute_metrics is None: compute_metrics = compute_accuracy diff --git a/trl/trainer/reward_config.py b/trl/trainer/reward_config.py index 6e3eeab372..8018a2844c 100644 --- a/trl/trainer/reward_config.py +++ b/trl/trainer/reward_config.py @@ -31,6 +31,8 @@ class RewardConfig(TrainingArguments): max_length (`Optional[int]`, *optional*, defaults to `None`): Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want to use the default data collator. + disable_dropout (`bool`, *optional*, defaults to `True`): + Whether to disable dropout in the model. dataset_num_proc (`int`, *optional*, defaults to `None`): Number of processes to use for processing the dataset. center_rewards_coefficient (`float`, *optional*, defaults to `None`): @@ -42,6 +44,7 @@ class RewardConfig(TrainingArguments): """ max_length: Optional[int] = None + disable_dropout: bool = True dataset_num_proc: Optional[int] = None center_rewards_coefficient: Optional[float] = None remove_unused_columns: bool = False diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py index 109d8a47cf..79b237b9e7 100644 --- a/trl/trainer/reward_trainer.py +++ b/trl/trainer/reward_trainer.py @@ -47,6 +47,7 @@ RewardDataCollatorWithPadding, compute_accuracy, decode_and_strip_padding, + disable_dropout_in_model, generate_model_card, get_comet_experiment_url, log_table_to_comet_experiment, @@ -169,6 +170,10 @@ def __init__( model = get_peft_model(model, peft_config) + # Disable dropout in the model + if args.disable_dropout: + disable_dropout_in_model(model) + if compute_metrics is None: compute_metrics = compute_accuracy