Skip to content

Commit

Permalink
Added puzzle24 closes #208 (#268)
Browse files Browse the repository at this point in the history
* added puzzle24
  • Loading branch information
joesharratt1229 authored Mar 5, 2025
1 parent d1e505a commit f3ee9a9
Show file tree
Hide file tree
Showing 3 changed files with 242 additions and 0 deletions.
3 changes: 3 additions & 0 deletions reasoning_gym/games/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .maze import MazeConfig, MazeDataset
from .mini_sudoku import MiniSudokuConfig, MiniSudokuDataset
from .n_queens import NQueensConfig, NQueensCurriculum, NQueensDataset
from .puzzle24 import Puzzle24Config, Puzzle24Dataset
from .rush_hour import RushHourConfig, RushHourDataset
from .sokoban import SokobanConfig, SokobanDataset
from .sudoku import SudokuConfig, SudokuDataset
Expand All @@ -29,6 +30,8 @@
"FutoshikiDataset",
"MiniSudokuConfig",
"MiniSudokuDataset",
"Puzzle24Config",
"Puzzle24Dataset",
"SudokuConfig",
"SudokuDataset",
"SokobanConfig",
Expand Down
130 changes: 130 additions & 0 deletions reasoning_gym/games/puzzle24.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import re
from dataclasses import dataclass
from random import Random
from typing import Any, Optional

import sympy
from sympy import Symbol, symbols
from sympy.parsing.sympy_parser import parse_expr

from ..factory import ProceduralDataset, register_dataset

QUESTION_TEMPLATE = """Make 24 using {numbers}. You can only use each number once. You can use the operators {operators}.
Final answer format instructions:
1. Provide your final answer as a arithmetic expression (no '=' sign).
2. Do not include the target number in the expression.
3. Use '*' for multiplication.
4. Use '/' for division.
"""


@dataclass
class Puzzle24Config:
operators: tuple = ("+", "-", "*", "/")
min_value: int = 1
max_value: int = 10
seed: Optional[int] = None
size: int = 500

def validate(self):
assert len(self.operators) > 0, "At least one operator is required"
assert self.min_value <= self.max_value, "Minimum value must be less than or equal to maximum value"
assert self.min_value >= 1, "Minimum value must be at least 1"
assert self.max_value <= 10, "Maximum value must be at most 10"
assert self.size > 0, "Size must be greater than 0"


class Puzzle24Dataset(ProceduralDataset):
def __init__(self, config: Puzzle24Config):
super().__init__(config=config, seed=config.seed, size=config.size)

def _generate_candidate_expression(self, rng: Random, num_terms: int) -> tuple[sympy.Expr, list[int], list[Symbol]]:
"""Generate a candidate expression with random numbers and operators
Args:
rng: Random number generator
num_terms: Number of terms to include
Returns:
Tuple of (sympy expression, list of numbers, list of symbols)
"""
# Generate random numbers
numbers = [rng.randint(self.config.min_value, self.config.max_value) for _ in range(num_terms)]

# Create symbols for building expression
syms = symbols(f"x:{num_terms}")

# Build random expression
expr = syms[0]

for i in range(1, num_terms):
op = rng.choice(self.config.operators)
if op == "+":
expr = expr + syms[i]
elif op == "-":
expr = expr - syms[i]
elif op == "*":
expr = expr * syms[i]
else:
if numbers[i] != 0:
current = int(expr.subs({sym: num for sym, num in zip(syms[:i], numbers[:i])}))
remaining = [n for n in numbers[i:] if n != 0]
rng.shuffle(remaining)
found_divisor = False
for div in remaining:
if current % div == 0:
numbers[i] = div
expr = expr / syms[i]
found_divisor = True
break
if not found_divisor:
expr = expr - syms[i]
else:
expr = expr + syms[i]

return expr, numbers, syms

def __getitem__(self, idx: int) -> dict:
rng = Random(self.seed + idx)
while True:
expr, numbers, syms = self._generate_candidate_expression(rng, 4)
if expr.subs({sym: num for sym, num in zip(syms, numbers)}) == 24:
break
expr_str = str(expr)
for i, sym in enumerate(syms):
expr_str = expr_str.replace(str(sym), str(numbers[i]))

