Skip to content

Commit

Permalink
Format
Browse files Browse the repository at this point in the history
  • Loading branch information
kcz358 committed Feb 8, 2025
1 parent c782ab7 commit ebc741f
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 18 deletions.
14 changes: 7 additions & 7 deletions src/open_r1/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@

import os
import re
from datetime import datetime
from dataclasses import dataclass, field
from datetime import datetime
from typing import Optional

from datasets import load_dataset
Expand Down Expand Up @@ -69,19 +69,19 @@ def accuracy_reward(completions, solution, **kwargs):
if reward == 0.0:
try:
# Extract answer from solution if it has think/answer tags
sol_match = re.search(r'<answer>(.*?)</answer>', sol)
sol_match = re.search(r"<answer>(.*?)</answer>", sol)
ground_truth = sol_match.group(1).strip() if sol_match else sol.strip()

# Extract answer from content if it has think/answer tags
content_match = re.search(r'<answer>(.*?)</answer>', content)
content_match = re.search(r"<answer>(.*?)</answer>", content)
student_answer = content_match.group(1).strip() if content_match else content.strip()

# Compare the extracted answers
if student_answer == ground_truth:
reward = 1.0
except Exception:
pass # Keep reward as 0.0 if both methods fail

rewards.append(reward)
if os.getenv("DEBUG_MODE") == "true":
log_path = os.getenv("LOG_PATH")
Expand Down Expand Up @@ -130,6 +130,7 @@ def make_conversation(example):
}

QUESTION_TEMPLATE = "{Question} Output the thinking process in <think> </think> and final answer (number) in <answer> </answer> tags."

def make_conversation_image(example):
return {
"prompt": [
Expand All @@ -149,7 +150,6 @@ def make_conversation_image(example):
dataset = dataset.map(make_conversation)
dataset = dataset.remove_columns("messages")


trainer_cls = Qwen2VLGRPOTrainer

# Initialize the GRPO trainer
Expand Down
23 changes: 12 additions & 11 deletions src/open_r1/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import os
import textwrap
from collections import defaultdict
Expand Down Expand Up @@ -45,8 +46,6 @@
from trl.trainer.grpo_config import GRPOConfig
from trl.trainer.utils import generate_model_card, get_comet_experiment_url

import copy


if is_peft_available():
from peft import PeftConfig, get_peft_model
Expand Down Expand Up @@ -275,8 +274,8 @@ def data_collator(features): # No data collation is needed in GRPO
self.num_generations = args.num_generations # = G in the GRPO paper
self.generation_config = GenerationConfig(
max_new_tokens=self.max_completion_length,
do_sample=True,
temperature=1, # HACK
do_sample=True,
temperature=1, # HACK
num_return_sequences=self.num_generations,
pad_token_id=pad_token_id,
)
Expand Down Expand Up @@ -355,7 +354,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N

# Generate completions
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
#prompt_completion_ids = unwrapped_model.generate(**prompt_inputs, generation_config=self.generation_config)
# prompt_completion_ids = unwrapped_model.generate(**prompt_inputs, generation_config=self.generation_config)

# Generate N times, each generate one with the temp_generation_config , stack the output_ids to prompt_completion_ids, pad the empty places with number 151613
num_generations = self.generation_config.num_return_sequences
Expand All @@ -367,25 +366,27 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
for i in range(num_generations): # -1 because we already have one generation
completion = unwrapped_model.generate(**prompt_inputs, generation_config=temp_generation_config)
all_completions.append(completion)

# Stack all completions and pad if needed
max_length = max(completion.size(1) for completion in all_completions)
padded_completions = []

for completion in all_completions:
if completion.size(1) < max_length:
padding = torch.full((completion.size(0), max_length - completion.size(1)),
self.processing_class.tokenizer.pad_token_id,
dtype=completion.dtype,
device=completion.device)
padding = torch.full(
(completion.size(0), max_length - completion.size(1)),
self.processing_class.tokenizer.pad_token_id,
dtype=completion.dtype,
device=completion.device,
)
padded_completion = torch.cat([completion, padding], dim=1)
else:
padded_completion = completion
padded_completions.append(padded_completion)

# Stack all padded completions
prompt_completion_ids = torch.cat(padded_completions, dim=0)

prompt_length = prompt_inputs["input_ids"].size(1)
completion_ids = prompt_completion_ids[:, prompt_length:]

Expand Down

0 comments on commit ebc741f

Please sign in to comment.