Skip to content

Commit

Permalink
Adds model kwargs to SFT and DPO trainers (#951)
Browse files Browse the repository at this point in the history
* adds model kwargs to SFT and DPO trainers

* adds checks for model_kwarg passing when model is not str

* changed warning to ValueError

* renames model_kwargs to model_init_kwargs

* corrects argument names in
  • Loading branch information
edbeeching authored Nov 6, 2023
1 parent 6c6ff24 commit c273b18
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 12 deletions.
46 changes: 43 additions & 3 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,14 @@
from accelerate.utils import is_deepspeed_available
from datasets import Dataset
from torch.utils.data import DataLoader
from transformers import DataCollator, PreTrainedModel, PreTrainedTokenizerBase, Trainer, TrainingArguments
from transformers import (
AutoModelForCausalLM,
DataCollator,
PreTrainedModel,
PreTrainedTokenizerBase,
Trainer,
TrainingArguments,
)
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalLoopOutput

Expand Down Expand Up @@ -100,12 +107,17 @@ class DPOTrainer(Trainer):
compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
The function to use to compute the metrics. Must take a `EvalPrediction` and return
a dictionary string to metric values.
model_init_kwargs: (`Optional[Dict]`, *optional*):
Dict of Optional kwargs to pass when instantiating the model from a string
ref_model_init_kwargs: (`Optional[Dict]`, *optional*):
Dict of Optional kwargs to pass when instantiating the ref model from a string
"""

def __init__(
self,
model: Union[PreTrainedModel, nn.Module] = None,
ref_model: Optional[Union[PreTrainedModel, nn.Module]] = None,
model: Union[PreTrainedModel, nn.Module, str] = None,
ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
beta: float = 0.1,
loss_type: Literal["sigmoid", "hinge"] = "sigmoid",
args: TrainingArguments = None,
Expand All @@ -131,7 +143,35 @@ def __init__(
disable_dropout: bool = True,
generate_during_eval: bool = False,
compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None,
model_init_kwargs: Optional[Dict] = None,
ref_model_init_kwargs: Optional[Dict] = None,
):
if model_init_kwargs is None:
model_init_kwargs = {}
elif not isinstance(model, str):
raise ValueError("You passed model_kwargs to the DPOTrainer. But your model is already instantiated.")

if ref_model_init_kwargs is None:
ref_model_init_kwargs = {}
elif not isinstance(ref_model, str):
raise ValueError(
"You passed ref_model_kwargs to the DPOTrainer. But your ref_model is already instantiated."
)

if isinstance(model, str):
warnings.warn(
"You passed a model_id to the DPOTrainer. This will automatically create an "
"`AutoModelForCausalLM` or a `PeftModel` (if you passed a `peft_config`) for you."
)
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)

if isinstance(ref_model, str):
warnings.warn(
"You passed a ref model_id to the DPOTrainer. This will automatically create an "
"`AutoModelForCausalLM`"
)
ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs)

if not is_peft_available() and peft_config is not None:
raise ValueError(
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
Expand Down
18 changes: 9 additions & 9 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ class SFTTrainer(Trainer):
neftune_noise_alpha (`Optional[float]`):
If not `None`, this will activate NEFTune noise embeddings. This has been proven to drastically improve model performances for instrcution
fine-tuning. Check out the original paper here: https://arxiv.org/abs/2310.05914 and the original code here: https://github.com/neelsjain/NEFTune
model_init_kwargs: (`Optional[Dict]`, *optional*):
Dict of Optional kwargs to pass when instantiating the model from a string
"""

def __init__(
Expand All @@ -132,20 +134,25 @@ def __init__(
dataset_num_proc: Optional[int] = None,
dataset_batch_size: int = 1000,
neftune_noise_alpha: Optional[float] = None,
model_init_kwargs: Optional[Dict] = None,
):
if model_init_kwargs is None:
model_init_kwargs = {}
elif not isinstance(model, str):
raise ValueError("You passed model_kwargs to the SFTTrainer. But your model is already instantiated.")

if isinstance(model, str):
warnings.warn(
"You passed a model_id to the SFTTrainer. This will automatically create an "
"`AutoModelForCausalLM` or a `PeftModel` (if you passed a `peft_config`) for you."
)
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)

if packing and data_collator is not None and isinstance(data_collator, DataCollatorForCompletionOnlyLM):
raise ValueError(
"You passed a `DataCollatorForCompletionOnlyLM` to the SFTTrainer. This is not compatible with the `packing` argument."
)

supported_classes = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel)

if is_peft_available() and peft_config is not None:
if not isinstance(peft_config, PeftConfig):
raise ValueError(
Expand All @@ -154,11 +161,6 @@ def __init__(
)

if not isinstance(model, PeftModel):
if not isinstance(model, PreTrainedModel):
model = AutoModelForCausalLM.from_pretrained(
model,
)

if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
_support_gc_kwargs = hasattr(
args, "gradient_checkpointing_kwargs"
Expand All @@ -179,8 +181,6 @@ def __init__(

if callbacks is None:
callbacks = [PeftSavingCallback]
elif not isinstance(model, supported_classes):
model = AutoModelForCausalLM.from_pretrained(model)

if tokenizer is None:
tokenizer = AutoTokenizer.from_pretrained(model.config._name_or_path)
Expand Down

0 comments on commit c273b18

Please sign in to comment.