question = QUESTION_TEMPLATE.format(
numbers=", ".join(map(str, numbers)), operators=", ".join(self.config.operators)
)
return {
"question": question,
"answer": expr_str,
"metadata": {
"numbers": numbers,
"expression": expr,
},
}

def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
reward = 0.01
if answer is not None:
try:
answer = answer.strip()
user_answer = int(parse_expr(answer))
solved = user_answer == 24
used_numbers = [int(num) for num in re.findall(r"\b\d+\b", answer)]
if len(used_numbers) != 4:
reward = 0.01
elif any(num > self.config.max_value or num < self.config.min_value for num in used_numbers):
reward = 0.01
elif not solved:
reward = 0.01
else:
reward = 1.0
except Exception as e:
reward = 0.01
return reward


register_dataset("puzzle24", Puzzle24Dataset, Puzzle24Config)
109 changes: 109 additions & 0 deletions tests/test_puzzle24.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import re

import pytest
from sympy.parsing.sympy_parser import parse_expr

from reasoning_gym.games.puzzle24 import Puzzle24Config, Puzzle24Dataset


def test_puzzle24_config_validation():
"""Test that invalid configs raise appropriate errors"""
# Min value greater than max value
with pytest.raises(AssertionError):
config = Puzzle24Config(min_value=10, max_value=5)
config.validate()

# Invalid operators
with pytest.raises(AssertionError):
config = Puzzle24Config(operators=()) # Empty operators
config.validate()

# Negative min value
with pytest.raises(AssertionError):
config = Puzzle24Config(min_value=-1)
config.validate()


def test_puzzle24_deterministic():
"""Test that dataset generates same items with same seed"""
config = Puzzle24Config(seed=42, size=10)
dataset1 = Puzzle24Dataset(config)
dataset2 = Puzzle24Dataset(config)

for i in range(len(dataset1)):
assert dataset1[i] == dataset2[i]


def test_puzzle24_basic_properties():
"""Test basic properties of generated items"""
config = Puzzle24Config(seed=42, size=10)
dataset = Puzzle24Dataset(config)

for item in dataset:
assert isinstance(item, dict)
assert "question" in item
assert "answer" in item
assert "metadata" in item

# Check metadata contains required fields
assert "numbers" in item["metadata"]
assert "expression" in item["metadata"]

# Check question format
assert "Make 24 using" in item["question"]
assert "Final answer format instructions" in item["question"]

# Verify the numbers in question match metadata
numbers = item["metadata"]["numbers"]
num_str = ", ".join(map(str, numbers))
assert num_str in item["question"]


def test_puzzle24_score_answer_correct():
"""Test the score_answer method for correct answers"""
config = Puzzle24Config(seed=42)
dataset = Puzzle24Dataset(config)
for item in dataset:
answer = item["answer"]
print(item)
assert dataset.score_answer(answer, item) == 1.0


def test_puzzle24_score_answer_individual():
"""Test the score_answer method"""
config = Puzzle24Config(seed=42)
dataset = Puzzle24Dataset(config)

# Create a test entry
entry = {"metadata": {"numbers": [4, 5, 7, 8], "expression": parse_expr("x0 + x1 + x2 + x3")}}

# Test correct answer (evaluates to 24)
assert dataset.score_answer("4 + 5 + 7 + 8", entry) == 1.0

# Test incorrect answers
assert dataset.score_answer(None, entry) == 0.01 # None answer
assert dataset.score_answer("", entry) == 0.01 # Empty answer
assert dataset.score_answer("1+2+3", entry) == 0.01 # Wrong numbers
assert dataset.score_answer("4*5*7*8", entry) == 0.01 # Doesn't equal 24
assert dataset.score_answer("not a valid expression", entry) == 0.01 # Invalid expression


def test_puzzle24_iteration():
"""Test that iteration respects dataset size"""
config = Puzzle24Config(size=5, seed=42) # Small size for testing
dataset = Puzzle24Dataset(config)

# Test manual iteration
items = []
for item in dataset:
items.append(item)
assert len(items) == config.size, "Iterator should yield exactly size items"

# Test list conversion
items = list(dataset)
assert len(items) == config.size, "Iterator should yield exactly size items"

# Test multiple iterations
first_items = list(dataset)
second_items = list(dataset)
assert first_items == second_items, "Multiple iterations should yield same items"

0 comments on commit f3ee9a9

Please sign in to comment.