diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index a0a61d3b16..f26e3f9c4a 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -39,6 +39,8 @@ class GRPOConfig(TrainingArguments): > Parameters that control the data preprocessing + max_prompt_length (`int` or `None`, *optional*, defaults to `512`): + Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left. num_generations (`int` or `None`, *optional*, defaults to `8`): Number of generations per prompt to sample. temperature (`float`, *optional*, defaults to `0.9`): @@ -65,6 +67,12 @@ class GRPOConfig(TrainingArguments): ) # Parameters that control the data preprocessing + max_prompt_length: Optional[int] = field( + default=512, + metadata={ + "help": "Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left." + }, + ) num_generations: Optional[int] = field( default=8, metadata={"help": "Number of generations to sample."}, diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 0e830b6822..579f55ec38 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -138,6 +138,7 @@ def data_collator(features): # No data collation is needed in GRPO return features # Training arguments + self.max_prompt_length = args.max_prompt_length self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper self.num_generations = args.num_generations # = G in the GRPO paper self.generation_config = GenerationConfig( @@ -203,6 +204,10 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N ) prompt_inputs = super()._prepare_inputs(prompt_inputs) + if self.max_prompt_length is not None: + prompt_inputs["input_ids"] = prompt_inputs["input_ids"][:, -self.max_prompt_length :] + prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][:, -self.max_prompt_length :] + # Generate completions with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model: prompt_completion_ids = unwrapped_model.generate(**prompt_inputs, generation_config=self.generation_config) @@ -211,13 +216,20 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N # Get the per-token log probabilities for the completions for the model and the reference model def get_per_token_logps(model, input_ids): - logits = model(input_ids).logits - logits = torch.roll(logits, shifts=1, dims=1) # Shape (B*G, L) - per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=input_ids.unsqueeze(2)).squeeze(2) - return per_token_logps + logits = model(input_ids).logits # (B, L, V) + logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred + input_ids = input_ids[:, 1:] # (B, L-1), exclude the first input ID since we don't have logits for it + # Compute the log probabilities for the input tokens. Use a loop to reduce memory peak. + per_token_logps = [] + for logits_row, input_ids_row in zip(logits, input_ids): + log_probs = logits_row.log_softmax(dim=-1) + token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1) + per_token_logps.append(token_log_prob) + return torch.stack(per_token_logps) per_token_logps = get_per_token_logps(model, prompt_completion_ids) - per_token_logps = per_token_logps[:, prompt_length:] # get rid of the prompt + # Get rid of the prompt (-1 because of the shift done in get_per_token_logps) + per_token_logps = per_token_logps[:, prompt_length - 1 :] with torch.inference_mode(): if self.ref_model is not None: @@ -225,7 +237,7 @@ def get_per_token_logps(model, input_ids): else: with self.accelerator.unwrap_model(model).disable_adapter(): ref_per_token_logps = get_per_token_logps(model, prompt_completion_ids) - ref_per_token_logps = ref_per_token_logps[:, prompt_length:] # get rid of the prompt + ref_per_token_logps = ref_per_token_logps[:, prompt_length - 1 :] # Compute the KL divergence between the model and the reference model per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 @@ -287,9 +299,9 @@ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> Non metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics logs = {**logs, **metrics} if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"): - return super().log(logs, start_time) + super().log(logs, start_time) else: # transformers<=4.46 - return super().log(logs) + super().log(logs) self._metrics = {key: [] for key in self._metrics} def create_model_card(