From 6307adc1c34b80db8b56b2461c558bf09798d28b Mon Sep 17 00:00:00 2001 From: joesharratt1229 Date: Tue, 4 Mar 2025 21:41:04 +0000 Subject: [PATCH] added puzzle24 --- reasoning_gym/games/__init__.py | 3 + reasoning_gym/games/puzzle24.py | 127 ++++++++++++++++++++++++++++++++ tests/test_puzzle24.py | 99 +++++++++++++++++++++++++ 3 files changed, 229 insertions(+) create mode 100644 reasoning_gym/games/puzzle24.py create mode 100644 tests/test_puzzle24.py diff --git a/reasoning_gym/games/__init__.py b/reasoning_gym/games/__init__.py index 4b450134..326f825d 100644 --- a/reasoning_gym/games/__init__.py +++ b/reasoning_gym/games/__init__.py @@ -14,6 +14,7 @@ from .maze import MazeConfig, MazeDataset from .mini_sudoku import MiniSudokuConfig, MiniSudokuDataset from .n_queens import NQueensDataset +from .puzzle24 import Puzzle24Config, Puzzle24Dataset from .rush_hour import RushHourConfig, RushHourDataset from .sokoban import SokobanConfig, SokobanDataset from .sudoku import SudokuConfig, SudokuDataset @@ -29,6 +30,8 @@ "FutoshikiDataset", "MiniSudokuConfig", "MiniSudokuDataset", + "Puzzle24Config", + "Puzzle24Dataset", "SudokuConfig", "SudokuDataset", "SokobanConfig", diff --git a/reasoning_gym/games/puzzle24.py b/reasoning_gym/games/puzzle24.py new file mode 100644 index 00000000..908e587d --- /dev/null +++ b/reasoning_gym/games/puzzle24.py @@ -0,0 +1,127 @@ +import re +from dataclasses import dataclass +from random import Random +from typing import Any, Optional + +import sympy +from sympy import Symbol, symbols +from sympy.parsing.sympy_parser import parse_expr + +from ..factory import ProceduralDataset, register_dataset + +QUESTION_TEMPLATE = """Make 24 using {numbers}. You can only use each number once. You can use the operators {operators}. +Final answer format instructions: +1. Provide your final answer as a arithmetic expression (no '=' sign). +2. Do not include the target number in the expression. +3. Use '*' for multiplication. +4. Use '/' for division. +""" + + +@dataclass +class Puzzle24Config: + operators: tuple = ("+", "-", "*", "/") + min_value: int = 1 + max_value: int = 10 + seed: Optional[int] = None + size: int = 500 + + def validate(self): + assert len(self.operators) > 0, "At least one operator is required" + assert self.min_value <= self.max_value, "Minimum value must be less than or equal to maximum value" + assert self.min_value >= 1, "Minimum value must be at least 1" + assert self.max_value <= 10, "Maximum value must be at most 10" + assert self.size > 0, "Size must be greater than 0" + + +class Puzzle24Dataset(ProceduralDataset): + def __init__(self, config: Puzzle24Config): + super().__init__(config=config, seed=config.seed, size=config.size) + + def _generate_candidate_expression(self, rng: Random, num_terms: int) -> tuple[sympy.Expr, list[int], list[Symbol]]: + """Generate a candidate expression with random numbers and operators + + Args: + rng: Random number generator + num_terms: Number of terms to include + + Returns: + Tuple of (sympy expression, list of numbers, list of symbols) + """ + # Generate random numbers + numbers = [rng.randint(self.config.min_value, self.config.max_value) for _ in range(num_terms)] + + # Create symbols for building expression + syms = symbols(f"x:{num_terms}") + + # Build random expression + expr = syms[0] + + for i in range(1, num_terms): + op = rng.choice(self.config.operators) + if op == "+": + expr = expr + syms[i] + elif op == "-": + expr = expr - syms[i] + elif op == "*": + expr = expr * syms[i] + else: + if numbers[i] != 0: + current = int(expr.subs({sym: num for sym, num in zip(syms[:i], numbers[:i])})) + remaining = [n for n in numbers[i:] if n != 0] + rng.shuffle(remaining) + found_divisor = False + for div in remaining: + if current % div == 0: + numbers[i] = div + expr = expr / syms[i] + found_divisor = True + break + if not found_divisor: + expr = expr - syms[i] + else: + expr = expr + syms[i] + + return expr, numbers, syms + + def __getitem__(self, idx: int) -> dict: + rng = Random(self.seed + idx) + while True: + expr, numbers, syms = self._generate_candidate_expression(rng, 4) + if expr.subs({sym: num for sym, num in zip(syms, numbers)}) == 24: + break + + question = QUESTION_TEMPLATE.format( + numbers=", ".join(map(str, numbers)), operators=", ".join(self.config.operators) + ) + return { + "question": question, + "answer": str(expr), + "metadata": { + "numbers": numbers, + "expression": expr, + }, + } + + def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: + reward = 0.01 + if answer is not None: + try: + answer = answer.strip() + user_answer = int(parse_expr(answer)) + solved = user_answer == 24 + used_numbers = [int(num) for num in re.findall(r"\b\d+\b", answer)] + if len(used_numbers) != 4: + reward = 0.01 + elif any(num > self.config.max_value or num < self.config.min_value for num in used_numbers): + reward = 0.01 + elif not solved: + reward = 0.01 + else: + reward = 1.0 + except Exception as e: + reward = 0.01 + return reward + + +register_dataset("puzzle24", Puzzle24Dataset, Puzzle24Config) diff --git a/tests/test_puzzle24.py b/tests/test_puzzle24.py new file mode 100644 index 00000000..8517c73c --- /dev/null +++ b/tests/test_puzzle24.py @@ -0,0 +1,99 @@ +import re + +import pytest +from sympy.parsing.sympy_parser import parse_expr + +from reasoning_gym.games.puzzle24 import Puzzle24Config, Puzzle24Dataset + + +def test_puzzle24_config_validation(): + """Test that invalid configs raise appropriate errors""" + # Min value greater than max value + with pytest.raises(AssertionError): + config = Puzzle24Config(min_value=10, max_value=5) + config.validate() + + # Invalid operators + with pytest.raises(AssertionError): + config = Puzzle24Config(operators=()) # Empty operators + config.validate() + + # Negative min value + with pytest.raises(AssertionError): + config = Puzzle24Config(min_value=-1) + config.validate() + + +def test_puzzle24_deterministic(): + """Test that dataset generates same items with same seed""" + config = Puzzle24Config(seed=42, size=10) + dataset1 = Puzzle24Dataset(config) + dataset2 = Puzzle24Dataset(config) + + for i in range(len(dataset1)): + assert dataset1[i] == dataset2[i] + + +def test_puzzle24_basic_properties(): + """Test basic properties of generated items""" + config = Puzzle24Config(seed=42, size=10) + dataset = Puzzle24Dataset(config) + + for item in dataset: + assert isinstance(item, dict) + assert "question" in item + assert "answer" in item + assert "metadata" in item + + # Check metadata contains required fields + assert "numbers" in item["metadata"] + assert "expression" in item["metadata"] + + # Check question format + assert "Make 24 using" in item["question"] + assert "Final answer format instructions" in item["question"] + + # Verify the numbers in question match metadata + numbers = item["metadata"]["numbers"] + num_str = ", ".join(map(str, numbers)) + assert num_str in item["question"] + + +def test_puzzle24_score_answer(): + """Test the score_answer method""" + config = Puzzle24Config(seed=42) + dataset = Puzzle24Dataset(config) + + # Create a test entry + entry = {"metadata": {"numbers": [4, 5, 7, 8], "expression": parse_expr("x0 + x1 + x2 + x3")}} + + # Test correct answer (evaluates to 24) + assert dataset.score_answer("4 + 5 + 7 + 8", entry) == 1.0 + + # Test incorrect answers + assert dataset.score_answer(None, entry) == 0.01 # None answer + assert dataset.score_answer("", entry) == 0.01 # Empty answer + assert dataset.score_answer("1+2+3", entry) == 0.01 # Wrong numbers + assert dataset.score_answer("4*5*7*8", entry) == 0.01 # Doesn't equal 24 + assert dataset.score_answer("not a valid expression", entry) == 0.01 # Invalid expression + + +def test_puzzle24_iteration(): + """Test that iteration respects dataset size""" + config = Puzzle24Config(size=5, seed=42) # Small size for testing + dataset = Puzzle24Dataset(config) + + # Test manual iteration + items = [] + for item in dataset: + items.append(item) + assert len(items) == config.size, "Iterator should yield exactly size items" + + # Test list conversion + items = list(dataset) + assert len(items) == config.size, "Iterator should yield exactly size items" + + # Test multiple iterations + first_items = list(dataset) + second_items = list(dataset) + assert first_items == second_items, "Multiple iterations should yield same items"