From a9b54a852ee12ff508773edb02e1c243817e71ae Mon Sep 17 00:00:00 2001 From: Dawid Motyka Date: Wed, 22 Jan 2025 12:24:42 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=AB=B7=20Include=20stop=20token=20in=20po?= =?UTF-8?q?licy=20model's=20generation=5Fconfig=20(#2528)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Include stop token in policy model's generation_config * Fix formatting * Update trl/trainer/ppo_trainer.py * Update trl/trainer/ppo_trainer.py * don't modify args * clarify doc * more nice doc * missing no [ci skip] * really don't modify args * oups --------- Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Quentin Gallouédec --- trl/trainer/ppo_trainer.py | 26 ++++++++++++++++---------- trl/trainer/utils.py | 20 ++++++++++++++++---- 2 files changed, 32 insertions(+), 14 deletions(-) diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index ef29461a70..83926cfd6a 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -138,10 +138,18 @@ def __init__( if data_collator is None: data_collator = DataCollatorWithPadding(self.processing_class) - self.policy_model.generation_config.eos_token_id = ( - None # disable `pad_token_id` and `eos_token_id` because we just want to - ) - self.policy_model.generation_config.pad_token_id = None # generate tokens without truncation / padding + # Handle stop token settings: update policy model's generation_config to use provided stop token + if args.stop_token and args.stop_token_id: + raise ValueError("You cannot set both `stop_token` and `stop_token_id`.") + elif args.stop_token: + if args.stop_token == "eos": + self.policy_model.generation_config.eos_token_id = self.stop_token_id = processing_class.eos_token_id + else: + raise ValueError( + f"Unknown `stop_token` {args.stop_token}. Allowed values are: `'eos'` and `None` (no stop token)." + ) + else: + self.policy_model.generation_config.eos_token_id = self.stop_token_id = args.stop_token_id # None or int # peft support if not is_peft_available() and peft_config is not None: @@ -220,8 +228,6 @@ def __init__( for module in [self.policy_model, self.ref_model, self.value_model, self.reward_model]: if module is not None: disable_dropout_in_model(module) - if args.stop_token and args.stop_token == "eos": - args.stop_token_id = processing_class.eos_token_id self.model = PolicyAndValueWrapper(self.policy_model, self.value_model) self.model.config = self.policy_model.config # needed for pushing to hub self.create_optimizer_and_scheduler( @@ -449,9 +455,9 @@ def repeat_generator(): # Response Processing 1. truncate response after the first occurrence of `stop_token_id` postprocessed_response = response - if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 + if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 postprocessed_response = truncate_response( - args.stop_token_id, processing_class.pad_token_id, response + self.stop_token_id, processing_class.pad_token_id, response ) # Response Processing 2. run reward model on the truncated responses @@ -706,9 +712,9 @@ def generate_completions(self, sampling: bool = False): ) response = query_response[:, context_length:] postprocessed_response = response - if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 + if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 postprocessed_response = truncate_response( - args.stop_token_id, processing_class.pad_token_id, response + self.stop_token_id, processing_class.pad_token_id, response ) table["query"].extend( gather_object(processing_class.batch_decode(query, skip_special_tokens=True)) diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 1228dc7ece..719d952f1f 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -993,9 +993,15 @@ class OnPolicyConfig(TrainingArguments): response_length (`int`, *optional*, defaults to `53`): Length of the response. stop_token (`str` or `None`, *optional*, defaults to `None`): - Stop token. + Specifies the stop token to use for text generation. This parameter is mutually exclusive with + `stop_token_id`. + + - `None`: No stop token is applied, unless `stop_token_id` is specified. + - `'eos'`: Uses the tokenizer's `eos_token`. + stop_token_id (`int` or `None`, *optional*, defaults to `None`): - Truncation token id. + Specifies the ID of the stop token to use for text generation. If `None`, no stop token ID is applied, + unless `stop_token` is specified. This parameter is mutually exclusive with `stop_token`. temperature (`float`, *optional*, defaults to `0.7`): Sampling temperature. missing_eos_penalty (`float` or `None`, *optional*, defaults to `None`): @@ -1054,11 +1060,17 @@ class OnPolicyConfig(TrainingArguments): ) stop_token: Optional[Literal["eos"]] = field( default=None, - metadata={"help": "Stop token."}, + metadata={ + "help": "Specifies the stop token to use for text generation. This parameter is mutually exclusive with " + "`stop_token_id`." + }, ) stop_token_id: Optional[int] = field( default=None, - metadata={"help": "Truncation token id."}, + metadata={ + "help": "Specifies the ID of the stop token to use for text generation. If `None`, no stop token ID is " + "applied, unless `stop_token` is specified. This parameter is mutually exclusive with `stop_token`." + }, ) temperature: float = field( default=0.7,