From e30be066ec85038e5f82917006986cd91bd7b8b7 Mon Sep 17 00:00:00 2001 From: joesharratt1229 <118444587+joesharratt1229@users.noreply.github.com> Date: Wed, 5 Mar 2025 22:30:12 +0100 Subject: [PATCH] Fixed `countdown` `score_answer` (#265) * fixed countdown score ans * checked solution uses all numbers --- reasoning_gym/games/countdown.py | 39 +++++++++++++++++--------------- tests/test_countdown.py | 21 ++++++++++++++++- 2 files changed, 41 insertions(+), 19 deletions(-) diff --git a/reasoning_gym/games/countdown.py b/reasoning_gym/games/countdown.py index 751c67d7..270476be 100644 --- a/reasoning_gym/games/countdown.py +++ b/reasoning_gym/games/countdown.py @@ -1,3 +1,4 @@ +import re from dataclasses import dataclass from random import Random from typing import Any, Optional @@ -51,9 +52,9 @@ class CountdownDataset(ProceduralDataset): def __init__(self, config: CountdownConfig): self._prompt_templates = [ - "Using the numbers {numbers}, create an expression that equals {target}.\nYou can only use each number once.", - "Find a way to make {target} using some or all of these numbers: {numbers}.\nEach number can only be used once.", - "Calculate {target} using the numbers {numbers}.\nEach number may be used at most once.", + "Using all the numbers {numbers}, create an expression that equals {target}.\nYou can only use each number once.", + "Find a way to make {target} using all of these numbers: {numbers}.\nEach number can only be used once.", + "Calculate {target} using all of these numbers: {numbers}.\nEach number may be used at most once.", ] super().__init__(config=config, seed=config.seed, size=config.size) @@ -174,21 +175,23 @@ def _generate_expression(self, rng: Random) -> tuple[str, list[int], int]: def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: """Determine if the solution provided solves the problem""" - reward = 0.0 - metadata = entry["metadata"] - if answer is not None: - try: - user_answer = int(parse_expr(answer)) - solved = user_answer == metadata["target"] - if solved: - reward = 1.0 - elif len(answer.strip()) > 0: # encourage partial solutions - reward = 0.05 - else: - reward = 0.01 - except: - reward = 0.01 - return reward + reward = 0.01 # Default reward + + if answer is None or not answer.strip(): + return reward + + try: + answer = answer.strip() + user_answer = int(parse_expr(answer)) + used_numbers = [int(num) for num in re.findall(r"\b\d+\b", answer)] + target_numbers = set(entry["metadata"]["numbers"]) + + if (user_answer == entry["metadata"]["target"]) and (set(used_numbers) == target_numbers): + return 1.0 + + return 0.05 if answer else 0.01 + except Exception: + return 0.01 # Register the dataset diff --git a/tests/test_countdown.py b/tests/test_countdown.py index 2bd20f4f..6bfc0cbf 100644 --- a/tests/test_countdown.py +++ b/tests/test_countdown.py @@ -72,7 +72,7 @@ def test_countdown_game_items(): dataset.score_answer(answer="a wrong solution", entry=item) == 0.01 ) # wrong answer but incorrectly formatted assert dataset.score_answer(answer="", entry=item) == 0.01 # wrong answer but empty string - assert dataset.score_answer(answer=None, entry=item) == 0.0 # no answer + assert dataset.score_answer(answer=None, entry=item) == 0.01 # no answer try: result = eval(expr) # Safe here since we control expression generation @@ -81,6 +81,25 @@ def test_countdown_game_items(): pytest.fail(f"Invalid expression generated: {expr}") +def test_answer_with_incorrect_numbers(): + dataset = CountdownDataset(CountdownConfig(size=10, seed=42)) + answer = "45+2" + item = { + "metadata": { + "numbers": [44, 3], + "target": 47, + } + } + assert dataset.score_answer(answer=answer, entry=item) == 0.05 + + +def test_answer_without_all_numbers(): + dataset = CountdownDataset(CountdownConfig(size=10, seed=42)) + answer = "45+2+3" + item = {"metadata": {"numbers": [1, 45, 2, 3], "target": 50}} + assert dataset.score_answer(answer=answer, entry=item) == 0.05 + + def test_countdown_game_randomization(): """Test number randomization configuration""" config = CountdownConfig(min_numbers=4, max_numbers=4, shuffle=False, size=10, seed=42) # Fixed size for testing