Skip to content

Commit

Permalink
add rlvr coding support
Browse files Browse the repository at this point in the history
  • Loading branch information
vwxyzjn committed Feb 14, 2025
1 parent 861ed92 commit 13e4dd7
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 0 deletions.
36 changes: 36 additions & 0 deletions open_instruct/ground_truth_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
import json
import re
import string
from typing import List

from open_instruct.code_utils import get_successful_tests_fast
from open_instruct.if_functions import IF_FUNCTIONS_MAP
from open_instruct.math_utils import (
get_unnormalized_answer,
Expand Down Expand Up @@ -119,6 +121,30 @@ def verify_ifeval_sample(model_output, constraint):
return func(answer, **non_none_args)


def extract_python_code(model_output: str) -> str:
"""Extract the first code block between ``` markers from the model output."""
# Find content between ``` markers
pattern = r"```(?:python)?(.*?)```"
matches = re.findall(pattern, model_output, re.DOTALL)

if not matches:
return model_output

# Return the first match, stripped of whitespace
return matches[0].strip()


def verify_ace_coder_sample(model_output, tests: List[str], max_execution_time: float = 1.0):
"""First extract the python code from the model output, then run it against the test cases. See example below.
"""
# Extract the python code from the model output
python_code = extract_python_code(model_output)
# Run the python code against the test cases
passes = get_successful_tests_fast(python_code, tests, max_execution_time)
pass_rate = sum(passes) / len(passes)
return pass_rate


def normalize_answer(s):
"""
Lower text and remove punctuation, articles and extra whitespace.
Expand Down Expand Up @@ -163,3 +189,13 @@ def soft_format_reward_func(responses: list[str], reward_scale: float = 1.0) ->
for sample in ds["train"]:
print(sample)
verify_ifeval_sample(test_model_output, sample["ground_truth"])


test_model_output = "<|assistant|>\nHere is a python code that solves the problem: ```python\ndef add(a, b):\n return a + b\n```"
test_cases = [
"assert add(1, 2) == 3",
"assert add(-1, 1) == 0",
"assert add(0, 0) == 1" # This test will fail but just for testing
]
print(verify_ace_coder_sample(test_model_output, test_cases))

55 changes: 55 additions & 0 deletions scripts/data/rlvr/rlvr_acecoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""
This script is used to convert the GSM8K dataset to standard SFT format.
Note that we don't do any special processing to answer, and we will mainly
use it for generations.
Usage:
python scripts/data/rlvr/rlvr_acecoder.py --push_to_hub
python scripts/data/rlvr/rlvr_acecoder.py --push_to_hub --hf_entity ai2-adapt-dev
"""

from dataclasses import dataclass
from typing import Optional

import datasets
from huggingface_hub import HfApi
from transformers import HfArgumentParser

@dataclass
class Args:
push_to_hub: bool = False
hf_entity: Optional[str] = None

def main(args: Args):
dataset = datasets.load_dataset("TIGER-Lab/AceCode-87K", split="train")

def process(example):
example["messages"] = [
{"role": "user", "content": example["question"]},
]
example["ground_truth"] = example["test_cases"]
example["dataset"] = "ace_coder"
return example

dataset = dataset.map(process)
# reorder columns
dataset = dataset.select_columns(["messages", "ground_truth", "dataset"])

if args.push_to_hub:
api = HfApi()
if not args.hf_entity:
args.hf_entity = HfApi().whoami()["name"]
repo_id = f"{args.hf_entity}/rlvr_acecoder"
print(f"Pushing dataset to Hub: {repo_id}")
dataset.push_to_hub(repo_id)
api.upload_file(
path_or_fileobj=__file__,
path_in_repo="create_dataset.py",
repo_type="dataset",
repo_id=repo_id,
)

if __name__ == "__main__":
parser = HfArgumentParser((Args))
main(*parser.parse_args_into_dataclasses())
1 change: 1 addition & 0 deletions scripts/train/rlvr/grpo_llama3.1-8b.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
exp_name=grpo_llama3.1-8b_${RANDOM}
python open_instruct/grpo_vllm_thread_ray_gtrl.py \
--exp_name $exp_name \
--output_dir /weka/oe-adapt-default/costah/models/$exp_name \
Expand Down
37 changes: 37 additions & 0 deletions scripts/train/rlvr/grpo_mini_code.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
python open_instruct/grpo_vllm_thread_ray_gtrl.py \
--dataset_mixer_list vwxyzjn/rlvr_acecoder 1.0 \
--dataset_mixer_list_splits train \
--dataset_mixer_eval_list vwxyzjn/rlvr_acecoder 1.0 \
--dataset_mixer_eval_list_splits train \
--max_token_length 1023 \
--max_prompt_token_length 1024 \
--response_length 1024 \
--number_samples_per_prompt 4 \
--model_name_or_path HuggingFaceTB/SmolLM-135M-Instruct \
--non_stop_penalty \
--stop_token eos \
--temperature 1.0 \
--ground_truths_key ground_truth \
--chat_template_name tulu \
--sft_messages_key messages \
--learning_rate 3e-7 \
--total_episodes 10000 \
--penalty_reward_value -10.0 \
--deepspeed_stage 3 \
--per_device_train_batch_size 1 \
--local_rollout_forward_batch_size 1 \
--local_mini_batch_size 4 \
--local_rollout_batch_size 4 \
--num_epochs 1 \
--actor_num_gpus_per_node 1 \
--vllm_tensor_parallel_size 1 \
--beta 0.05 \
--apply_verifiable_reward true \
--output_dir output/rlvr_1b \
--seed 3 \
--num_evals 3 \
--save_freq 100 \
--reward_model_multiplier 0.0 \
--gradient_checkpointing \
--vllm_enforce_eager \
--with_tracking

0 comments on commit 13e4dd7

Please sign in to comment.