Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added puzzle24 closes #208 #268

Merged
merged 3 commits into from
Mar 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions reasoning_gym/games/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .maze import MazeConfig, MazeDataset
from .mini_sudoku import MiniSudokuConfig, MiniSudokuDataset
from .n_queens import NQueensConfig, NQueensCurriculum, NQueensDataset
from .puzzle24 import Puzzle24Config, Puzzle24Dataset
from .rush_hour import RushHourConfig, RushHourDataset
from .sokoban import SokobanConfig, SokobanDataset
from .sudoku import SudokuConfig, SudokuDataset
Expand All @@ -29,6 +30,8 @@
"FutoshikiDataset",
"MiniSudokuConfig",
"MiniSudokuDataset",
"Puzzle24Config",
"Puzzle24Dataset",
"SudokuConfig",
"SudokuDataset",
"SokobanConfig",
Expand Down
130 changes: 130 additions & 0 deletions reasoning_gym/games/puzzle24.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import re
from dataclasses import dataclass
from random import Random
from typing import Any, Optional

import sympy
from sympy import Symbol, symbols
from sympy.parsing.sympy_parser import parse_expr

from ..factory import ProceduralDataset, register_dataset

QUESTION_TEMPLATE = """Make 24 using {numbers}. You can only use each number once. You can use the operators {operators}.
Final answer format instructions:
1. Provide your final answer as a arithmetic expression (no '=' sign).
2. Do not include the target number in the expression.
3. Use '*' for multiplication.
4. Use '/' for division.
"""


@dataclass
class Puzzle24Config:
operators: tuple = ("+", "-", "*", "/")
min_value: int = 1
max_value: int = 10
seed: Optional[int] = None
size: int = 500

def validate(self):
assert len(self.operators) > 0, "At least one operator is required"
assert self.min_value <= self.max_value, "Minimum value must be less than or equal to maximum value"
assert self.min_value >= 1, "Minimum value must be at least 1"
assert self.max_value <= 10, "Maximum value must be at most 10"
assert self.size > 0, "Size must be greater than 0"


class Puzzle24Dataset(ProceduralDataset):
def __init__(self, config: Puzzle24Config):
super().__init__(config=config, seed=config.seed, size=config.size)

def _generate_candidate_expression(self, rng: Random, num_terms: int) -> tuple[sympy.Expr, list[int], list[Symbol]]:
"""Generate a candidate expression with random numbers and operators

Args:
rng: Random number generator
num_terms: Number of terms to include

Returns:
Tuple of (sympy expression, list of numbers, list of symbols)
"""
# Generate random numbers
numbers = [rng.randint(self.config.min_value, self.config.max_value) for _ in range(num_terms)]

# Create symbols for building expression
syms = symbols(f"x:{num_terms}")

# Build random expression
expr = syms[0]

for i in range(1, num_terms):
op = rng.choice(self.config.operators)
if op == "+":
expr = expr + syms[i]
elif op == "-":
expr = expr - syms[i]
elif op == "*":
expr = expr * syms[i]
else:
if numbers[i] != 0:
current = int(expr.subs({sym: num for sym, num in zip(syms[:i], numbers[:i])}))
remaining = [n for n in numbers[i:] if n != 0]
rng.shuffle(remaining)
found_divisor = False
for div in remaining:
if current % div == 0:
numbers[i] = div
expr = expr / syms[i]
found_divisor = True
break
if not found_divisor:
expr = expr - syms[i]
else:
expr = expr + syms[i]

return expr, numbers, syms

def __getitem__(self, idx: int) -> dict:
rng = Random(self.seed + idx)
while True:
expr, numbers, syms = self._generate_candidate_expression(rng, 4)
if expr.subs({sym: num for sym, num in zip(syms, numbers)}) == 24:
break
expr_str = str(expr)
for i, sym in enumerate(syms):
expr_str = expr_str.replace(str(sym), str(numbers[i]))

question = QUESTION_TEMPLATE.format(
numbers=", ".join(map(str, numbers)), operators=", ".join(self.config.operators)
)
return {
"question": question,
"answer": expr_str,
"metadata": {
"numbers": numbers,
"expression": expr,
},
}

