From 131e0d8f19ca9d1329a924240c8131777ac0e717 Mon Sep 17 00:00:00 2001 From: EC2 Default User Date: Fri, 31 Jan 2025 06:42:25 +0000 Subject: [PATCH 1/3] added countdown score answer impl --- reasoning_gym/games/countdown.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/reasoning_gym/games/countdown.py b/reasoning_gym/games/countdown.py index 4721844d..87d1a793 100644 --- a/reasoning_gym/games/countdown.py +++ b/reasoning_gym/games/countdown.py @@ -1,9 +1,10 @@ from dataclasses import dataclass from random import Random -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Dict, Any import sympy from sympy import Symbol, symbols +from sympy.parsing.sympy_parser import parse_expr from ..factory import ProceduralDataset, register_dataset @@ -157,6 +158,23 @@ def _generate_expression(self, rng: Random) -> Tuple[str, List[int], int]: continue raise ValueError(f"Failed to generate valid expression after {max_attempts} attempts") + + def score_answer(self, answer: Optional[str], metadata: Dict[str, Any]) -> float: + """Determine if the solution provided solves the problem""" + reward = 0.0 + if answer is not None: + try: + user_answer = int(parse_expr(answer)) + solved = user_answer == metadata["target"] + if solved: + reward = 1.0 + elif (len(answer.strip()) > 0): # encourage partial solutions + reward = 0.05 + else: + reward = 0.01 + except: + reward = 0.01 + return reward # Register the dataset From 4fea3c3378f90e2298c56a3b85ae8c8b87dec626 Mon Sep 17 00:00:00 2001 From: joesharratt1229 Date: Fri, 31 Jan 2025 06:46:18 +0000 Subject: [PATCH 2/3] added testing of score answer method --- tests/test_countdown.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/test_countdown.py b/tests/test_countdown.py index e426caf2..04bf1a8b 100644 --- a/tests/test_countdown.py +++ b/tests/test_countdown.py @@ -64,6 +64,14 @@ def test_countdown_game_items(): # Verify expression evaluates correctly expr = item["metadata"]["expression"] + + #check score + assert dataset.score_answer(answer=expr, metadata=item["metadata"]) == 1.0 #correct answer + assert dataset.score_answer(answer="45+2", metadata=item["metadata"]) == 0.05 #wrong answer but an attempt + assert dataset.score_answer(answer="a wrong solution", metadata=item["metadata"]) == 0.01 #wrong answer but incorrectly formatted + assert dataset.score_answer(answer="", metadata=item["metadata"]) == 0.01 #wrong answer but empty string + assert dataset.score_answer(answer=None, metadata=item["metadata"]) == 0.0 #no answer + try: result = eval(expr) # Safe here since we control expression generation assert result == item["metadata"]["target"] From 37375f08a95a41cd9e804b185cb1522a66041b37 Mon Sep 17 00:00:00 2001 From: joesharratt1229 Date: Fri, 31 Jan 2025 07:19:55 +0000 Subject: [PATCH 3/3] added linting checks --- reasoning_gym/games/countdown.py | 6 +++--- tests/test_countdown.py | 18 ++++++++++-------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/reasoning_gym/games/countdown.py b/reasoning_gym/games/countdown.py index 87d1a793..38a60c4f 100644 --- a/reasoning_gym/games/countdown.py +++ b/reasoning_gym/games/countdown.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from random import Random -from typing import List, Optional, Tuple, Dict, Any +from typing import Any, Dict, List, Optional, Tuple import sympy from sympy import Symbol, symbols @@ -158,7 +158,7 @@ def _generate_expression(self, rng: Random) -> Tuple[str, List[int], int]: continue raise ValueError(f"Failed to generate valid expression after {max_attempts} attempts") - + def score_answer(self, answer: Optional[str], metadata: Dict[str, Any]) -> float: """Determine if the solution provided solves the problem""" reward = 0.0 @@ -168,7 +168,7 @@ def score_answer(self, answer: Optional[str], metadata: Dict[str, Any]) -> float solved = user_answer == metadata["target"] if solved: reward = 1.0 - elif (len(answer.strip()) > 0): # encourage partial solutions + elif len(answer.strip()) > 0: # encourage partial solutions reward = 0.05 else: reward = 0.01 diff --git a/tests/test_countdown.py b/tests/test_countdown.py index 04bf1a8b..e78a69ab 100644 --- a/tests/test_countdown.py +++ b/tests/test_countdown.py @@ -64,14 +64,16 @@ def test_countdown_game_items(): # Verify expression evaluates correctly expr = item["metadata"]["expression"] - - #check score - assert dataset.score_answer(answer=expr, metadata=item["metadata"]) == 1.0 #correct answer - assert dataset.score_answer(answer="45+2", metadata=item["metadata"]) == 0.05 #wrong answer but an attempt - assert dataset.score_answer(answer="a wrong solution", metadata=item["metadata"]) == 0.01 #wrong answer but incorrectly formatted - assert dataset.score_answer(answer="", metadata=item["metadata"]) == 0.01 #wrong answer but empty string - assert dataset.score_answer(answer=None, metadata=item["metadata"]) == 0.0 #no answer - + + # check score + assert dataset.score_answer(answer=expr, metadata=item["metadata"]) == 1.0 # correct answer + assert dataset.score_answer(answer="45+2", metadata=item["metadata"]) == 0.05 # wrong answer but an attempt + assert ( + dataset.score_answer(answer="a wrong solution", metadata=item["metadata"]) == 0.01 + ) # wrong answer but incorrectly formatted + assert dataset.score_answer(answer="", metadata=item["metadata"]) == 0.01 # wrong answer but empty string + assert dataset.score_answer(answer=None, metadata=item["metadata"]) == 0.0 # no answer + try: result = eval(expr) # Safe here since we control expression generation assert result == item["metadata"]["target"]