From 7b72c3470bff68cecf276e35e8c41d08fdd80ee0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Fri, 7 Feb 2025 07:56:22 +0100 Subject: [PATCH] docs: Update TRL README with GRPO example details and usage instructions (#76) --- examples/trl/README.md | 31 +++++++++++++++++++++++++++++-- examples/trl/main_grpo_reward.py | 14 +++++++------- examples/trl/requirements.txt | 4 +--- 3 files changed, 37 insertions(+), 12 deletions(-) diff --git a/examples/trl/README.md b/examples/trl/README.md index ae2818e2..a9ef439a 100644 --- a/examples/trl/README.md +++ b/examples/trl/README.md @@ -1,5 +1,32 @@ -1. Install the requirements in the txt file +# TRL Examples -``` +This directory contains examples using the [TRL (Transformer Reinforcement Learning) library](https://github.com/huggingface/trl) to fine-tune language models with reinforcement learning techniques. + +## GRPO Example + +The main example demonstrates using GRPO (Group Relative Policy Optimization) to fine-tune a language model on reasoning tasks from reasoning-gym. It includes: + +- Custom reward functions for answer accuracy and format compliance +- Integration with reasoning-gym datasets +- Configurable training parameters via YAML config +- Wandb logging and model checkpointing +- Evaluation on held-out test sets + +## Setup + +1. Install the required dependencies: + +```bash pip install -r requirements.txt ``` + +## Usage + +1. Configure the training parameters in `config/grpo.yaml` +2. Run the training script: + +```bash +python main_grpo_reward.py +``` + +The model will be trained using GRPO with the specified reasoning-gym dataset and evaluation metrics will be logged to Weights & Biases. diff --git a/examples/trl/main_grpo_reward.py b/examples/trl/main_grpo_reward.py index ae8c33cd..119b1e1e 100644 --- a/examples/trl/main_grpo_reward.py +++ b/examples/trl/main_grpo_reward.py @@ -33,8 +33,8 @@ def __len__(self): return len(self.data) def __getitem__(self, idx): - metadata = self.data[idx] - question = metadata["question"] + item = self.data[idx] + question = item["question"] chat = [] @@ -43,7 +43,7 @@ def __getitem__(self, idx): chat.append({"role": "user", "content": question}) prompt = self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) - return {"prompt": prompt, "metadata": metadata} + return {"prompt": prompt, "metadata": item} class GRPOTrainerCustom(GRPOTrainer): @@ -54,7 +54,7 @@ def __init__( args: GRPOConfig, tokenizer, peft_config, - seed1, + seed, size, developer_role="system", ): @@ -66,7 +66,7 @@ def __init__( peft_config=peft_config, ) developer_prompt = reasoning_gym.utils.SYSTEM_PROMPTS["DeepSeekZero"] - self.train_dataset = ReasoningGymDataset(dataset_name, seed1, size, tokenizer, developer_prompt, developer_role) + self.train_dataset = ReasoningGymDataset(dataset_name, seed, size, tokenizer, developer_prompt, developer_role) def _format_reward(self, completions, **kwargs): regex = r"^([^<]*(?:<(?!/?think>)[^<]*)*)<\/think>\n([\s\S]*?)<\/answer>$" @@ -128,7 +128,7 @@ def main(script_args, training_args, model_args): args=training_args, tokenizer=tokenizer, peft_config=peft_config, - seed1=training_args.seed, + seed=training_args.seed, size=script_args.train_size, ) @@ -154,7 +154,7 @@ def main(script_args, training_args, model_args): "finetuned_from": model_args.model_name_or_path, "dataset": list(script_args.dataset_name), "dataset_tags": list(script_args.dataset_name), - "tags": ["open-r1"], + "tags": ["reasoning-gym"], } if trainer.accelerator.is_main_process: diff --git a/examples/trl/requirements.txt b/examples/trl/requirements.txt index 474c8107..1353b8b3 100644 --- a/examples/trl/requirements.txt +++ b/examples/trl/requirements.txt @@ -1,6 +1,4 @@ -torch --index-url https://download.pytorch.org/whl/cu124 -torchvision --index-url https://download.pytorch.org/whl/cu124 -torchaudio --index-url https://download.pytorch.org/whl/cu124 +torch>=2.6.0 datasets peft transformers