def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
reward = 0.01
if answer is not None:
try:
answer = answer.strip()
user_answer = int(parse_expr(answer))
solved = user_answer == 24
used_numbers = [int(num) for num in re.findall(r"\b\d+\b", answer)]
if len(used_numbers) != 4:
reward = 0.01
elif any(num > self.config.max_value or num < self.config.min_value for num in used_numbers):
reward = 0.01
elif not solved:
reward = 0.01
else:
reward = 1.0
except Exception as e:
reward = 0.01
return reward


register_dataset("puzzle24", Puzzle24Dataset, Puzzle24Config)
109 changes: 109 additions & 0 deletions tests/test_puzzle24.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import re

import pytest
from sympy.parsing.sympy_parser import parse_expr

from reasoning_gym.games.puzzle24 import Puzzle24Config, Puzzle24Dataset


def test_puzzle24_config_validation():
"""Test that invalid configs raise appropriate errors"""
# Min value greater than max value
with pytest.raises(AssertionError):
config = Puzzle24Config(min_value=10, max_value=5)
config.validate()

# Invalid operators
with pytest.raises(AssertionError):
config = Puzzle24Config(operators=()) # Empty operators
config.validate()

# Negative min value
with pytest.raises(AssertionError):
config = Puzzle24Config(min_value=-1)
config.validate()


def test_puzzle24_deterministic():
"""Test that dataset generates same items with same seed"""
config = Puzzle24Config(seed=42, size=10)
dataset1 = Puzzle24Dataset(config)
dataset2 = Puzzle24Dataset(config)

for i in range(len(dataset1)):
assert dataset1[i] == dataset2[i]


def test_puzzle24_basic_properties():
"""Test basic properties of generated items"""
config = Puzzle24Config(seed=42, size=10)
dataset = Puzzle24Dataset(config)

for item in dataset:
assert isinstance(item, dict)
assert "question" in item
assert "answer" in item
assert "metadata" in item

# Check metadata contains required fields
assert "numbers" in item["metadata"]
assert "expression" in item["metadata"]

# Check question format
assert "Make 24 using" in item["question"]
assert "Final answer format instructions" in item["question"]

# Verify the numbers in question match metadata
numbers = item["metadata"]["numbers"]
num_str = ", ".join(map(str, numbers))
assert num_str in item["question"]


def test_puzzle24_score_answer_correct():
"""Test the score_answer method for correct answers"""
config = Puzzle24Config(seed=42)
dataset = Puzzle24Dataset(config)
for item in dataset:
answer = item["answer"]
print(item)
assert dataset.score_answer(answer, item) == 1.0


def test_puzzle24_score_answer_individual():
"""Test the score_answer method"""
config = Puzzle24Config(seed=42)
dataset = Puzzle24Dataset(config)

# Create a test entry
entry = {"metadata": {"numbers": [4, 5, 7, 8], "expression": parse_expr("x0 + x1 + x2 + x3")}}

# Test correct answer (evaluates to 24)
assert dataset.score_answer("4 + 5 + 7 + 8", entry) == 1.0

# Test incorrect answers
assert dataset.score_answer(None, entry) == 0.01 # None answer
assert dataset.score_answer("", entry) == 0.01 # Empty answer
assert dataset.score_answer("1+2+3", entry) == 0.01 # Wrong numbers
assert dataset.score_answer("4*5*7*8", entry) == 0.01 # Doesn't equal 24
assert dataset.score_answer("not a valid expression", entry) == 0.01 # Invalid expression


def test_puzzle24_iteration():
"""Test that iteration respects dataset size"""
config = Puzzle24Config(size=5, seed=42) # Small size for testing
dataset = Puzzle24Dataset(config)

# Test manual iteration
items = []
for item in dataset:
items.append(item)
assert len(items) == config.size, "Iterator should yield exactly size items"

# Test list conversion
items = list(dataset)
assert len(items) == config.size, "Iterator should yield exactly size items"

# Test multiple iterations
first_items = list(dataset)
second_items = list(dataset)
assert first_items == second_items, "Multiple iterations should yield same items"