Skip to content

Commit

Permalink
💧 Generalize disable_dropout (#2511)
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec authored Dec 22, 2024
1 parent 8fb267f commit 5239b94
Show file tree
Hide file tree
Showing 15 changed files with 35 additions and 13 deletions.
3 changes: 3 additions & 0 deletions trl/trainer/bco_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ class BCOConfig(TrainingArguments):
truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
This argument is required if you want to use the default data collator.
disable_dropout (`bool`, *optional*, defaults to `True`):
Whether to disable dropout in the model and reference model.
generate_during_eval (`bool`, *optional*, defaults to `False`):
If `True`, generates and logs completions from both the model and the reference model to W&B during
evaluation.
Expand Down Expand Up @@ -78,6 +80,7 @@ class BCOConfig(TrainingArguments):
label_pad_token_id: int = -100
padding_value: Optional[int] = None
truncation_mode: str = "keep_end"
disable_dropout: bool = True
generate_during_eval: bool = False
is_encoder_decoder: Optional[bool] = None
precompute_ref_log_probs: bool = False
Expand Down
11 changes: 5 additions & 6 deletions trl/trainer/bco_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,8 +309,6 @@ class BCOTrainer(Trainer):
The function to use to preprocess the logits before computing the metrics.
peft_config (`dict`, defaults to `None`):
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
disable_dropout (`bool`, defaults to `True`):
Whether or not to disable dropouts in `model` and `ref_model`.
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
The function to use to compute the metrics. Must take a `EvalPrediction` and return
a dictionary string to metric values.
Expand Down Expand Up @@ -538,10 +536,11 @@ def make_inputs_require_grad(module, input, output):
else:
self.use_dpo_data_collator = False

# disable dropout in the model and reference model
disable_dropout_in_model(model)
if self.ref_model is not None:
disable_dropout_in_model(self.ref_model)
# Disable dropout in the model and reference model
if args.disable_dropout:
disable_dropout_in_model(model)
if self.ref_model is not None:
disable_dropout_in_model(self.ref_model)

self.max_length = max_length
self.generate_during_eval = args.generate_during_eval
Expand Down
1 change: 1 addition & 0 deletions trl/trainer/cpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ def make_inputs_require_grad(module, input, output):
else:
self.use_dpo_data_collator = False

# Disable dropout in the model
if args.disable_dropout:
disable_dropout_in_model(model)

Expand Down
1 change: 1 addition & 0 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ def make_inputs_require_grad(module, input, output):
if data_collator is None:
data_collator = PreferenceCollator(pad_token_id=self.padding_value)

# Disable dropout in the model and reference model
if args.disable_dropout:
disable_dropout_in_model(model)
if self.ref_model is not None:
Expand Down
2 changes: 1 addition & 1 deletion trl/trainer/gkd_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class GKDConfig(SFTConfig):
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the teacher model
from a string.
disable_dropout (`bool`, *optional*, defaults to `True`):
Whether or not to disable dropouts in `model`.
Whether to disable dropout in the model.
seq_kd (`bool`, *optional*, defaults to `False`):
Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised FT
on teacher-generated output).
Expand Down
1 change: 1 addition & 0 deletions trl/trainer/gkd_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def __init__(
else:
teacher_model = AutoModelForCausalLM.from_pretrained(teacher_model, **teacher_model_init_kwargs)

# Disable dropout in the model
if args.disable_dropout:
disable_dropout_in_model(self.model)

Expand Down
2 changes: 1 addition & 1 deletion trl/trainer/kto_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class KTOConfig(TrainingArguments):
dataset_num_proc: (`Optional[int]`, *optional*, defaults to `None`):
Number of processes to use for processing the dataset.
disable_dropout (`bool`, *optional*, defaults to `True`):
Whether to disable dropout in the model.
Whether to disable dropout in the model and reference model.
"""

learning_rate: float = 1e-6
Expand Down
3 changes: 1 addition & 2 deletions trl/trainer/kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,8 +304,6 @@ class KTOTrainer(Trainer):
The function to use to preprocess the logits before computing the metrics.
peft_config (`dict`, defaults to `None`):
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
disable_dropout (`bool`, defaults to `True`):
Whether or not to disable dropouts in `model` and `ref_model`.
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
The function to use to compute the metrics. Must take a `EvalPrediction` and return
a dictionary string to metric values.
Expand Down Expand Up @@ -526,6 +524,7 @@ def make_inputs_require_grad(module, input, output):
else:
self.use_dpo_data_collator = False

# Disable dropout in the model and reference model
if args.disable_dropout:
disable_dropout_in_model(model)
if self.ref_model is not None:
Expand Down
2 changes: 1 addition & 1 deletion trl/trainer/online_dpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class OnlineDPOConfig(TrainingArguments):
dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`):
Number of processes to use for processing the dataset.
disable_dropout (`bool`, *optional*, defaults to `True`):
Whether to disable dropout in the model.
Whether to disable dropout in the model and reference model.
"""

learning_rate: float = 5e-7
Expand Down
4 changes: 3 additions & 1 deletion trl/trainer/online_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,11 @@ def __init__(
# Get peft model with the given config
model = get_peft_model(model, peft_config)

# Disable dropout in the model if specified
# Disable dropout in the model and reference model
if args.disable_dropout:
disable_dropout_in_model(model)
if self.ref_model is not None:
disable_dropout_in_model(self.ref_model)

# Handle the ref_model
# Usually, the user wants the ref model to be the initial version of the model. When using PEFT, it's easy to
Expand Down
1 change: 1 addition & 0 deletions trl/trainer/orpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ def make_inputs_require_grad(module, input, output):
else:
self.use_dpo_data_collator = False

# Disable dropout in the model and reference model
if args.disable_dropout:
disable_dropout_in_model(model)

Expand Down
3 changes: 3 additions & 0 deletions trl/trainer/prm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ class PRMConfig(TrainingArguments):
Maximum length of the sequences (prompt + completion) used for truncation.
max_completion_length (`Optional[int]`, *optional*, defaults to `None`):
Maximum length of the completion used for truncation. The completion is the concatenation of the steps.
disable_dropout (`bool`, *optional*, defaults to `True`):
Whether to disable dropout in the model.
step_separator (`str`, *optional*, defaults to `"\n"`):
Separator used to separate each step of the reasoning process.
train_on_last_step_only (`bool`, *optional*, defaults to `False`):
Expand All @@ -46,6 +48,7 @@ class PRMConfig(TrainingArguments):
learning_rate: float = 1e-5
max_length: Optional[int] = None
max_completion_length: Optional[int] = None
disable_dropout: bool = True
step_separator: str = "\n"
train_on_last_step_only: bool = False
dataset_num_proc: Optional[int] = None
6 changes: 5 additions & 1 deletion trl/trainer/prm_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from transformers.utils import is_peft_available

from .prm_config import PRMConfig
from .utils import compute_accuracy, generate_model_card
from .utils import compute_accuracy, disable_dropout_in_model, generate_model_card


if is_peft_available():
Expand Down Expand Up @@ -130,6 +130,10 @@ def __init__(

model = get_peft_model(model, peft_config)

# Disable dropout in the model
if args.disable_dropout:
disable_dropout_in_model(model)

if compute_metrics is None:
compute_metrics = compute_accuracy

Expand Down
3 changes: 3 additions & 0 deletions trl/trainer/reward_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ class RewardConfig(TrainingArguments):
max_length (`Optional[int]`, *optional*, defaults to `None`):
Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
to use the default data collator.
disable_dropout (`bool`, *optional*, defaults to `True`):
Whether to disable dropout in the model.
dataset_num_proc (`int`, *optional*, defaults to `None`):
Number of processes to use for processing the dataset.
center_rewards_coefficient (`float`, *optional*, defaults to `None`):
Expand All @@ -42,6 +44,7 @@ class RewardConfig(TrainingArguments):
"""

max_length: Optional[int] = None
disable_dropout: bool = True
dataset_num_proc: Optional[int] = None
center_rewards_coefficient: Optional[float] = None
remove_unused_columns: bool = False
5 changes: 5 additions & 0 deletions trl/trainer/reward_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
RewardDataCollatorWithPadding,
compute_accuracy,
decode_and_strip_padding,
disable_dropout_in_model,
generate_model_card,
get_comet_experiment_url,
log_table_to_comet_experiment,
Expand Down Expand Up @@ -169,6 +170,10 @@ def __init__(

model = get_peft_model(model, peft_config)

# Disable dropout in the model
if args.disable_dropout:
disable_dropout_in_model(model)

if compute_metrics is None:
compute_metrics = compute_accuracy

Expand Down

0 comments on commit 5239b94

Please sign in to comment.