Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

docs: Update TRL README with GRPO example details and usage instructions #76

Merged
merged 1 commit into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 29 additions & 2 deletions examples/trl/README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,32 @@
1. Install the requirements in the txt file
# TRL Examples

```
This directory contains examples using the [TRL (Transformer Reinforcement Learning) library](https://github.com/huggingface/trl) to fine-tune language models with reinforcement learning techniques.

## GRPO Example

The main example demonstrates using GRPO (Group Relative Policy Optimization) to fine-tune a language model on reasoning tasks from reasoning-gym. It includes:

- Custom reward functions for answer accuracy and format compliance
- Integration with reasoning-gym datasets
- Configurable training parameters via YAML config
- Wandb logging and model checkpointing
- Evaluation on held-out test sets

## Setup

1. Install the required dependencies:

```bash
pip install -r requirements.txt
```

## Usage

1. Configure the training parameters in `config/grpo.yaml`
2. Run the training script:

```bash
python main_grpo_reward.py
```

The model will be trained using GRPO with the specified reasoning-gym dataset and evaluation metrics will be logged to Weights & Biases.
14 changes: 7 additions & 7 deletions examples/trl/main_grpo_reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def __len__(self):
return len(self.data)

def __getitem__(self, idx):
metadata = self.data[idx]
question = metadata["question"]
item = self.data[idx]
question = item["question"]

chat = []

Expand All @@ -43,7 +43,7 @@ def __getitem__(self, idx):
chat.append({"role": "user", "content": question})

prompt = self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
return {"prompt": prompt, "metadata": metadata}
return {"prompt": prompt, "metadata": item}


class GRPOTrainerCustom(GRPOTrainer):
Expand All @@ -54,7 +54,7 @@ def __init__(
args: GRPOConfig,
tokenizer,
peft_config,
seed1,
seed,
size,
developer_role="system",
):
Expand All @@ -66,7 +66,7 @@ def __init__(
peft_config=peft_config,
)
developer_prompt = reasoning_gym.utils.SYSTEM_PROMPTS["DeepSeekZero"]
self.train_dataset = ReasoningGymDataset(dataset_name, seed1, size, tokenizer, developer_prompt, developer_role)
self.train_dataset = ReasoningGymDataset(dataset_name, seed, size, tokenizer, developer_prompt, developer_role)

def _format_reward(self, completions, **kwargs):
regex = r"^<think>([^<]*(?:<(?!/?think>)[^<]*)*)<\/think>\n<answer>([\s\S]*?)<\/answer>$"
Expand Down Expand Up @@ -128,7 +128,7 @@ def main(script_args, training_args, model_args):
args=training_args,
tokenizer=tokenizer,
peft_config=peft_config,
seed1=training_args.seed,
seed=training_args.seed,
size=script_args.train_size,
)

Expand All @@ -154,7 +154,7 @@ def main(script_args, training_args, model_args):
"finetuned_from": model_args.model_name_or_path,
"dataset": list(script_args.dataset_name),
"dataset_tags": list(script_args.dataset_name),
"tags": ["open-r1"],
"tags": ["reasoning-gym"],
}

if trainer.accelerator.is_main_process:
Expand Down
4 changes: 1 addition & 3 deletions examples/trl/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
torch --index-url https://download.pytorch.org/whl/cu124
torchvision --index-url https://download.pytorch.org/whl/cu124
torchaudio --index-url https://download.pytorch.org/whl/cu124
torch>=2.6.0
datasets
peft
transformers
Expand Down