Skip to content

Commit

Permalink
docs: Update TRL README with GRPO example details and usage instructi…
Browse files Browse the repository at this point in the history
…ons (#76)
  • Loading branch information
andreaskoepf authored Feb 7, 2025
1 parent a8e11e7 commit 7b72c34
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 12 deletions.
31 changes: 29 additions & 2 deletions examples/trl/README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,32 @@
1. Install the requirements in the txt file
# TRL Examples

```
This directory contains examples using the [TRL (Transformer Reinforcement Learning) library](https://github.com/huggingface/trl) to fine-tune language models with reinforcement learning techniques.

## GRPO Example

The main example demonstrates using GRPO (Group Relative Policy Optimization) to fine-tune a language model on reasoning tasks from reasoning-gym. It includes:

- Custom reward functions for answer accuracy and format compliance
- Integration with reasoning-gym datasets
- Configurable training parameters via YAML config
- Wandb logging and model checkpointing
- Evaluation on held-out test sets

## Setup

1. Install the required dependencies:

```bash
pip install -r requirements.txt
```

## Usage

1. Configure the training parameters in `config/grpo.yaml`
2. Run the training script:

```bash
python main_grpo_reward.py
```

The model will be trained using GRPO with the specified reasoning-gym dataset and evaluation metrics will be logged to Weights & Biases.
14 changes: 7 additions & 7 deletions examples/trl/main_grpo_reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def __len__(self):
return len(self.data)

def __getitem__(self, idx):
metadata = self.data[idx]
question = metadata["question"]
item = self.data[idx]
question = item["question"]

chat = []

Expand All @@ -43,7 +43,7 @@ def __getitem__(self, idx):
chat.append({"role": "user", "content": question})

prompt = self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
return {"prompt": prompt, "metadata": metadata}
return {"prompt": prompt, "metadata": item}


class GRPOTrainerCustom(GRPOTrainer):
Expand All @@ -54,7 +54,7 @@ def __init__(
args: GRPOConfig,
tokenizer,
peft_config,
seed1,
seed,
size,
developer_role="system",
):
Expand All @@ -66,7 +66,7 @@ def __init__(
peft_config=peft_config,
)
developer_prompt = reasoning_gym.utils.SYSTEM_PROMPTS["DeepSeekZero"]
self.train_dataset = ReasoningGymDataset(dataset_name, seed1, size, tokenizer, developer_prompt, developer_role)
self.train_dataset = ReasoningGymDataset(dataset_name, seed, size, tokenizer, developer_prompt, developer_role)

def _format_reward(self, completions, **kwargs):
regex = r"^<think>([^<]*(?:<(?!/?think>)[^<]*)*)<\/think>\n<answer>([\s\S]*?)<\/answer>$"
Expand Down Expand Up @@ -128,7 +128,7 @@ def main(script_args, training_args, model_args):
args=training_args,
tokenizer=tokenizer,
peft_config=peft_config,
seed1=training_args.seed,
seed=training_args.seed,
size=script_args.train_size,
)

Expand All @@ -154,7 +154,7 @@ def main(script_args, training_args, model_args):
"finetuned_from": model_args.model_name_or_path,
"dataset": list(script_args.dataset_name),
"dataset_tags": list(script_args.dataset_name),
"tags": ["open-r1"],
"tags": ["reasoning-gym"],
}

if trainer.accelerator.is_main_process:
Expand Down
4 changes: 1 addition & 3 deletions examples/trl/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
torch --index-url https://download.pytorch.org/whl/cu124
torchvision --index-url https://download.pytorch.org/whl/cu124
torchaudio --index-url https://download.pytorch.org/whl/cu124
torch>=2.6.0
datasets
peft
transformers
Expand Down

0 comments on commit 7b72c34

Please sign in to comment.