Skip to content

Commit

Permalink
Fixed countdown score_answer (#265)
Browse files Browse the repository at this point in the history
* fixed countdown score ans
* checked solution uses all numbers
  • Loading branch information
joesharratt1229 authored Mar 5, 2025
1 parent d0a4211 commit e30be06
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 19 deletions.
39 changes: 21 additions & 18 deletions reasoning_gym/games/countdown.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
from dataclasses import dataclass
from random import Random
from typing import Any, Optional
Expand Down Expand Up @@ -51,9 +52,9 @@ class CountdownDataset(ProceduralDataset):

def __init__(self, config: CountdownConfig):
self._prompt_templates = [
"Using the numbers {numbers}, create an expression that equals {target}.\nYou can only use each number once.",
"Find a way to make {target} using some or all of these numbers: {numbers}.\nEach number can only be used once.",
"Calculate {target} using the numbers {numbers}.\nEach number may be used at most once.",
"Using all the numbers {numbers}, create an expression that equals {target}.\nYou can only use each number once.",
"Find a way to make {target} using all of these numbers: {numbers}.\nEach number can only be used once.",
"Calculate {target} using all of these numbers: {numbers}.\nEach number may be used at most once.",
]
super().__init__(config=config, seed=config.seed, size=config.size)

Expand Down Expand Up @@ -174,21 +175,23 @@ def _generate_expression(self, rng: Random) -> tuple[str, list[int], int]:

def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""Determine if the solution provided solves the problem"""
reward = 0.0
metadata = entry["metadata"]
if answer is not None:
try:
user_answer = int(parse_expr(answer))
solved = user_answer == metadata["target"]
if solved:
reward = 1.0
elif len(answer.strip()) > 0: # encourage partial solutions
reward = 0.05
else:
reward = 0.01
except:
reward = 0.01
return reward
reward = 0.01 # Default reward

if answer is None or not answer.strip():
return reward

try:
answer = answer.strip()
user_answer = int(parse_expr(answer))
used_numbers = [int(num) for num in re.findall(r"\b\d+\b", answer)]
target_numbers = set(entry["metadata"]["numbers"])

if (user_answer == entry["metadata"]["target"]) and (set(used_numbers) == target_numbers):
return 1.0

return 0.05 if answer else 0.01
except Exception:
return 0.01


# Register the dataset
Expand Down
21 changes: 20 additions & 1 deletion tests/test_countdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def test_countdown_game_items():
dataset.score_answer(answer="a wrong solution", entry=item) == 0.01
) # wrong answer but incorrectly formatted
assert dataset.score_answer(answer="", entry=item) == 0.01 # wrong answer but empty string
assert dataset.score_answer(answer=None, entry=item) == 0.0 # no answer
assert dataset.score_answer(answer=None, entry=item) == 0.01 # no answer

try:
result = eval(expr) # Safe here since we control expression generation
Expand All @@ -81,6 +81,25 @@ def test_countdown_game_items():
pytest.fail(f"Invalid expression generated: {expr}")


def test_answer_with_incorrect_numbers():
dataset = CountdownDataset(CountdownConfig(size=10, seed=42))
answer = "45+2"
item = {
"metadata": {
"numbers": [44, 3],
"target": 47,
}
}
assert dataset.score_answer(answer=answer, entry=item) == 0.05


def test_answer_without_all_numbers():
dataset = CountdownDataset(CountdownConfig(size=10, seed=42))
answer = "45+2+3"
item = {"metadata": {"numbers": [1, 45, 2, 3], "target": 50}}
assert dataset.score_answer(answer=answer, entry=item) == 0.05


def test_countdown_game_randomization():
"""Test number randomization configuration"""
config = CountdownConfig(min_numbers=4, max_numbers=4, shuffle=False, size=10, seed=42) # Fixed size for testing
Expand Down

0 comments on commit e30be06

Please sign in to comment.