Skip to content

Commit

Permalink
🫷 Include stop token in policy model's generation_config (#2528)
Browse files Browse the repository at this point in the history
* Include stop token in policy model's generation_config

* Fix formatting

* Update trl/trainer/ppo_trainer.py

* Update trl/trainer/ppo_trainer.py

* don't modify args

* clarify doc

* more nice doc

* missing no [ci skip]

* really don't modify args

* oups

---------

Co-authored-by: Quentin Gallouédec <[email protected]>
Co-authored-by: Quentin Gallouédec <[email protected]>
  • Loading branch information
3 people authored Jan 22, 2025
1 parent d4222a1 commit a9b54a8
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 14 deletions.
26 changes: 16 additions & 10 deletions trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,18 @@ def __init__(
if data_collator is None:
data_collator = DataCollatorWithPadding(self.processing_class)

self.policy_model.generation_config.eos_token_id = (
None # disable `pad_token_id` and `eos_token_id` because we just want to
)
self.policy_model.generation_config.pad_token_id = None # generate tokens without truncation / padding
# Handle stop token settings: update policy model's generation_config to use provided stop token
if args.stop_token and args.stop_token_id:
raise ValueError("You cannot set both `stop_token` and `stop_token_id`.")
elif args.stop_token:
if args.stop_token == "eos":
self.policy_model.generation_config.eos_token_id = self.stop_token_id = processing_class.eos_token_id
else:
raise ValueError(
f"Unknown `stop_token` {args.stop_token}. Allowed values are: `'eos'` and `None` (no stop token)."
)
else:
self.policy_model.generation_config.eos_token_id = self.stop_token_id = args.stop_token_id # None or int

# peft support
if not is_peft_available() and peft_config is not None:
Expand Down Expand Up @@ -220,8 +228,6 @@ def __init__(
for module in [self.policy_model, self.ref_model, self.value_model, self.reward_model]:
if module is not None:
disable_dropout_in_model(module)
if args.stop_token and args.stop_token == "eos":
args.stop_token_id = processing_class.eos_token_id
self.model = PolicyAndValueWrapper(self.policy_model, self.value_model)
self.model.config = self.policy_model.config # needed for pushing to hub
self.create_optimizer_and_scheduler(
Expand Down Expand Up @@ -449,9 +455,9 @@ def repeat_generator():

# Response Processing 1. truncate response after the first occurrence of `stop_token_id`
postprocessed_response = response
if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
postprocessed_response = truncate_response(
args.stop_token_id, processing_class.pad_token_id, response
self.stop_token_id, processing_class.pad_token_id, response
)

# Response Processing 2. run reward model on the truncated responses
Expand Down Expand Up @@ -706,9 +712,9 @@ def generate_completions(self, sampling: bool = False):
)
response = query_response[:, context_length:]
postprocessed_response = response
if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
postprocessed_response = truncate_response(
args.stop_token_id, processing_class.pad_token_id, response
self.stop_token_id, processing_class.pad_token_id, response
)
table["query"].extend(
gather_object(processing_class.batch_decode(query, skip_special_tokens=True))
Expand Down
20 changes: 16 additions & 4 deletions trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -993,9 +993,15 @@ class OnPolicyConfig(TrainingArguments):
response_length (`int`, *optional*, defaults to `53`):
Length of the response.
stop_token (`str` or `None`, *optional*, defaults to `None`):
Stop token.
Specifies the stop token to use for text generation. This parameter is mutually exclusive with
`stop_token_id`.
- `None`: No stop token is applied, unless `stop_token_id` is specified.
- `'eos'`: Uses the tokenizer's `eos_token`.
stop_token_id (`int` or `None`, *optional*, defaults to `None`):
Truncation token id.
Specifies the ID of the stop token to use for text generation. If `None`, no stop token ID is applied,
unless `stop_token` is specified. This parameter is mutually exclusive with `stop_token`.
temperature (`float`, *optional*, defaults to `0.7`):
Sampling temperature.
missing_eos_penalty (`float` or `None`, *optional*, defaults to `None`):
Expand Down Expand Up @@ -1054,11 +1060,17 @@ class OnPolicyConfig(TrainingArguments):
)
stop_token: Optional[Literal["eos"]] = field(
default=None,
metadata={"help": "Stop token."},
metadata={
"help": "Specifies the stop token to use for text generation. This parameter is mutually exclusive with "
"`stop_token_id`."
},
)
stop_token_id: Optional[int] = field(
default=None,
metadata={"help": "Truncation token id."},
metadata={
"help": "Specifies the ID of the stop token to use for text generation. If `None`, no stop token ID is "
"applied, unless `stop_token` is specified. This parameter is mutually exclusive with `stop_token`."
},
)
temperature: float = field(
default=0.7,
Expand Down

0 comments on commit a9b54a8

Please sign in to comment.