-
Notifications
You must be signed in to change notification settings - Fork 51
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
3 changed files
with
242 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
import re | ||
from dataclasses import dataclass | ||
from random import Random | ||
from typing import Any, Optional | ||
|
||
import sympy | ||
from sympy import Symbol, symbols | ||
from sympy.parsing.sympy_parser import parse_expr | ||
|
||
from ..factory import ProceduralDataset, register_dataset | ||
|
||
QUESTION_TEMPLATE = """Make 24 using {numbers}. You can only use each number once. You can use the operators {operators}. | ||
Final answer format instructions: | ||
1. Provide your final answer as a arithmetic expression (no '=' sign). | ||
2. Do not include the target number in the expression. | ||
3. Use '*' for multiplication. | ||
4. Use '/' for division. | ||
""" | ||
|
||
|
||
@dataclass | ||
class Puzzle24Config: | ||
operators: tuple = ("+", "-", "*", "/") | ||
min_value: int = 1 | ||
max_value: int = 10 | ||
seed: Optional[int] = None | ||
size: int = 500 | ||
|
||
def validate(self): | ||
assert len(self.operators) > 0, "At least one operator is required" | ||
assert self.min_value <= self.max_value, "Minimum value must be less than or equal to maximum value" | ||
assert self.min_value >= 1, "Minimum value must be at least 1" | ||
assert self.max_value <= 10, "Maximum value must be at most 10" | ||
assert self.size > 0, "Size must be greater than 0" | ||
|
||
|
||
class Puzzle24Dataset(ProceduralDataset): | ||
def __init__(self, config: Puzzle24Config): | ||
super().__init__(config=config, seed=config.seed, size=config.size) | ||
|
||
def _generate_candidate_expression(self, rng: Random, num_terms: int) -> tuple[sympy.Expr, list[int], list[Symbol]]: | ||
"""Generate a candidate expression with random numbers and operators | ||
Args: | ||
rng: Random number generator | ||
num_terms: Number of terms to include | ||
Returns: | ||
Tuple of (sympy expression, list of numbers, list of symbols) | ||
""" | ||
# Generate random numbers | ||
numbers = [rng.randint(self.config.min_value, self.config.max_value) for _ in range(num_terms)] | ||
|
||
# Create symbols for building expression | ||
syms = symbols(f"x:{num_terms}") | ||
|
||
# Build random expression | ||
expr = syms[0] | ||
|
||
for i in range(1, num_terms): | ||
op = rng.choice(self.config.operators) | ||
if op == "+": | ||
expr = expr + syms[i] | ||
elif op == "-": | ||
expr = expr - syms[i] | ||
elif op == "*": | ||
expr = expr * syms[i] | ||
else: | ||
if numbers[i] != 0: | ||
current = int(expr.subs({sym: num for sym, num in zip(syms[:i], numbers[:i])})) | ||
remaining = [n for n in numbers[i:] if n != 0] | ||
rng.shuffle(remaining) | ||
found_divisor = False | ||
for div in remaining: | ||
if current % div == 0: | ||
numbers[i] = div | ||
expr = expr / syms[i] | ||
found_divisor = True | ||
break | ||
if not found_divisor: | ||
expr = expr - syms[i] | ||
else: | ||
expr = expr + syms[i] | ||
|
||
return expr, numbers, syms | ||
|
||
def __getitem__(self, idx: int) -> dict: | ||
rng = Random(self.seed + idx) | ||
while True: | ||
expr, numbers, syms = self._generate_candidate_expression(rng, 4) | ||
if expr.subs({sym: num for sym, num in zip(syms, numbers)}) == 24: | ||
break | ||
expr_str = str(expr) | ||
for i, sym in enumerate(syms): | ||
expr_str = expr_str.replace(str(sym), str(numbers[i])) | ||
|
||
question = QUESTION_TEMPLATE.format( | ||
numbers=", ".join(map(str, numbers)), operators=", ".join(self.config.operators) | ||
) | ||
return { | ||
"question": question, | ||
"answer": expr_str, | ||
"metadata": { | ||
"numbers": numbers, | ||
"expression": expr, | ||
}, | ||
} | ||
|
||
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: | ||
reward = 0.01 | ||
if answer is not None: | ||
try: | ||
answer = answer.strip() | ||
user_answer = int(parse_expr(answer)) | ||
solved = user_answer == 24 | ||
used_numbers = [int(num) for num in re.findall(r"\b\d+\b", answer)] | ||
if len(used_numbers) != 4: | ||
reward = 0.01 | ||
elif any(num > self.config.max_value or num < self.config.min_value for num in used_numbers): | ||
reward = 0.01 | ||
elif not solved: | ||
reward = 0.01 | ||
else: | ||
reward = 1.0 | ||
except Exception as e: | ||
reward = 0.01 | ||
return reward | ||
|
||
|
||
register_dataset("puzzle24", Puzzle24Dataset, Puzzle24Config) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
import re | ||
|
||
import pytest | ||
from sympy.parsing.sympy_parser import parse_expr | ||
|
||
from reasoning_gym.games.puzzle24 import Puzzle24Config, Puzzle24Dataset | ||
|
||
|
||
def test_puzzle24_config_validation(): | ||
"""Test that invalid configs raise appropriate errors""" | ||
# Min value greater than max value | ||
with pytest.raises(AssertionError): | ||
config = Puzzle24Config(min_value=10, max_value=5) | ||
config.validate() | ||
|
||
# Invalid operators | ||
with pytest.raises(AssertionError): | ||
config = Puzzle24Config(operators=()) # Empty operators | ||
config.validate() | ||
|
||
# Negative min value | ||
with pytest.raises(AssertionError): | ||
config = Puzzle24Config(min_value=-1) | ||
config.validate() | ||
|
||
|
||
def test_puzzle24_deterministic(): | ||
"""Test that dataset generates same items with same seed""" | ||
config = Puzzle24Config(seed=42, size=10) | ||
dataset1 = Puzzle24Dataset(config) | ||
dataset2 = Puzzle24Dataset(config) | ||
|
||
for i in range(len(dataset1)): | ||
assert dataset1[i] == dataset2[i] | ||
|
||
|
||
def test_puzzle24_basic_properties(): | ||
"""Test basic properties of generated items""" | ||
config = Puzzle24Config(seed=42, size=10) | ||
dataset = Puzzle24Dataset(config) | ||
|
||
for item in dataset: | ||
assert isinstance(item, dict) | ||
assert "question" in item | ||
assert "answer" in item | ||
assert "metadata" in item | ||
|
||
# Check metadata contains required fields | ||
assert "numbers" in item["metadata"] | ||
assert "expression" in item["metadata"] | ||
|
||
# Check question format | ||
assert "Make 24 using" in item["question"] | ||
assert "Final answer format instructions" in item["question"] | ||
|
||
# Verify the numbers in question match metadata | ||
numbers = item["metadata"]["numbers"] | ||
num_str = ", ".join(map(str, numbers)) | ||
assert num_str in item["question"] | ||
|
||
|
||
def test_puzzle24_score_answer_correct(): | ||
"""Test the score_answer method for correct answers""" | ||
config = Puzzle24Config(seed=42) | ||
dataset = Puzzle24Dataset(config) | ||
for item in dataset: | ||
answer = item["answer"] | ||
print(item) | ||
assert dataset.score_answer(answer, item) == 1.0 | ||
|
||
|
||
def test_puzzle24_score_answer_individual(): | ||
"""Test the score_answer method""" | ||
config = Puzzle24Config(seed=42) | ||
dataset = Puzzle24Dataset(config) | ||
|
||
# Create a test entry | ||
entry = {"metadata": {"numbers": [4, 5, 7, 8], "expression": parse_expr("x0 + x1 + x2 + x3")}} | ||
|
||
# Test correct answer (evaluates to 24) | ||
assert dataset.score_answer("4 + 5 + 7 + 8", entry) == 1.0 | ||
|
||
# Test incorrect answers | ||
assert dataset.score_answer(None, entry) == 0.01 # None answer | ||
assert dataset.score_answer("", entry) == 0.01 # Empty answer | ||
assert dataset.score_answer("1+2+3", entry) == 0.01 # Wrong numbers | ||
assert dataset.score_answer("4*5*7*8", entry) == 0.01 # Doesn't equal 24 | ||
assert dataset.score_answer("not a valid expression", entry) == 0.01 # Invalid expression | ||
|
||
|
||
def test_puzzle24_iteration(): | ||
"""Test that iteration respects dataset size""" | ||
config = Puzzle24Config(size=5, seed=42) # Small size for testing | ||
dataset = Puzzle24Dataset(config) | ||
|
||
# Test manual iteration | ||
items = [] | ||
for item in dataset: | ||
items.append(item) | ||
assert len(items) == config.size, "Iterator should yield exactly size items" | ||
|
||
# Test list conversion | ||
items = list(dataset) | ||
assert len(items) == config.size, "Iterator should yield exactly size items" | ||
|
||
# Test multiple iterations | ||
first_items = list(dataset) | ||
second_items = list(dataset) | ||
assert first_items == second_items, "Multiple iterations should yield same items" |