diff --git a/reasoning_gym/games/countdown.py b/reasoning_gym/games/countdown.py index 4721844d..38a60c4f 100644 --- a/reasoning_gym/games/countdown.py +++ b/reasoning_gym/games/countdown.py @@ -1,9 +1,10 @@ from dataclasses import dataclass from random import Random -from typing import List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import sympy from sympy import Symbol, symbols +from sympy.parsing.sympy_parser import parse_expr from ..factory import ProceduralDataset, register_dataset @@ -158,6 +159,23 @@ def _generate_expression(self, rng: Random) -> Tuple[str, List[int], int]: raise ValueError(f"Failed to generate valid expression after {max_attempts} attempts") + def score_answer(self, answer: Optional[str], metadata: Dict[str, Any]) -> float: + """Determine if the solution provided solves the problem""" + reward = 0.0 + if answer is not None: + try: + user_answer = int(parse_expr(answer)) + solved = user_answer == metadata["target"] + if solved: + reward = 1.0 + elif len(answer.strip()) > 0: # encourage partial solutions + reward = 0.05 + else: + reward = 0.01 + except: + reward = 0.01 + return reward + # Register the dataset register_dataset("countdown", CountdownDataset, CountdownConfig) diff --git a/tests/test_countdown.py b/tests/test_countdown.py index e426caf2..e78a69ab 100644 --- a/tests/test_countdown.py +++ b/tests/test_countdown.py @@ -64,6 +64,16 @@ def test_countdown_game_items(): # Verify expression evaluates correctly expr = item["metadata"]["expression"] + + # check score + assert dataset.score_answer(answer=expr, metadata=item["metadata"]) == 1.0 # correct answer + assert dataset.score_answer(answer="45+2", metadata=item["metadata"]) == 0.05 # wrong answer but an attempt + assert ( + dataset.score_answer(answer="a wrong solution", metadata=item["metadata"]) == 0.01 + ) # wrong answer but incorrectly formatted + assert dataset.score_answer(answer="", metadata=item["metadata"]) == 0.01 # wrong answer but empty string + assert dataset.score_answer(answer=None, metadata=item["metadata"]) == 0.0 # no answer + try: result = eval(expr) # Safe here since we control expression generation assert result == item["metadata"]["target"]