Skip to content

Commit

Permalink
💾 Reduce memory peak in GRPO by adding max_prompt_length and loop u…
Browse files Browse the repository at this point in the history
…sage in logp computation (#2598)

* add max_prompt len to config

* truncate prompt and compute log probs line by line
  • Loading branch information
qgallouedec authored Jan 21, 2025
1 parent d9f0568 commit b6a084c
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 8 deletions.
8 changes: 8 additions & 0 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ class GRPOConfig(TrainingArguments):
> Parameters that control the data preprocessing
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left.
num_generations (`int` or `None`, *optional*, defaults to `8`):
Number of generations per prompt to sample.
temperature (`float`, *optional*, defaults to `0.9`):
Expand All @@ -65,6 +67,12 @@ class GRPOConfig(TrainingArguments):
)

# Parameters that control the data preprocessing
max_prompt_length: Optional[int] = field(
default=512,
metadata={
"help": "Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left."
},
)
num_generations: Optional[int] = field(
default=8,
metadata={"help": "Number of generations to sample."},
Expand Down
28 changes: 20 additions & 8 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def data_collator(features): # No data collation is needed in GRPO
return features

# Training arguments
self.max_prompt_length = args.max_prompt_length
self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper
self.num_generations = args.num_generations # = G in the GRPO paper
self.generation_config = GenerationConfig(
Expand Down Expand Up @@ -203,6 +204,10 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
)
prompt_inputs = super()._prepare_inputs(prompt_inputs)

if self.max_prompt_length is not None:
prompt_inputs["input_ids"] = prompt_inputs["input_ids"][:, -self.max_prompt_length :]
prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][:, -self.max_prompt_length :]

# Generate completions
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
prompt_completion_ids = unwrapped_model.generate(**prompt_inputs, generation_config=self.generation_config)
Expand All @@ -211,21 +216,28 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N

# Get the per-token log probabilities for the completions for the model and the reference model
def get_per_token_logps(model, input_ids):
logits = model(input_ids).logits
logits = torch.roll(logits, shifts=1, dims=1) # Shape (B*G, L)
per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=input_ids.unsqueeze(2)).squeeze(2)
return per_token_logps
logits = model(input_ids).logits # (B, L, V)
logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
input_ids = input_ids[:, 1:] # (B, L-1), exclude the first input ID since we don't have logits for it
# Compute the log probabilities for the input tokens. Use a loop to reduce memory peak.
per_token_logps = []
for logits_row, input_ids_row in zip(logits, input_ids):
log_probs = logits_row.log_softmax(dim=-1)
token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
per_token_logps.append(token_log_prob)
return torch.stack(per_token_logps)

per_token_logps = get_per_token_logps(model, prompt_completion_ids)
per_token_logps = per_token_logps[:, prompt_length:] # get rid of the prompt
# Get rid of the prompt (-1 because of the shift done in get_per_token_logps)
per_token_logps = per_token_logps[:, prompt_length - 1 :]

with torch.inference_mode():
if self.ref_model is not None:
ref_per_token_logps = get_per_token_logps(self.ref_model, prompt_completion_ids)
else:
with self.accelerator.unwrap_model(model).disable_adapter():
ref_per_token_logps = get_per_token_logps(model, prompt_completion_ids)
ref_per_token_logps = ref_per_token_logps[:, prompt_length:] # get rid of the prompt
ref_per_token_logps = ref_per_token_logps[:, prompt_length - 1 :]

# Compute the KL divergence between the model and the reference model
per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
Expand Down Expand Up @@ -287,9 +299,9 @@ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> Non
metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics
logs = {**logs, **metrics}
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super().log(logs, start_time)
super().log(logs, start_time)
else: # transformers<=4.46
return super().log(logs)
super().log(logs)
self._metrics = {key: [] for key in self._metrics}

def create_model_card(
Expand Down

0 comments on commit b6a084c

Please sign in to comment.