From 5d7fbac0ad8f926894d4397bddaaf1c8d412cfc5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Tue, 4 Mar 2025 21:55:09 +0100 Subject: [PATCH] Minor question template & score_answer improvements (#261) * math prompt improvements * ignore brackets in complex_arithmetic results * improve additional instruction in prompt of polynomial_equations * more strict tests for score_answer in polynomial_equations * simplify special reward handling * fix test_intermediate_integration * fix sokoban dataset * add common dataset score_answer consistency test --- reasoning_gym/algebra/complex_arithmetic.py | 4 ++ .../algebra/intermediate_integration.py | 15 +++--- reasoning_gym/algebra/polynomial_equations.py | 22 +++++---- .../algebra/polynomial_multiplication.py | 18 +++---- reasoning_gym/algebra/simple_integration.py | 12 ++--- reasoning_gym/algorithmic/ab.py | 7 +-- reasoning_gym/algorithmic/binary_matrix.py | 6 +-- reasoning_gym/algorithmic/cryptarithm.py | 2 +- reasoning_gym/algorithmic/game_of_life.py | 2 +- reasoning_gym/algorithmic/graph_color.py | 9 ++-- reasoning_gym/algorithmic/group_anagrams.py | 2 +- reasoning_gym/algorithmic/jugs.py | 4 +- reasoning_gym/algorithmic/letter_jumble.py | 18 ++++--- .../algorithmic/manipulate_matrix.py | 2 - .../algorithmic/palindrome_generation.py | 4 +- .../algorithmic/palindrome_partitioning.py | 3 +- reasoning_gym/algorithmic/pool_matrix.py | 4 +- reasoning_gym/algorithmic/ransom_note.py | 12 ++--- .../algorithmic/sentence_reordering.py | 2 +- reasoning_gym/algorithmic/spell_backward.py | 4 +- reasoning_gym/algorithmic/spiral_matrix.py | 8 ++-- reasoning_gym/algorithmic/string_insertion.py | 8 ++-- reasoning_gym/algorithmic/word_ladder.py | 12 ++--- reasoning_gym/algorithmic/word_sorting.py | 2 - reasoning_gym/arc/arc_agi.py | 2 +- reasoning_gym/arc/rearc.py | 2 +- .../arithmetic/bitwise_arithmetic.py | 20 ++++---- .../arithmetic/calendar_arithmetic.py | 5 +- .../arithmetic/decimal_arithmetic.py | 10 ++-- reasoning_gym/arithmetic/decimal_chain_sum.py | 19 +++++--- reasoning_gym/arithmetic/dice.py | 11 ++--- reasoning_gym/arithmetic/number_format.py | 5 +- reasoning_gym/arithmetic/power_function.py | 6 +-- reasoning_gym/arithmetic/time_intervals.py | 2 +- reasoning_gym/code/bf.py | 25 +++++----- reasoning_gym/cognition/figlet_fonts.py | 2 +- reasoning_gym/cognition/needle_haystack.py | 18 ++++--- reasoning_gym/cognition/rectangle_count.py | 10 ++-- reasoning_gym/dataset.py | 3 -- .../games/contrib/sokoban/levels/lvl0.dat | 10 ---- .../games/contrib/sokoban/levels/lvl1.dat | 5 -- .../games/contrib/sokoban/levels/lvl2.dat | 6 --- .../games/contrib/sokoban/levels/lvl3.dat | 7 --- .../games/contrib/sokoban/levels/lvl4.dat | 7 --- .../games/contrib/sokoban/levels/lvl5.dat | 7 --- .../games/contrib/sokoban/levels/lvl6.dat | 9 ---- .../games/contrib/sokoban/levels/lvl7.dat | 6 --- .../games/contrib/sokoban/src/astar.py | 11 +++-- .../games/contrib/sokoban/src/game.py | 13 +++-- .../games/contrib/sokoban/src/generator.py | 38 +++++++++------ reasoning_gym/games/futoshiki.py | 2 +- reasoning_gym/games/knight_swap.py | 27 +++++------ reasoning_gym/games/mini_sudoku.py | 2 +- reasoning_gym/games/n_queens.py | 6 +-- reasoning_gym/games/rush_hour.py | 2 +- reasoning_gym/games/sokoban.py | 40 +++++++++++----- reasoning_gym/games/sudoku.py | 2 +- reasoning_gym/games/tower_of_hanoi.py | 24 +++------- reasoning_gym/graphs/family_relationships.py | 10 ++-- reasoning_gym/graphs/quantum_lock.py | 16 ++----- reasoning_gym/graphs/shortest_path.py | 6 +-- reasoning_gym/logic/circuit_logic.py | 18 ++++--- reasoning_gym/logic/knights_knaves.py | 8 ++-- reasoning_gym/logic/propositional_logic.py | 6 +-- reasoning_gym/logic/self_reference.py | 14 ++---- reasoning_gym/logic/zebra_puzzles.py | 10 ++-- reasoning_gym/utils.py | 1 - tests/test_ab.py | 2 +- tests/test_arc_1d.py | 2 +- tests/test_arc_agi.py | 2 +- tests/test_bf.py | 2 +- tests/test_binary_matrix.py | 2 +- tests/test_bitwise_arithmetic.py | 2 +- tests/test_complex_arithmetic.py | 1 + tests/test_dataset.py | 2 +- tests/test_dataset_common.py | 17 +++++++ tests/test_decimal_chain_sum.py | 8 ++-- tests/test_game_of_life.py | 2 +- tests/test_intermediate_integration.py | 14 +++--- tests/test_knight_swap.py | 4 +- tests/test_knights_knaves.py | 3 +- tests/test_manipulate_matrix.py | 2 +- tests/test_n_queens.py | 2 +- tests/test_needle_haystack.py | 2 +- tests/test_number_format.py | 2 +- tests/test_palindrome.py | 4 +- tests/test_palindrome_partitioning.py | 12 ++--- tests/test_polynomial_equations.py | 10 ++-- tests/test_polynomial_multiplication.py | 48 ++----------------- tests/test_pool_matrix.py | 2 +- tests/test_power_function.py | 2 +- tests/test_products.py | 2 +- tests/test_propositional_logic.py | 2 +- tests/test_quantum_lock.py | 9 ++-- tests/test_ransom_note.py | 2 +- tests/test_rearc.py | 2 +- tests/test_self_reference.py | 12 ++--- tests/test_shortest_path.py | 2 +- tests/test_simple_integration.py | 14 +++--- tests/test_sokoban.py | 18 +++++-- tests/test_spiral_matrix.py | 4 +- tests/test_string_insertion.py | 2 +- tests/test_tower_of_hanoi.py | 21 ++++---- tests/test_utils.py | 2 +- tests/test_word_ladder.py | 12 ++--- tests/test_word_sorting.py | 4 +- 106 files changed, 394 insertions(+), 498 deletions(-) delete mode 100644 reasoning_gym/games/contrib/sokoban/levels/lvl0.dat delete mode 100644 reasoning_gym/games/contrib/sokoban/levels/lvl1.dat delete mode 100644 reasoning_gym/games/contrib/sokoban/levels/lvl2.dat delete mode 100644 reasoning_gym/games/contrib/sokoban/levels/lvl3.dat delete mode 100644 reasoning_gym/games/contrib/sokoban/levels/lvl4.dat delete mode 100644 reasoning_gym/games/contrib/sokoban/levels/lvl5.dat delete mode 100644 reasoning_gym/games/contrib/sokoban/levels/lvl6.dat delete mode 100644 reasoning_gym/games/contrib/sokoban/levels/lvl7.dat create mode 100644 tests/test_dataset_common.py diff --git a/reasoning_gym/algebra/complex_arithmetic.py b/reasoning_gym/algebra/complex_arithmetic.py index 30ffe972..903cdb99 100644 --- a/reasoning_gym/algebra/complex_arithmetic.py +++ b/reasoning_gym/algebra/complex_arithmetic.py @@ -103,6 +103,10 @@ def parse_string_to_complex(answer: str) -> complex: # Normalize the answer string by removing spaces and converting to lowercase answer = answer.replace(" ", "").lower() + # remove brackets + while len(answer) > 1 and answer[0] == "(" and answer[-1] == ")": + answer = answer[1:-1] + # Convert mathematical notation 'i' to Python's 'j' for complex numbers answer = answer.replace("i", "j") diff --git a/reasoning_gym/algebra/intermediate_integration.py b/reasoning_gym/algebra/intermediate_integration.py index c967c7be..de664450 100644 --- a/reasoning_gym/algebra/intermediate_integration.py +++ b/reasoning_gym/algebra/intermediate_integration.py @@ -77,9 +77,10 @@ def __init__(self, config: IntermediateIntegrationConfig): "Evaluate the indefinite integral: ∫ {integrand} dx", ] self.added_instruction = """ -In addition, when doing calculation, use the following instructions together with your mathematical ingenuity to solve the integral problems -## 1. Use ** instead ^ to represent powers. For example 7*X**2 instead of 7*X^2. -## 2. Always use * when doing all sorts of multiplcation in your reasoning steps. For example Use [-3*X**3*sin(X) - 9*X**2*cos(X) + 18*X*sin(X) + 18*cos(X) + C] instead of [-3x3sin(x) - 9x2cos(x) + 18xsin(x) + 18cos(x) + C]. +When performing calculations, please follow these guidelines: +1. Use ** instead of ^ to represent exponents. For example, write 7*X**2 instead of 7*X^2. +2. Always include the * symbol for all multiplication operations in your reasoning steps. For example, write `-3*X**3*sin(X) - 9*X**2*cos(X) + 18*X*sin(X) + 18*cos(X) + C` instead of `-3x3sin(x) - 9x2cos(x) + 18xsin(x) + 18cos(x) + C`. +3. Use `exp(x)` or `E**(x)` for the exponential function (i.e. use capital E for Euler's number). """ def _get_outer_constant(self, rng: random.Random) -> int: @@ -245,7 +246,7 @@ def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: """Determine if the solution provided solves the problem""" reward = 0.0 metadata = entry["metadata"] - if answer is not None: + if isinstance(answer, str): try: var = metadata["variable"] x = sympy.Symbol(var) @@ -258,12 +259,8 @@ def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: # Check mathematical equivalence through simplification if sympy.simplify(derivative - integrand) == 0: reward = 1.0 - elif answer.strip(): - reward = 0.05 - else: - reward = 0.01 except: - reward = 0.01 + reward = 0.0 return reward diff --git a/reasoning_gym/algebra/polynomial_equations.py b/reasoning_gym/algebra/polynomial_equations.py index c069c32f..ac054427 100644 --- a/reasoning_gym/algebra/polynomial_equations.py +++ b/reasoning_gym/algebra/polynomial_equations.py @@ -27,8 +27,9 @@ class PolynomialEquationsConfig: seed: Optional[int] = None size: int = 500 # reward function hyperparameters - penalty_missing_factor = 0.1 - penalty_extra_factor = 0.05 + penalty_missing_factor = 0.5 + penalty_extra_factor = 0.5 + exp_distance_factor = -10.0 def validate(self) -> None: """Validate configuration parameters.""" @@ -62,12 +63,15 @@ def __init__(self, config: PolynomialEquationsConfig): "Solve the polynomial equation for real {variable}:\n{polynomial_expanded} = 0", ] self.added_instruction = """ -In solving the equations, please abide by the following instruction: -## 1. All answers should be comma-separated. For example "-0.3773, 0.4005" etc. -## 2. In cases where your answer is b = 2 + sqrt(4560) / 172 and b = 2 - sqrt(4560) / 172. Since b can be 2 numbers, resolve your answer like this instead, "-0.3773, 0.4005". -## 3. If there are no real values of i that satisfy the equation, report your answer as empty string, "". -## 4. If there are 2 answers, resolve the answers as comma-separated floats of 2 numbers, if 3 answers, make it comma-separated floats of 3 numbers. -## 5. Resolve all numbers as floats in the string of comma-separated numbers. Round the floats higher than 4 decimal place(d.p) down to 4 d.p. +In solving equations, please follow these instructions: +1. Provide all answers as comma-separated decimal values. For example: "-0.3773, 0.4005" +2. For solutions that can be expressed in exact form (like "u = 2 + sqrt(4560)/172" and "u = 2 - sqrt(4560)/172"), convert them to decimal form in your final answer. +3. If there are no real values that satisfy the equation, report your answer as an empty string: "" +4. Format your answer based on the number of solutions: + - For 1 solution: a single decimal number + - For 2 solutions: two comma-separated decimal numbers + - For 3 or more solutions: all values as comma-separated decimal numbers +5. Round all decimal values to 4 decimal places (rounding down when the 5th decimal place is 5 or greater). """ super().__init__(config=config, seed=config.seed, size=config.size) @@ -238,7 +242,7 @@ def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: # Remove matched oracle solution oracle_solutions.pop(matched_distance_index) # Exponential decay reward - total_reward += math.exp(-matched_distance) + total_reward += math.exp(matched_distance * self.config.exp_distance_factor) else: # Extra predicted solution extra_solutions += 1 diff --git a/reasoning_gym/algebra/polynomial_multiplication.py b/reasoning_gym/algebra/polynomial_multiplication.py index 09e2530a..9a0d51dd 100644 --- a/reasoning_gym/algebra/polynomial_multiplication.py +++ b/reasoning_gym/algebra/polynomial_multiplication.py @@ -69,9 +69,9 @@ def __init__(self, config: PolynomialMultiplicationConfig): "Calculate the following: {polynomial_expr}", ] self.added_instruction = """ -In addition, When doing calculation, Use the following instructions together with your mathematical ingenuity to solve the integral problems -## 1. Use ** instead ^ to represent powers. For example 7*X**2 instead of 7*X^2. -## 2. Always use * when doing all sorts of multiplcation in your reasoning steps and even in reporting answers. +When performing calculations, please follow these guidelines: +1. Use ** instead of ^ to represent exponents. For example, write 7*X**2 instead of 7*X^2. +2. Always include the * symbol for all multiplication operations in your reasoning steps. For example, write `-3*X**3*sin(X) - 9*X**2*cos(X) + 18*X*sin(X) + 18*cos(X) + C` instead of `-3x3sin(x) - 9x2cos(x) + 18xsin(x) + 18cos(x) + C`. """ super().__init__(config=config, seed=config.seed, size=config.size) @@ -106,10 +106,9 @@ def __getitem__(self, idx: int) -> dict: return { "question": question, - "answer": product, + "answer": str(product), "metadata": { "polynomial_expr": str(polynomial_expr), - "result": str(product), "variables": list(product.free_symbols), }, } @@ -147,21 +146,16 @@ def _generate_polynomial(self, rng: random.Random, monomials: Optional[list]): def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: reward = 0.0 - metadata = entry["metadata"] if answer is not None: try: predicted_poly = sp.parse_expr(answer) - target_poly = sp.parse_expr(metadata["result"]) + target_poly = sp.parse_expr(entry["answer"]) # Check if the difference simplifies to zero (i.e. they are equivalent). if predicted_poly == target_poly: reward = 1.0 - elif answer.strip(): - reward = 0.05 - else: - reward = 0.01 except Exception: - reward = 0.01 + reward = 0.0 return reward diff --git a/reasoning_gym/algebra/simple_integration.py b/reasoning_gym/algebra/simple_integration.py index a45dc3ff..321a7a9d 100644 --- a/reasoning_gym/algebra/simple_integration.py +++ b/reasoning_gym/algebra/simple_integration.py @@ -42,9 +42,9 @@ def __init__(self, config: SimpleIntegrationConfig): "Evaluate the indefinite integral: ∫ {integrand} dx", ] self.added_instruction = """ -In addition, When doing calculation, Use the following instructions together with your mathematical ingenuity to solve the integral problems -## 1. Use ** instead ^ to represent powers. For example 7*X**2 instead of 7*X^2. -## 2. Always use * when doing all sorts of multiplcation in your reasoning steps. For example Use [-3*X**3*sin(X) - 9*X**2*cos(X) + 18*X*sin(X) + 18*cos(X) + C] instead of [-3x3sin(x) - 9x2cos(x) + 18xsin(x) + 18cos(x) + C]. +When performing calculations, please follow these guidelines: +1. Use ** instead of ^ to represent exponents. For example, write 7*X**2 instead of 7*X^2. +2. Always include the * symbol for all multiplication operations in your reasoning steps. For example, write `-3*X**3*sin(X) - 9*X**2*cos(X) + 18*X*sin(X) + 18*cos(X) + C` instead of `-3x3sin(x) - 9x2cos(x) + 18xsin(x) + 18cos(x) + C`. """ super().__init__(config=config, seed=config.seed, size=config.size) @@ -103,12 +103,8 @@ def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: # Check mathematical equivalence through simplification if sympy.simplify(derivative - integrand) == 0: reward = 1.0 - elif answer.strip(): - reward = 0.05 - else: - reward = 0.01 except: - reward = 0.01 + reward = 0.0 return reward diff --git a/reasoning_gym/algorithmic/ab.py b/reasoning_gym/algorithmic/ab.py index 9ec679af..3e251d31 100644 --- a/reasoning_gym/algorithmic/ab.py +++ b/reasoning_gym/algorithmic/ab.py @@ -130,12 +130,9 @@ def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: float: The computed score between 0.0 and 1.0. """ - if answer == None: - return 0.0 - if answer != entry["answer"]: - return 0.01 - else: + if answer == entry["answer"]: return 1.0 # Yay + return 0.0 # Register the dataset diff --git a/reasoning_gym/algorithmic/binary_matrix.py b/reasoning_gym/algorithmic/binary_matrix.py index 4584a7fb..772cde97 100644 --- a/reasoning_gym/algorithmic/binary_matrix.py +++ b/reasoning_gym/algorithmic/binary_matrix.py @@ -108,9 +108,9 @@ def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: # check if answer is python list of lists answer = self._matrix_to_str(eval(answer)) if answer == oracle_answer: - return 0.5 - except Exception as e: - return 0.01 + return 0.1 + except Exception: + return 0.0 return 0.0 def __getitem__(self, idx: int) -> dict: diff --git a/reasoning_gym/algorithmic/cryptarithm.py b/reasoning_gym/algorithmic/cryptarithm.py index f0946278..fe62b743 100644 --- a/reasoning_gym/algorithmic/cryptarithm.py +++ b/reasoning_gym/algorithmic/cryptarithm.py @@ -200,7 +200,7 @@ def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: Returns: float: The computed score between 0.0 and 1.0. """ - if not answer: + if not isinstance(answer, str): return 0.0 correct_mapping = {} diff --git a/reasoning_gym/algorithmic/game_of_life.py b/reasoning_gym/algorithmic/game_of_life.py index d75f57c0..83c391e7 100644 --- a/reasoning_gym/algorithmic/game_of_life.py +++ b/reasoning_gym/algorithmic/game_of_life.py @@ -106,7 +106,7 @@ def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: ans_arr = json.loads(answer) correct_arr = json.loads(entry["answer"]) except Exception: - return 0.01 + return 0.0 total_cells = 0 correct_cells = 0 diff --git a/reasoning_gym/algorithmic/graph_color.py b/reasoning_gym/algorithmic/graph_color.py index 0e469bc0..fa34b63c 100644 --- a/reasoning_gym/algorithmic/graph_color.py +++ b/reasoning_gym/algorithmic/graph_color.py @@ -228,12 +228,13 @@ def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: try: danswer = json.loads(answer) solved, failure = verify_graph_coloring_solution(entry["metadata"]["puzzle"], danswer) - if not solved: - return 0.01 # json was parsable but solution incorrect - else: + if solved: return 1.0 # Yay + else: + return 0.01 # json parsable except Exception: - return 0.0 + pass + return 0.0 register_dataset("graph_color", GraphColorDataset, GraphColorConfig) diff --git a/reasoning_gym/algorithmic/group_anagrams.py b/reasoning_gym/algorithmic/group_anagrams.py index b6630ac0..6c17eae5 100644 --- a/reasoning_gym/algorithmic/group_anagrams.py +++ b/reasoning_gym/algorithmic/group_anagrams.py @@ -95,7 +95,7 @@ def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: if answer_str == oracle_str: reward = 1.0 else: - reward = 0.01 + reward = 0.01 # json parsable except Exception: reward = 0.0 return reward diff --git a/reasoning_gym/algorithmic/jugs.py b/reasoning_gym/algorithmic/jugs.py index 134ac068..0dc7210c 100644 --- a/reasoning_gym/algorithmic/jugs.py +++ b/reasoning_gym/algorithmic/jugs.py @@ -303,11 +303,11 @@ def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: danswer = json.loads(answer) valid, _ = verify_solution(entry["metadata"]["puzzle"], danswer) if not valid: - return 0.01 + return 0.01 # json parsable else: return 1.0 # Yay except Exception as e: - return 0.01 + return 0.0 register_dataset("jugs", JugsDataset, JugsConfig) diff --git a/reasoning_gym/algorithmic/letter_jumble.py b/reasoning_gym/algorithmic/letter_jumble.py index 5917e55c..2cb3fc08 100644 --- a/reasoning_gym/algorithmic/letter_jumble.py +++ b/reasoning_gym/algorithmic/letter_jumble.py @@ -116,7 +116,7 @@ def partial(self, expected_answer, model_answer): # Each word in the expected answer is worth an equal fraction of 1.0 total_words = len(expected_words) - score_per_word = 1.0 / total_words if total_words else 0 + score_per_word = 1.0 / total_words if total_words > 0 else 0 # Calculate scores word by word scores = [] @@ -142,18 +142,16 @@ def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: float: The computed score between 0.0 and 1.0. """ - if not answer: + if not isinstance(answer, str): return 0.0 oracle_answer = entry["answer"].strip().lower() - if answer: - answer = answer.strip().lower() - if answer == oracle_answer: - return 1.0 # Perfect score! - else: - partial_score = self.partial(oracle_answer, answer) - return partial_score - return 0.01 + answer = answer.strip().lower() + if answer == oracle_answer: + return 1.0 # Perfect score! + else: + partial_score = self.partial(oracle_answer, answer) + return partial_score register_dataset("letter_jumble", LetterJumbleDataset, LetterJumbleConfig) diff --git a/reasoning_gym/algorithmic/manipulate_matrix.py b/reasoning_gym/algorithmic/manipulate_matrix.py index ab0bf592..ec964094 100644 --- a/reasoning_gym/algorithmic/manipulate_matrix.py +++ b/reasoning_gym/algorithmic/manipulate_matrix.py @@ -144,8 +144,6 @@ def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: if oracle_answer in answer: return len(oracle_answer) / len(answer) - else: - return 0.01 return 0.0 diff --git a/reasoning_gym/algorithmic/palindrome_generation.py b/reasoning_gym/algorithmic/palindrome_generation.py index 2f7b5fb5..05e566ed 100644 --- a/reasoning_gym/algorithmic/palindrome_generation.py +++ b/reasoning_gym/algorithmic/palindrome_generation.py @@ -92,14 +92,14 @@ def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: - Correct answer (palindrome with only correct letters in the correct quantities) gives 1.0 - An answer that is a palindrome, but not with the same letters as provided, gives 0.05 - An answer that is a string, but not a palindrome gives 0.02 - - An empty string gives 0.01. + - An empty string gives 0.0 - None gives 0.0. """ if answer is None or not isinstance(answer, str): return 0.0 # No answer given if answer == "": - return 0.01 + return 0.0 metadata = entry["metadata"] answer = answer.strip().lower() diff --git a/reasoning_gym/algorithmic/palindrome_partitioning.py b/reasoning_gym/algorithmic/palindrome_partitioning.py index 8a0c07b3..067a19e7 100644 --- a/reasoning_gym/algorithmic/palindrome_partitioning.py +++ b/reasoning_gym/algorithmic/palindrome_partitioning.py @@ -95,9 +95,8 @@ def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: oracle = self.to_set_of_tuples(entry["metadata"]["solution"]) if answer == oracle: return 1.0 - return 0.01 except Exception: - return 0.0 + pass return 0.0 def _generate_palindrome_letters(self, rng: Random, length: int) -> list[str]: diff --git a/reasoning_gym/algorithmic/pool_matrix.py b/reasoning_gym/algorithmic/pool_matrix.py index bf839ad4..e22f9a92 100644 --- a/reasoning_gym/algorithmic/pool_matrix.py +++ b/reasoning_gym/algorithmic/pool_matrix.py @@ -80,7 +80,7 @@ def _average_pool(self, matrix: np.ndarray, pool_size: int) -> np.ndarray: def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: """Score the answer based on the metadata""" - if not answer: + if not isinstance(answer, str): return 0.0 reward = 0.0 @@ -91,8 +91,6 @@ def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: reward = 1.0 elif oracle_answer.shape == answer.shape: reward = 0.1 - else: - reward = 0.01 except Exception: pass return reward diff --git a/reasoning_gym/algorithmic/ransom_note.py b/reasoning_gym/algorithmic/ransom_note.py index 2a1826a0..cf163467 100644 --- a/reasoning_gym/algorithmic/ransom_note.py +++ b/reasoning_gym/algorithmic/ransom_note.py @@ -108,14 +108,12 @@ def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: float: The computed score between 0.0 and 1.0. """ - if answer == None: - return 0.0 + if isinstance(answer, str): + s_answer = answer.strip() + if s_answer == str(entry["answer"]): + return 1.0 - s_answer = answer.strip() - if not s_answer == str(entry["answer"]): - return 0.01 - else: - return 1.0 + return 0.0 register_dataset("ransom_note", RansomNoteDataset, RansomNoteConfig) diff --git a/reasoning_gym/algorithmic/sentence_reordering.py b/reasoning_gym/algorithmic/sentence_reordering.py index 9caae762..0cfbaaaf 100644 --- a/reasoning_gym/algorithmic/sentence_reordering.py +++ b/reasoning_gym/algorithmic/sentence_reordering.py @@ -110,7 +110,7 @@ def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: else: reward = 0.05 except: - reward = 0.01 + reward = 0.0 return reward diff --git a/reasoning_gym/algorithmic/spell_backward.py b/reasoning_gym/algorithmic/spell_backward.py index 57825d84..bf33441b 100644 --- a/reasoning_gym/algorithmic/spell_backward.py +++ b/reasoning_gym/algorithmic/spell_backward.py @@ -52,14 +52,14 @@ def __getitem__(self, idx: int) -> dict: def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: reward = 0.0 expected_answer = entry["answer"] - if answer is not None: + if isinstance(answer, str): try: if expected_answer.lower() == answer.lower(): reward = 1.0 else: reward = 0.05 except: - reward = 0.01 + reward = 0.0 return reward diff --git a/reasoning_gym/algorithmic/spiral_matrix.py b/reasoning_gym/algorithmic/spiral_matrix.py index 63492a37..f895c628 100644 --- a/reasoning_gym/algorithmic/spiral_matrix.py +++ b/reasoning_gym/algorithmic/spiral_matrix.py @@ -126,11 +126,9 @@ def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: try: answer = " ".join(str(item) for item in eval(answer)) if answer == oracle_answer: - return 0.5 - else: - return 0.01 - except Exception as e: - return 0.01 + return 0.1 + except Exception: + pass return 0.0 diff --git a/reasoning_gym/algorithmic/string_insertion.py b/reasoning_gym/algorithmic/string_insertion.py index 1d597364..0dafe8f4 100644 --- a/reasoning_gym/algorithmic/string_insertion.py +++ b/reasoning_gym/algorithmic/string_insertion.py @@ -75,7 +75,7 @@ def _get_answer(self, string: str) -> str: def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: """Overwrite this method in derived classes if a single oracle answer is not available.""" oracle_answer = entry["answer"] - if answer is not None: + if isinstance(answer, str): if answer == oracle_answer: return 1.0 else: @@ -83,9 +83,9 @@ def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: # check if answer is python list of characters answer = "".join(eval(answer)) if answer == oracle_answer: - return 0.5 - except Exception as e: - return 0.01 + return 0.1 + except Exception: + pass return 0.0 def __getitem__(self, idx: int) -> dict: diff --git a/reasoning_gym/algorithmic/word_ladder.py b/reasoning_gym/algorithmic/word_ladder.py index 18670d4b..d15c7be6 100644 --- a/reasoning_gym/algorithmic/word_ladder.py +++ b/reasoning_gym/algorithmic/word_ladder.py @@ -221,8 +221,8 @@ def __getitem__(self, idx: int) -> dict: } def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: - if answer is None: - return 0 + if not isinstance(answer, str): + return 0.0 answer_words = tuple(s.strip() for s in answer.upper().split(",")) @@ -239,17 +239,17 @@ def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: # 4. all words are in our vocabulary if len(answer_words) < 2: - return 0 + return 0.0 if answer_words[0] != start_word or answer_words[-1] != end_word: - return 0.01 + return 0.0 if not all(len(w) == word_length for w in answer_words): - return 0.01 + return 0.0 for i in range(1, len(answer_words)): if sum(1 for a, b in zip(answer_words[i - 1], answer_words[i]) if a != b) != 1: - return 0.01 + return 0.0 reward = 1.0 for word in answer_words: diff --git a/reasoning_gym/algorithmic/word_sorting.py b/reasoning_gym/algorithmic/word_sorting.py index fc65c976..d246bd5a 100644 --- a/reasoning_gym/algorithmic/word_sorting.py +++ b/reasoning_gym/algorithmic/word_sorting.py @@ -121,8 +121,6 @@ def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: return 1.0 elif sorted(parsed_answer) == oracle_answer: return 0.2 - else: - return 0.01 return 0.0 diff --git a/reasoning_gym/arc/arc_agi.py b/reasoning_gym/arc/arc_agi.py index 98c3f000..a19a6507 100644 --- a/reasoning_gym/arc/arc_agi.py +++ b/reasoning_gym/arc/arc_agi.py @@ -199,7 +199,7 @@ def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: else: reward = 0.05 except: - reward = 0.01 + reward = 0.0 return reward diff --git a/reasoning_gym/arc/rearc.py b/reasoning_gym/arc/rearc.py index fe4a9aa5..dabbe808 100644 --- a/reasoning_gym/arc/rearc.py +++ b/reasoning_gym/arc/rearc.py @@ -106,7 +106,7 @@ def score_answer(self, answer: str, entry: dict[str, Any]) -> float: else: reward = 0.05 except: - reward = 0.01 + reward = 0.0 return reward diff --git a/reasoning_gym/arithmetic/bitwise_arithmetic.py b/reasoning_gym/arithmetic/bitwise_arithmetic.py index a5fcf79a..97de426b 100644 --- a/reasoning_gym/arithmetic/bitwise_arithmetic.py +++ b/reasoning_gym/arithmetic/bitwise_arithmetic.py @@ -160,17 +160,15 @@ def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: Returns: float: 1.0 if the user's answer is correct; otherwise, 0.01 unless no answer is provided, in which case 0. """ - if answer is None: - return 0.0 - - try: - solved = verify_solution(entry["metadata"]["problem"], answer) - if solved: - return 1.0 - except Exception: - return 0.01 - - return 0.01 + if isinstance(answer, str): + try: + solved = verify_solution(entry["metadata"]["problem"], answer) + if solved: + return 1.0 + except Exception: + pass + + return 0.0 # Register the dataset with the factory. diff --git a/reasoning_gym/arithmetic/calendar_arithmetic.py b/reasoning_gym/arithmetic/calendar_arithmetic.py index 9b307001..7ed0e84b 100644 --- a/reasoning_gym/arithmetic/calendar_arithmetic.py +++ b/reasoning_gym/arithmetic/calendar_arithmetic.py @@ -428,7 +428,7 @@ def _random_date_between(self, rng: random.Random, start_date: date, end_date: d def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: # we suppose the answer is the last occurence of the expected answer type - if answer is None: + if not isinstance(answer, str) or len(answer) == 0: return 0.0 oracle_answer = entry["answer"] @@ -439,9 +439,6 @@ def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: CalendarTask.WEEKDAY_OF_DATE_FROM_FIRST_DATE.value, CalendarTask.WEEKDAY_OF_DATE.value, }: - if not answer: - return 0.0 - answer = answer.strip() oracle_answer = oracle_answer weekdays = {d.name.title() for d in Weekday} diff --git a/reasoning_gym/arithmetic/decimal_arithmetic.py b/reasoning_gym/arithmetic/decimal_arithmetic.py index bf00c2af..86c4f5ff 100644 --- a/reasoning_gym/arithmetic/decimal_arithmetic.py +++ b/reasoning_gym/arithmetic/decimal_arithmetic.py @@ -178,7 +178,7 @@ def __getitem__(self, idx: int) -> dict[str, Any]: + problem_str ) - return {"question": problem_str, "answer": answer, "metadata": {}} + return {"question": problem_str, "answer": str(answer), "metadata": {}} def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: """ @@ -189,12 +189,12 @@ def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: Returns: float: 1.0 if the user's answer is within tolerance; otherwise, 0.01. """ - if answer is None: + if not isinstance(answer, str): return 0.0 try: user_ans: Decimal = Decimal(answer) - correct_ans: Decimal = entry["answer"] + correct_ans: Decimal = Decimal(entry["answer"]) # Determine tolerance based on the desired precision. precision: int = self.config.max_num_decimal_places @@ -202,9 +202,9 @@ def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: if abs(user_ans - correct_ans) <= tol: return 1.0 except Exception: - return 0.01 + pass - return 0.01 + return 0.0 # Register the dataset with the factory. diff --git a/reasoning_gym/arithmetic/decimal_chain_sum.py b/reasoning_gym/arithmetic/decimal_chain_sum.py index d2313789..55f9b411 100644 --- a/reasoning_gym/arithmetic/decimal_chain_sum.py +++ b/reasoning_gym/arithmetic/decimal_chain_sum.py @@ -1,6 +1,6 @@ import random from dataclasses import dataclass -from decimal import Decimal +from decimal import Decimal, InvalidOperation from typing import Any, Optional from ..factory import ProceduralDataset, register_dataset @@ -129,7 +129,11 @@ def _generate_task(self, rng: random.Random, num_terms: int, min_value: int, max result -= c expression = " ".join(expression_parts) - result = result.quantize(Decimal(f"0.{'0' * max(decimal_places)}")) + try: + q = Decimal(f"0.{'0' * max(decimal_places)}") + result = result.quantize(q) + except InvalidOperation: + pass return expression, result def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: @@ -141,16 +145,19 @@ def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: Returns: 1.0 for exact numerical match, 0.01 otherwise """ - if answer is None or len(answer.strip()) == 0: + if not isinstance(answer, str) or len(answer.strip()) == 0: return 0.0 try: student_answer = Decimal(answer.strip()) oracle_answer = Decimal(entry["answer"]) - return 1.0 if student_answer == oracle_answer else 0.01 - except (ValueError, TypeError, ArithmeticError): - return 0.01 + if student_answer == oracle_answer: + return 1.0 + except Exception: + pass + + return 0.0 register_dataset("decimal_chain_sum", DecimalChainSumDataset, DecimalChainSumConfig) diff --git a/reasoning_gym/arithmetic/dice.py b/reasoning_gym/arithmetic/dice.py index 0dcf3e44..00cc6b7d 100644 --- a/reasoning_gym/arithmetic/dice.py +++ b/reasoning_gym/arithmetic/dice.py @@ -138,12 +138,11 @@ def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: float: The computed score between 0.0 and 1.0. """ - if answer == None: - return 0.0 - if answer.lower().replace("\n", "") != entry["answer"].lower().replace("\n", ""): - return 0.01 - else: - return 1.0 # Yay + if isinstance(answer, str): + if answer.lower().replace("\n", "") == entry["answer"].lower().replace("\n", ""): + return 1.0 # Yay + + return 0.0 register_dataset("dice", DiceDataset, DiceConfig) diff --git a/reasoning_gym/arithmetic/number_format.py b/reasoning_gym/arithmetic/number_format.py index 0c2b79d5..a85bf7cb 100644 --- a/reasoning_gym/arithmetic/number_format.py +++ b/reasoning_gym/arithmetic/number_format.py @@ -65,14 +65,13 @@ def _transform_candidates(self, rng: Random, candidates: list[float]) -> list[st def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: """Overwrite this method in derived classes if a single oracle answer is not available.""" oracle_answer = entry["metadata"]["solution"] - if answer is not None and len(answer) > 0: + if isinstance(answer, str) and len(answer) > 0: try: answer = float(answer.strip().replace(",", "")) if abs(answer - oracle_answer) < 1e-2: return 1.0 - return 0.01 except: - return 0.0 + pass return 0.0 def __getitem__(self, idx: int) -> dict: diff --git a/reasoning_gym/arithmetic/power_function.py b/reasoning_gym/arithmetic/power_function.py index 879f4848..5d5848c0 100644 --- a/reasoning_gym/arithmetic/power_function.py +++ b/reasoning_gym/arithmetic/power_function.py @@ -44,10 +44,8 @@ def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: return 1.0 elif difference < 1e-1: return 0.5 - else: - return 0.01 - except Exception as e: - return 0.01 + except Exception: + pass return 0.0 def __getitem__(self, idx: int) -> dict: diff --git a/reasoning_gym/arithmetic/time_intervals.py b/reasoning_gym/arithmetic/time_intervals.py index 58e30cba..eb23b232 100644 --- a/reasoning_gym/arithmetic/time_intervals.py +++ b/reasoning_gym/arithmetic/time_intervals.py @@ -246,7 +246,7 @@ def score_answer(self, answer: Optional[str], entry: dict) -> float: Returns a score between 0 and 1, with partial credit for answers that are close to correct in the appropriate units/format """ - if not answer: + if not isinstance(answer, str): return 0.0 expected = entry["answer"] diff --git a/reasoning_gym/code/bf.py b/reasoning_gym/code/bf.py index 3a7d9b0d..fe16f3c8 100644 --- a/reasoning_gym/code/bf.py +++ b/reasoning_gym/code/bf.py @@ -121,20 +121,23 @@ def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: float: The computed score between 0.0 and 1.0. """ - if answer == None: + if not isinstance(answer, str): return 0.0 - if answer != entry["answer"]: - if entry["answer"] in answer.splitlines(): - # We can be quite confident that the correct answer was given - # It was likely just given alongside an explanation - return max(0.9 * len(answer) / len(entry["answer"]), 0.1) - if entry["answer"] in answer: - # Since answers are English words, some risk of the response coincidentally containing the answer - return max(0.5 * len(answer) / len(entry["answer"]), 0.1) - return 0.01 - else: + + if answer == entry["answer"]: return 1.0 # Yay + if entry["answer"] in answer.splitlines(): + # We can be quite confident that the correct answer was given + # It was likely just given alongside an explanation + return max(0.9 * len(answer) / len(entry["answer"]), 0.1) + + if entry["answer"] in answer: + # Since answers are English words, some risk of the response coincidentally containing the answer + return max(0.5 * len(answer) / len(entry["answer"]), 0.1) + + return 0.0 + # Register the dataset register_dataset("bf", BFDataset, BFConfig) diff --git a/reasoning_gym/cognition/figlet_fonts.py b/reasoning_gym/cognition/figlet_fonts.py index 4c16dec2..f4150fd5 100644 --- a/reasoning_gym/cognition/figlet_fonts.py +++ b/reasoning_gym/cognition/figlet_fonts.py @@ -182,7 +182,7 @@ def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: """ correct_word = entry["answer"] - if not answer: + if not isinstance(answer, str): return 0.0 # No answer given # Normalize case diff --git a/reasoning_gym/cognition/needle_haystack.py b/reasoning_gym/cognition/needle_haystack.py index f83318b8..e5adf741 100644 --- a/reasoning_gym/cognition/needle_haystack.py +++ b/reasoning_gym/cognition/needle_haystack.py @@ -110,19 +110,17 @@ def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: Returns: float: The computed score between 0.0 and 1.0. """ + if isinstance(answer, str): + correct_word = entry["answer"] - correct_word = entry["answer"] - if not answer: - return 0.0 # No answer given + # Normalize case + answer = answer.replace(" ", "").strip().lower() + correct_word = correct_word.strip().lower() - # Normalize case - answer = answer.replace(" ", "").strip().lower() - correct_word = correct_word.strip().lower() + if answer == correct_word: + return 1.0 # Correct! - if answer == correct_word: - return 1.0 # Correct! - - return 0.01 + return 0.0 # Register the dataset diff --git a/reasoning_gym/cognition/rectangle_count.py b/reasoning_gym/cognition/rectangle_count.py index b258b615..6f096e94 100644 --- a/reasoning_gym/cognition/rectangle_count.py +++ b/reasoning_gym/cognition/rectangle_count.py @@ -132,12 +132,10 @@ def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: float: The computed score between 0.0 and 1.0. """ - if answer == None: - return 0.0 - if answer.lower().replace("\n", "") != entry["answer"].lower().replace("\n", ""): - return 0.01 - else: - return 1.0 # Yay + if isinstance(answer, str): + if answer.lower().replace("\n", "") == entry["answer"].lower().replace("\n", ""): + return 1.0 # Yay + return 0.0 register_dataset("rectangle_count", RectangleCountDataset, RectangleCountConfig) diff --git a/reasoning_gym/dataset.py b/reasoning_gym/dataset.py index d727c863..29aea330 100644 --- a/reasoning_gym/dataset.py +++ b/reasoning_gym/dataset.py @@ -69,9 +69,6 @@ def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: reward = 1.0 elif oracle_answer in answer: reward = len(oracle_answer) / len(answer) - else: - reward = 0.01 - return reward diff --git a/reasoning_gym/games/contrib/sokoban/levels/lvl0.dat b/reasoning_gym/games/contrib/sokoban/levels/lvl0.dat deleted file mode 100644 index 867d112a..00000000 --- a/reasoning_gym/games/contrib/sokoban/levels/lvl0.dat +++ /dev/null @@ -1,10 +0,0 @@ -+ + + + + + + -+ - * - - - + -+ - - - $ - + -+ X - - @ - + -+ - - - - - + -+ $ - + - - + -+ + - - - - + -+ X @ - $ - + -+ + - - - - + -+ + + + + + + diff --git a/reasoning_gym/games/contrib/sokoban/levels/lvl1.dat b/reasoning_gym/games/contrib/sokoban/levels/lvl1.dat deleted file mode 100644 index 9ba48c31..00000000 --- a/reasoning_gym/games/contrib/sokoban/levels/lvl1.dat +++ /dev/null @@ -1,5 +0,0 @@ -+ + + + + + + -+ * - @ - X + -+ + - @ - + + -+ X - - - - + -+ + + + + + + diff --git a/reasoning_gym/games/contrib/sokoban/levels/lvl2.dat b/reasoning_gym/games/contrib/sokoban/levels/lvl2.dat deleted file mode 100644 index 46755810..00000000 --- a/reasoning_gym/games/contrib/sokoban/levels/lvl2.dat +++ /dev/null @@ -1,6 +0,0 @@ -- - + + + + + + -- + + - - - * + -+ + - - - + X + -+ X - @ - @ @ + -+ X X @ - - - + -+ + + + + + + + diff --git a/reasoning_gym/games/contrib/sokoban/levels/lvl3.dat b/reasoning_gym/games/contrib/sokoban/levels/lvl3.dat deleted file mode 100644 index 9d0bc599..00000000 --- a/reasoning_gym/games/contrib/sokoban/levels/lvl3.dat +++ /dev/null @@ -1,7 +0,0 @@ -- + + + + + + - - - -- + X - - X + - - - -+ + - @ @ + + - - - -+ - - - - + + - - - -+ - @ - - * + + + + -+ + - - - - - - X + -- + + + + + + + + + diff --git a/reasoning_gym/games/contrib/sokoban/levels/lvl4.dat b/reasoning_gym/games/contrib/sokoban/levels/lvl4.dat deleted file mode 100644 index 42fbc6eb..00000000 --- a/reasoning_gym/games/contrib/sokoban/levels/lvl4.dat +++ /dev/null @@ -1,7 +0,0 @@ -- + + + + + + - - -+ + X - @ - + + + -+ - - - - - - - + -+ - @ + + X - @ + -+ - - - @ - + - + -+ + + * - X - X + -- - + + + + + + + diff --git a/reasoning_gym/games/contrib/sokoban/levels/lvl5.dat b/reasoning_gym/games/contrib/sokoban/levels/lvl5.dat deleted file mode 100644 index 3a096d58..00000000 --- a/reasoning_gym/games/contrib/sokoban/levels/lvl5.dat +++ /dev/null @@ -1,7 +0,0 @@ -- + + + + + + + - -+ + - - + - - + + -+ - @ - - - @ - + -+ - - X * X - - + -+ + @ + + - - + + -+ - - X - - - + - -+ + + + + + + + - diff --git a/reasoning_gym/games/contrib/sokoban/levels/lvl6.dat b/reasoning_gym/games/contrib/sokoban/levels/lvl6.dat deleted file mode 100644 index 32ee5bbc..00000000 --- a/reasoning_gym/games/contrib/sokoban/levels/lvl6.dat +++ /dev/null @@ -1,9 +0,0 @@ -- - - + + + + + + + + -- - - + - - - - - - + -- - + + - - - - @ - + -- + + - - + + - + + + -+ + - - + - - X - - + -+ - - + X @ @ - - + + -+ * + X - - - - + + - -+ + - - - - - + + - - -+ + + + + + + + - - - diff --git a/reasoning_gym/games/contrib/sokoban/levels/lvl7.dat b/reasoning_gym/games/contrib/sokoban/levels/lvl7.dat deleted file mode 100644 index 9c2fe302..00000000 --- a/reasoning_gym/games/contrib/sokoban/levels/lvl7.dat +++ /dev/null @@ -1,6 +0,0 @@ -+ + + + + + + + -+ - - @ - X * + -+ - @ - - + X + -+ X X @ - @ @ + -+ X X @ - - - + -+ + + + + + + + diff --git a/reasoning_gym/games/contrib/sokoban/src/astar.py b/reasoning_gym/games/contrib/sokoban/src/astar.py index 25d1e63d..10b185f6 100644 --- a/reasoning_gym/games/contrib/sokoban/src/astar.py +++ b/reasoning_gym/games/contrib/sokoban/src/astar.py @@ -13,7 +13,7 @@ ) -def astar(matrix, player_pos, debug=False, heuristic="manhattan"): +def astar(matrix, player_pos, debug: bool = False, heuristic: str = "manhattan", max_depth: int = 100): # print(f'A* - {heuristic.title()} Heuristic') heur = "[A*]" if heuristic == "manhattan" else "[Dijkstra]" shape = matrix.shape @@ -67,15 +67,18 @@ def astar(matrix, player_pos, debug=False, heuristic="manhattan"): return (path + direction[move], depth + 1) if debug: print(f"{heur} Solution Depth: {depth + 1}\n{path + direction[move]}", 20) - print(f"{heur} Solution not found!\n") + + if depth > max_depth: + break + if debug: print(f"{heur} Solution Not Found!\nDepth {depth + 1}", 20) return (None, -1 if not heap else depth + 1) -def solve_astar(puzzle, visualizer=False, heuristic="manhattan"): +def solve_astar(puzzle, visualizer: bool = False, heuristic: str = "manhattan", max_depth: int = 100): matrix = puzzle where = np.where((matrix == "*") | (matrix == "%")) player_pos = where[0][0], where[1][0] - return astar(matrix, player_pos, debug=visualizer, heuristic=heuristic) + return astar(matrix, player_pos, debug=visualizer, heuristic=heuristic, max_depth=max_depth) diff --git a/reasoning_gym/games/contrib/sokoban/src/game.py b/reasoning_gym/games/contrib/sokoban/src/game.py index f01f3720..be50b52c 100644 --- a/reasoning_gym/games/contrib/sokoban/src/game.py +++ b/reasoning_gym/games/contrib/sokoban/src/game.py @@ -29,8 +29,7 @@ def __str__(self) -> str: class Game: - def __init__(self, width=19, height=10, level=None, path=None): - self.level = level + def __init__(self, width=19, height=10, path=None): self.width = width self.height = height self.puzzle = np.empty((height, width), dtype=PuzzleElement) @@ -39,7 +38,7 @@ def __init__(self, width=19, height=10, level=None, path=None): self.puzzle_size = None self.pad_x = 0 self.pad_y = 0 - self.path = path or f"levels/lvl{level}.dat" + self.path = path if path: if type(self) == Game: @@ -108,7 +107,7 @@ def _process_puzzle_data(self, data): # Calculate puzzle size and padding self.puzzle_size = (len(data), len(data[0]) if len(data) > 0 else 0) - pad_x = (self.width - self.puzzle_size[1] - 2) // 2 # -2 matches original file-based logic + pad_x = (self.width - self.puzzle_size[1]) // 2 pad_y = (self.height - self.puzzle_size[0]) // 2 self.pad_x, self.pad_y = pad_x, pad_y @@ -140,15 +139,15 @@ def _process_puzzle_data(self, data): class ReverseGame(Game): - def __init__(self, rng: Random, width=19, height=10, level=None): - super().__init__(width, height, level) + def __init__(self, rng: Random, width: int = 19, height: int = 10): + super().__init__(width, height) self.rng = rng self.pad_x = 0 self.pad_y = 0 def load_puzzle(self, puzzle): self.puzzle_size = (len(puzzle), len(puzzle[0]) if len(puzzle) > 0 else 0) - pad_x = (self.width - len(puzzle[0]) - 2) // 2 + pad_x = (self.width - len(puzzle[0])) // 2 pad_y = (self.height - len(puzzle)) // 2 self.pad_x, self.pad_y = pad_x, pad_y for i, row in enumerate(puzzle): diff --git a/reasoning_gym/games/contrib/sokoban/src/generator.py b/reasoning_gym/games/contrib/sokoban/src/generator.py index da4c954f..372a4168 100644 --- a/reasoning_gym/games/contrib/sokoban/src/generator.py +++ b/reasoning_gym/games/contrib/sokoban/src/generator.py @@ -7,6 +7,9 @@ def num_boxes(puzzle_area, min_boxes, max_boxes, min_w, min_h, max_w, max_h): + if min_w == max_w or min_h == max_h or min_boxes == max_boxes: + return max_boxes + m = (max_boxes - min_boxes) / (max_w * max_h - min_w * min_h) b = min_boxes - m * min_w * min_h return int(m * puzzle_area + b) @@ -19,31 +22,33 @@ def random_valid(rng: Random, width: int = 10, height: int = 10): def generate( rng: Random, debug: bool = False, - path: str = None, min_w: int = 6, min_h: int = 6, max_w: int = 15, max_h: int = 10, min_boxes: int = 4, max_boxes: int = 10, + max_depth: int = 100, + path: str = None, ) -> tuple[str, str, dict]: """ Generates a level with the given configuration parameters. Parameters: - rng: Random number generator for reproducibility. - visualizer: Whether to visualize the generation process. - path: Path to save the level file (default 'levels/lvl0.dat'). - min_w: Minimum width of the puzzle. - min_h: Minimum height of the puzzle. - max_w: Maximum width of the puzzle. - max_h: Maximum height of the puzzle. - min_boxes: Minimum number of boxes. - max_boxes: Maximum number of boxes. + rng: Random number generator + visualizer: Whether to visualize the generation process + min_w: Minimum width of the puzzle + min_h: Minimum height of the puzzle + max_w: Maximum width of the puzzle + max_h: Maximum height of the puzzle + min_boxes: Minimum number of boxes + max_boxes: Maximum number of boxes + max_depth: Maximum search depth + path: Path to save the level file (optional) Returns: puzzle_string, solution """ - path = path or "levels/lvl0.dat" + while True: width = rng.randint(min_w, max_w) height = rng.randint(min_h, max_h) @@ -60,7 +65,7 @@ def generate( puzzle[box_pos] = "$" boxes_created += 1 boxes_seen.add(box_pos) - reverse_game = ReverseGame(rng=rng, level=0) + reverse_game = ReverseGame(rng=rng, width=width, height=height) reverse_game.load_puzzle(puzzle) player = reverse_game.player counter = round(height * width * rng.uniform(1.8, 3.6)) @@ -79,16 +84,19 @@ def generate( out_of_place_boxes = np.sum([str(x) == "@" for x in matrix.flatten()]) if out_of_place_boxes >= boxes // 2: # Optionally save the puzzle to a file: - # np.savetxt(path, matrix, fmt='%s') + if path: + np.savetxt(path, matrix, fmt="%s") puzzle_str = player.puzzle_to_string(matrix) grid_list = [list(line) for line in puzzle_str.replace(" ", "").strip().split("\n")] grid_array = np.array(grid_list) - solution, _ = solve_astar(grid_array) + solution, depth = solve_astar(grid_array, max_depth=max_depth) + if solution is None: + continue # retry generation if debug: print(f"solution={solution}") - game = Game() + game = Game(width=width, height=height) game.load_puzzle_matrix(grid_array) for step, move in enumerate(solution): diff --git a/reasoning_gym/games/futoshiki.py b/reasoning_gym/games/futoshiki.py index ce70e14a..092087f1 100644 --- a/reasoning_gym/games/futoshiki.py +++ b/reasoning_gym/games/futoshiki.py @@ -618,7 +618,7 @@ def _try_remove(): return grid def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: - if not answer: + if not isinstance(answer, str): return 0.0 oracle_answer = entry["answer"] diff --git a/reasoning_gym/games/knight_swap.py b/reasoning_gym/games/knight_swap.py index 1e0270f5..2a54289d 100644 --- a/reasoning_gym/games/knight_swap.py +++ b/reasoning_gym/games/knight_swap.py @@ -314,36 +314,35 @@ def score_answer(self, answer: Optional[str], entry: dict) -> float: - 1.0 for correct answer (either "No" for impossible puzzles or valid solution of optimal length) - A proportional score for correct but longer solutions - 0.05 for valid moves that don't solve the puzzle - - 0.01 for invalid format - - 0.0 for None + - 0.0 for invalid format or None """ - if answer is None: + if not isinstance(answer, str): return 0.0 answer = answer.strip() - if not answer: - return 0.01 + if len(answer) == 0: + return 0.0 # Handle impossible puzzles if not entry["metadata"]["is_possible"]: - return 1.0 if answer.lower() == "no" else 0.01 + return 1.0 if answer.lower() == "no" else 0.0 # Handle "No" answer for possible puzzles if answer.lower() == "no": - return 0.01 + return 0.0 try: # Parse moves from JSON list move_list = json.loads(answer) if not isinstance(move_list, list): - return 0.01 + return 0.0 # Parse moves moves = [] for move_str in move_list: color, start, end = move_str.split(",") if color not in ("w", "B"): - return 0.01 + return 0.0 moves.append((color, start, end)) # Validate and apply moves @@ -357,13 +356,13 @@ def score_answer(self, answer: Optional[str], entry: dict) -> float: for color, start, end in moves: if color != current_turn: - return 0.01 + return 0.0 if start not in pieces or pieces[start] != color: - return 0.01 + return 0.0 if end not in board[start]: - return 0.01 + return 0.0 if end in pieces and pieces[end] is not None: - return 0.01 + return 0.0 # Apply move pieces[end] = pieces[start] @@ -390,7 +389,7 @@ def score_answer(self, answer: Optional[str], entry: dict) -> float: return 0.05 except Exception: - return 0.01 + return 0.0 register_dataset("knight_swap", KnightSwapDataset, KnightSwapConfig) diff --git a/reasoning_gym/games/mini_sudoku.py b/reasoning_gym/games/mini_sudoku.py index 319569ff..46df6aa2 100644 --- a/reasoning_gym/games/mini_sudoku.py +++ b/reasoning_gym/games/mini_sudoku.py @@ -195,7 +195,7 @@ def __getitem__(self, idx: int) -> dict: } def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: - if not answer: + if not isinstance(answer, str) or len(answer) == 0: return 0.0 oracle_answer = entry["answer"] diff --git a/reasoning_gym/games/n_queens.py b/reasoning_gym/games/n_queens.py index 8fcecfd4..3a07bb10 100644 --- a/reasoning_gym/games/n_queens.py +++ b/reasoning_gym/games/n_queens.py @@ -138,8 +138,8 @@ def __getitem__(self, idx: int) -> dict: } def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: - valid_solutions = entry["metadata"]["valid_answers"] - if answer is not None: + if isinstance(answer, str): + valid_solutions = entry["metadata"]["valid_answers"] if answer in valid_solutions: return 1.0 try: @@ -147,7 +147,7 @@ def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: if answer in valid_solutions: return 0.5 except Exception as e: - return 0.01 + pass return 0.0 diff --git a/reasoning_gym/games/rush_hour.py b/reasoning_gym/games/rush_hour.py index 1ab9489a..ba48f8a2 100644 --- a/reasoning_gym/games/rush_hour.py +++ b/reasoning_gym/games/rush_hour.py @@ -171,7 +171,7 @@ def score_answer(self, answer: Optional[str], entry: dict) -> float: Returns: 1.0 if solution reaches goal state, 0.0 otherwise """ - if not answer: + if not isinstance(answer, str) or len(answer) == 0: return 0.0 try: diff --git a/reasoning_gym/games/sokoban.py b/reasoning_gym/games/sokoban.py index 1b12169c..0f81930d 100644 --- a/reasoning_gym/games/sokoban.py +++ b/reasoning_gym/games/sokoban.py @@ -11,20 +11,26 @@ class SokobanConfig: """Configuration for sokoban puzzle generation""" - min_w: int = 6 # Minimum width of the puzzle. - min_h: int = 6 # Minimum height of the puzzle. - max_w: int = 10 # Maximum width of the puzzle. - max_h: int = 10 # Maximum height of the puzzle. - min_boxes: int = 6 # Minimum number of boxes. - max_boxes: int = 10 # Maximum number of boxes. + min_w: int = 6 # Minimum width of the puzzle + min_h: int = 6 # Minimum height of the puzzle + max_w: int = 10 # Maximum width of the puzzle + max_h: int = 10 # Maximum height of the puzzle + min_boxes: int = 4 # Minimum number of boxes + max_boxes: int = 10 # Maximum number of boxes + max_depth: int = 80 # Maximum search depth seed: Optional[int] = None size: int = 500 def validate(self): """Validate configuration parameters""" + assert 0 < self.max_w <= 20 + assert 0 < self.max_h <= 20 + assert self.min_h > 0 + assert self.min_w > 0 assert self.min_w <= self.max_w, "min_w must be lte max_w" assert self.min_h <= self.max_h, "min_h must be lte max_h" assert self.min_boxes <= self.max_boxes, "min_boxes must be lte max_boxes" + assert self.max_depth > 1 class SokobanDataset(ProceduralDataset): @@ -58,7 +64,16 @@ def __getitem__(self, idx: int) -> dict: # Make the Sokoban! rng = Random(self.seed + idx) - gamestr, solution, difficulty = self._generate(rng=rng) + gamestr, solution, difficulty = self._generate( + rng=rng, + min_w=self.config.min_w, + min_h=self.config.min_h, + max_w=self.config.max_w, + max_h=self.config.max_h, + min_boxes=self.config.min_boxes, + max_boxes=self.config.max_boxes, + max_depth=self.config.max_depth, + ) return { "question": """You are going to solve a 'sokoban' puzzle. @@ -93,14 +108,15 @@ def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: float: The computed score between 0.0 and 1.0. """ - if answer == None: + if not isinstance(answer, str): return 0.0 try: grid_list = [list(line) for line in entry["metadata"]["gamestr"].replace(" ", "").strip().split("\n")] matrix = np.array(grid_list) - game = self._Game() + h, w = matrix.shape + game = self._Game(height=h, width=w) game.load_puzzle_matrix(matrix) for move in answer: @@ -108,10 +124,10 @@ def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: if self._is_solved(game.get_curr_state()): return 1.0 - except Exception as e: - return 0.01 + except: + pass - return 0.1 + return 0.0 register_dataset("sokoban", SokobanDataset, SokobanConfig) diff --git a/reasoning_gym/games/sudoku.py b/reasoning_gym/games/sudoku.py index 9b07184b..d7ecf8d2 100644 --- a/reasoning_gym/games/sudoku.py +++ b/reasoning_gym/games/sudoku.py @@ -214,7 +214,7 @@ def __getitem__(self, idx: int) -> dict: } def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: - if not answer: + if not isinstance(answer, str) or len(answer) == 0: return 0.0 oracle_answer = entry["answer"] diff --git a/reasoning_gym/games/tower_of_hanoi.py b/reasoning_gym/games/tower_of_hanoi.py index 052d72bf..9ab50cb8 100644 --- a/reasoning_gym/games/tower_of_hanoi.py +++ b/reasoning_gym/games/tower_of_hanoi.py @@ -266,7 +266,7 @@ def __getitem__(self, idx: int) -> dict: start_peg=peg_labels[start_peg], target_peg=peg_labels[target_peg], ), - "answer": solution, + "answer": "\n".join(solution), "metadata": { "num_disks": num_disks, "num_pegs": num_pegs, @@ -383,24 +383,14 @@ def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: Expected behavior: - Correct answer (i.e. equivalent in length, or better, than the one provided in the dataset item) gives 1.0. - A correct solution that is suboptimal length gives a proportional reward of optimal_move_count/user_move_count - - A badly formatted answer gives a minimal reward (0.01). - An answer that is syntactically valid but does not solve the puzzle gives a partial reward (0.05). - - An empty string gives 0.01. - - None gives 0.0. + - A badly formatted or empty answer gives 0.0 """ - if answer is None: + if not isinstance(answer, str) or len(answer) == 0: return 0.0 - if answer == "": - return 0.01 - - # If answer is a string, split it into lines; if it's already a list, use it directly. - if isinstance(answer, str): - moves = [line.strip() for line in answer.strip().splitlines() if line.strip()] - elif isinstance(answer, list): - moves = [line.strip() for line in answer if isinstance(line, str) and line.strip()] - else: - return 0.0 + # Spilt answer string it into lines + moves = [line.strip() for line in answer.strip().splitlines() if line.strip()] # Build the initial peg state from metadata. metadata = entry["metadata"] @@ -418,11 +408,11 @@ def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: try: disk, from_peg, to_peg = self._parse_move(move) except Exception: - return 0.01 # Invalid move format + return 0.0 # Invalid move format # Validate the move using existing _validate_move method. if not self._validate_move(peg_state, move): - return 0.01 + return 0.0 # Execute the move. peg_state[from_peg].pop() diff --git a/reasoning_gym/graphs/family_relationships.py b/reasoning_gym/graphs/family_relationships.py index 973de976..51c0055c 100644 --- a/reasoning_gym/graphs/family_relationships.py +++ b/reasoning_gym/graphs/family_relationships.py @@ -358,16 +358,14 @@ def _generate_story(self, family: set[Person]) -> str: def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: reward = 0.0 - if answer is not None: + if isinstance(answer, str): try: answer_formatted = answer.strip().lower() - solved = answer_formatted == entry["answer"].strip().lower() - if solved: + oracle_answer = entry["answer"].strip().lower() + if answer_formatted == oracle_answer: reward = 1.0 - else: - reward = 0.01 except: - reward = 0.01 + pass return reward diff --git a/reasoning_gym/graphs/quantum_lock.py b/reasoning_gym/graphs/quantum_lock.py index c32085f2..e2196906 100644 --- a/reasoning_gym/graphs/quantum_lock.py +++ b/reasoning_gym/graphs/quantum_lock.py @@ -169,21 +169,15 @@ def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: The function awards 1.0 for a correct answer and less otherwise. """ - if answer == None: + if not isinstance(answer, str): return 0.0 - # Get correct solution from metadata - correct_solution = entry["metadata"].get("solution_path", []) - # Normalize both answers - def normalize_seq(seq): - """Handle both string and list inputs by converting to string first""" - # Convert sequence to string representation if it's a list - input_str = "".join(seq) if isinstance(seq, list) else str(seq or "") - return [c.upper() for c in re.findall(r"[A-C]", input_str.upper())] + def normalize_seq(seq: str) -> list[str]: + return [c.upper() for c in re.findall(r"[A-C]", seq.upper())] user_sequence = normalize_seq(answer) - target_sequence = normalize_seq("".join(correct_solution)) + target_sequence = normalize_seq(entry["answer"]) # Exact sequence match required if user_sequence == target_sequence: @@ -196,7 +190,7 @@ def normalize_seq(seq): return 1.0 # Different answer, but qually correct return 0.5 # Alternative scoring - you're correct, but not optimal - return 0.1 + return 0.0 def simulate_sequence(self, metadata: dict, sequence: list[str]) -> int: """Simulate button presses to verify solutions""" diff --git a/reasoning_gym/graphs/shortest_path.py b/reasoning_gym/graphs/shortest_path.py index d915e114..43c45f54 100644 --- a/reasoning_gym/graphs/shortest_path.py +++ b/reasoning_gym/graphs/shortest_path.py @@ -125,8 +125,8 @@ def _is_valid_path(self, matrix: list[list[str]], path: list[str]) -> bool: def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: """Overwrite this method in derived classes if a single oracle answer is not available.""" - oracle_answer = entry["answer"].strip() - if answer is not None and len(answer) > 0: + if isinstance(answer, str) and len(answer) > 0: + oracle_answer = entry["answer"].strip() answer = answer.strip() # Exact answer @@ -145,8 +145,6 @@ def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: elif self._is_valid_path(matrix, answer): return 0.5 - return 0.01 - return 0.0 def __getitem__(self, idx: int) -> dict: diff --git a/reasoning_gym/logic/circuit_logic.py b/reasoning_gym/logic/circuit_logic.py index fd169076..7f8e69cb 100644 --- a/reasoning_gym/logic/circuit_logic.py +++ b/reasoning_gym/logic/circuit_logic.py @@ -401,16 +401,14 @@ def get_random_input() -> str: } def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: - if answer is None or len(answer) == 0: - return 0.0 - - oracle_answer = entry["answer"] - if oracle_answer == answer: - return 1.0 - elif oracle_answer == answer.strip(): - return len(oracle_answer) / len(answer) - - return 0.01 + if isinstance(answer, str) and len(answer) > 0: + oracle_answer = entry["answer"] + if oracle_answer == answer: + return 1.0 + elif oracle_answer == answer.strip(): + return len(oracle_answer) / len(answer) + + return 0.0 register_dataset("circuit_logic", CircuitLogicDataset, CircuitLogicConfig) diff --git a/reasoning_gym/logic/knights_knaves.py b/reasoning_gym/logic/knights_knaves.py index 7cc16ecf..0e96d069 100644 --- a/reasoning_gym/logic/knights_knaves.py +++ b/reasoning_gym/logic/knights_knaves.py @@ -489,7 +489,7 @@ def _normalize_answer(answer: str) -> set[tuple[str, str]]: def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: """Score an answer against the oracle answer.""" - if answer is None or len(answer) == 0: + if not isinstance(answer, str) or len(answer) == 0: return 0.0 try: @@ -506,11 +506,9 @@ def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: if matching > 0: return 0.3 + (0.7 * matching / len(oracle_assignments)) - return 0.01 - except Exception: - # If parsing fails, give minimal credit - return 0.01 + pass + return 0.0 register_dataset("knights_knaves", KnightsKnavesDataset, KnightsKnavesConfig) diff --git a/reasoning_gym/logic/propositional_logic.py b/reasoning_gym/logic/propositional_logic.py index 8732bc1b..b9387dfa 100644 --- a/reasoning_gym/logic/propositional_logic.py +++ b/reasoning_gym/logic/propositional_logic.py @@ -295,7 +295,7 @@ def _measure_complexity(self, expression: Expression) -> int: def score_answer(self, answer: str | None, entry: dict[str, Any]) -> float: """Robust scoring implementation for propositional logic answers""" - if not answer: + if not isinstance(answer, str): return 0.0 try: @@ -304,7 +304,7 @@ def score_answer(self, answer: str | None, entry: dict[str, Any]) -> float: valid_vars = set(entry["metadata"]["variables"]) answer_vars = re.findall(r"([A-Z])", cleaned_answer) if any(var not in valid_vars for var in answer_vars): - return 0.01 + return 0.0 premises = [Expression.from_string(p) for p in entry["metadata"]["premises"]] answer_expr = Expression.from_string(cleaned_answer) @@ -316,7 +316,7 @@ def score_answer(self, answer: str | None, entry: dict[str, Any]) -> float: return 1.0 return 0.05 except (ValueError, KeyError, AttributeError): - return 0.01 + return 0.0 def _is_trivial(self, expr: Expression) -> bool: """Check for trivial tautologies like P ∨ ¬P""" diff --git a/reasoning_gym/logic/self_reference.py b/reasoning_gym/logic/self_reference.py index f42ce415..a490a619 100644 --- a/reasoning_gym/logic/self_reference.py +++ b/reasoning_gym/logic/self_reference.py @@ -339,9 +339,7 @@ def __getitem__(self, idx: int) -> dict: # Solve puzzle solutions = solve_puzzle_dynamic(puzzle) - for idx, sol in enumerate(solutions, start=1): - sol_str = ["True" if s else "False" for s in sol] - answer = len(solutions) + answer = str(len(solutions)) return { "question": puzz_s, @@ -362,12 +360,10 @@ def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: float: The computed score between 0.0 and 1.0. """ - if answer == None: - return 0.0 - if str(answer) != str(entry["answer"]): - return 0.1 - else: - return 1.0 # Yay + if isinstance(answer, str): + if answer == str(entry["answer"]): + return 1.0 # Yay + return 0.0 register_dataset("self_reference", SelfReferenceDataset, SelfReferenceConfig) diff --git a/reasoning_gym/logic/zebra_puzzles.py b/reasoning_gym/logic/zebra_puzzles.py index c518c65d..7e424b13 100644 --- a/reasoning_gym/logic/zebra_puzzles.py +++ b/reasoning_gym/logic/zebra_puzzles.py @@ -68,12 +68,10 @@ def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: float: The computed score between 0.0 and 1.0. """ - if answer == None: - return 0.0 - if answer.lower().replace("\n", "") != entry["answer"].lower().replace("\n", ""): - return 0.01 - else: - return 1.0 # Yay + if isinstance(answer, str): + if answer.lower().replace("\n", "") == entry["answer"].lower().replace("\n", ""): + return 1.0 # Yay + return 0.0 register_dataset("zebra_puzzles", ZebraDataset, ZebraConfig) diff --git a/reasoning_gym/utils.py b/reasoning_gym/utils.py index 2143961d..ad0a5472 100644 --- a/reasoning_gym/utils.py +++ b/reasoning_gym/utils.py @@ -103,7 +103,6 @@ def compute_decimal_reward(answer: Optional[str], oracle_answer: str, strip_comm """ reward = 0.0 if answer is not None and len(answer) > 0: - reward = 0.01 try: if strip_commas: answer = answer.replace(",", "") diff --git a/tests/test_ab.py b/tests/test_ab.py index 489c4acf..e63a07bc 100644 --- a/tests/test_ab.py +++ b/tests/test_ab.py @@ -57,7 +57,7 @@ def test_ab_scoring(): # Test wrong answer wrong_answer = "A# B#" if item["answer"] != "A# B#" else "B# A#" - assert dataset.score_answer(answer=wrong_answer, entry=item) == 0.01 + assert dataset.score_answer(answer=wrong_answer, entry=item) == 0.0 # Test None answer assert dataset.score_answer(answer=None, entry=item) == 0.0 diff --git a/tests/test_arc_1d.py b/tests/test_arc_1d.py index 140916d6..1eeb0d09 100644 --- a/tests/test_arc_1d.py +++ b/tests/test_arc_1d.py @@ -103,7 +103,7 @@ def test_arc_1d_scoring(): assert dataset.score_answer(f"The answer is: {entry['answer']}", entry) > 0.5 # Test incorrect answer - assert dataset.score_answer("wrong answer", entry) == 0.01 + assert dataset.score_answer("wrong answer", entry) == 0.0 # Test None answer assert dataset.score_answer(None, entry) == 0.0 diff --git a/tests/test_arc_agi.py b/tests/test_arc_agi.py index cfeecf21..ca7a92e3 100644 --- a/tests/test_arc_agi.py +++ b/tests/test_arc_agi.py @@ -110,7 +110,7 @@ def test_arc_agi_scoring(): assert dataset.score_answer(item["answer"], entry=item) == 1.0 # Test invalid format - assert dataset.score_answer("invalid grid format", entry=item) == 0.01 + assert dataset.score_answer("invalid grid format", entry=item) == 0.0 # Test None answer assert dataset.score_answer(None, entry=item) == 0.0 diff --git a/tests/test_bf.py b/tests/test_bf.py index 86d2619a..0455febf 100644 --- a/tests/test_bf.py +++ b/tests/test_bf.py @@ -23,7 +23,7 @@ def test_bf(): # Test the scoring assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0 assert dataset.score_answer(answer=None, entry=item) == 0.0 - assert dataset.score_answer(answer="Love is a battlefield", entry=item) == 0.01 + assert dataset.score_answer(answer="Love is a battlefield", entry=item) == 0.0 # Medium config = BFConfig(seed=43, size=20, difficulty=2) diff --git a/tests/test_binary_matrix.py b/tests/test_binary_matrix.py index 68094472..88ef9adc 100644 --- a/tests/test_binary_matrix.py +++ b/tests/test_binary_matrix.py @@ -115,7 +115,7 @@ def test_binary_matrix_answer(): # Answer is a python list (partially correct answer) answer = "[[0, 0, 0], [0, 1, 0], [1, 2, 1]]" entry = {"answer": "0 0 0\n0 1 0\n1 2 1"} - assert dataset.score_answer(answer, entry) == 0.5 + assert dataset.score_answer(answer, entry) == 0.1 # Answer is null answer = None diff --git a/tests/test_bitwise_arithmetic.py b/tests/test_bitwise_arithmetic.py index cbaff24b..854c764a 100644 --- a/tests/test_bitwise_arithmetic.py +++ b/tests/test_bitwise_arithmetic.py @@ -43,7 +43,7 @@ def test_bitwise_arithmetic_items(): # Test scoring edge cases assert dataset.score_answer(answer=None, entry=item) == 0.0 - assert dataset.score_answer(answer="invalid", entry=item) == 0.01 + assert dataset.score_answer(answer="invalid", entry=item) == 0.0 def test_bitwise_arithmetic_difficulty_levels(): diff --git a/tests/test_complex_arithmetic.py b/tests/test_complex_arithmetic.py index 994b94cc..14f67bfe 100644 --- a/tests/test_complex_arithmetic.py +++ b/tests/test_complex_arithmetic.py @@ -58,6 +58,7 @@ def test_complex_arithmetic_scoring(): assert dataset.score_answer("3 + 2i", entry) == 1.0 assert dataset.score_answer("3+2i", entry) == 1.0 assert dataset.score_answer("3.0 + 2.0i", entry) == 1.0 + assert dataset.score_answer("((3.0 + 2.0i ) )", entry) == 1.0 # Test answers with small errors (should get high but < 1.0 scores) print(dataset.score_answer("3.1 + 2i", entry)) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 4373a204..88ff79f7 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -36,7 +36,7 @@ def test_reseeding_dataset_iteration(): # Test score_answer forwarding test_item = next(iter(infinite_dataset)) - assert infinite_dataset.score_answer("wrong", test_item) == 0.01 + assert infinite_dataset.score_answer("wrong", test_item) == 0.0 assert infinite_dataset.score_answer(test_item["answer"], test_item) == 1.0 diff --git a/tests/test_dataset_common.py b/tests/test_dataset_common.py new file mode 100644 index 00000000..d65c4cb5 --- /dev/null +++ b/tests/test_dataset_common.py @@ -0,0 +1,17 @@ +import reasoning_gym +from reasoning_gym.factory import DATASETS + + +def test_score_answer_consistency(): + for dataset_name in DATASETS.keys(): + if dataset_name == "composite": + continue + dataset = reasoning_gym.create_dataset(dataset_name, size=10, seed=1234) + for entry in dataset: + assert entry["answer"] is None or isinstance( + entry["answer"], str + ), f"{dataset_name} answer must be str, is {type(entry['answer'])}" + if entry["answer"] is not None: + assert ( + dataset.score_answer(answer=entry["answer"], entry=entry) == 1.0 + ), f"inconsistent score_answer {dataset_name}" diff --git a/tests/test_decimal_chain_sum.py b/tests/test_decimal_chain_sum.py index 5114a7c7..e488b77c 100644 --- a/tests/test_decimal_chain_sum.py +++ b/tests/test_decimal_chain_sum.py @@ -242,11 +242,11 @@ def test_decimal_precision_scoring(): assert dataset.score_answer("0.3", {"answer": "0.300"}) == 1.0 # Test incorrect answers - assert dataset.score_answer("1.200000001", {"answer": "1.200"}) == 0.01 - assert dataset.score_answer("1.199999999", {"answer": "1.200"}) == 0.01 + assert dataset.score_answer("1.200000001", {"answer": "1.200"}) == 0.0 + assert dataset.score_answer("1.199999999", {"answer": "1.200"}) == 0.0 # Test invalid inputs assert dataset.score_answer(None, {"answer": "1.200"}) == 0.0 assert dataset.score_answer("", {"answer": "1.200"}) == 0.0 - assert dataset.score_answer("invalid", {"answer": "1.200"}) == 0.01 - assert dataset.score_answer("1.2.3", {"answer": "1.200"}) == 0.01 + assert dataset.score_answer("invalid", {"answer": "1.200"}) == 0.0 + assert dataset.score_answer("1.2.3", {"answer": "1.200"}) == 0.0 diff --git a/tests/test_game_of_life.py b/tests/test_game_of_life.py index 924a12de..bc3c0c61 100644 --- a/tests/test_game_of_life.py +++ b/tests/test_game_of_life.py @@ -50,7 +50,7 @@ def test_game_of_life_basic_properties(): # Test the scoring assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0 assert dataset.score_answer(answer=None, entry=item) == 0.0 - assert dataset.score_answer(answer="invalid json", entry=item) == 0.01 + assert dataset.score_answer(answer="invalid json", entry=item) == 0.0 config = GameOfLifeConfig(seed=43, size=1, grid_size_x=3, grid_size_y=3, filled_cells=1, simulation_steps=1) dataset = GameOfLifeDataset(config) diff --git a/tests/test_intermediate_integration.py b/tests/test_intermediate_integration.py index f4393e2f..6cb32ec8 100644 --- a/tests/test_intermediate_integration.py +++ b/tests/test_intermediate_integration.py @@ -121,16 +121,16 @@ def test_score_answer_cases(): ("x**2", {"variable": "x", "integrand": "2*x"}, 1.0), ("log(x)", {"variable": "x", "integrand": "1/x"}, 1.0), # Incorrect but properly formatted - ("x**3 + C", {"variable": "x", "integrand": "2*x"}, 0.05), - ("cos(X)", {"variable": "X", "integrand": "sin(X)"}, 0.05), + ("x**3 + C", {"variable": "x", "integrand": "2*x"}, 0.0), + ("cos(X)", {"variable": "X", "integrand": "sin(X)"}, 0.0), # Malformed expressions - ("x**2 +", {"variable": "x", "integrand": "2*x"}, 0.01), - ("sin(x", {"variable": "x", "integrand": "cos(x)"}, 0.01), + ("x**2 +", {"variable": "x", "integrand": "2*x"}, 0.0), + ("sin(x", {"variable": "x", "integrand": "cos(x)"}, 0.0), # Empty answer - ("", {"variable": "x", "integrand": "2*x"}, 0.01), + ("", {"variable": "x", "integrand": "2*x"}, 0.0), # Case sensitivity - ("x**2 + C", {"variable": "X", "integrand": "2*X"}, 0.05), - ("X**2 + C", {"variable": "x", "integrand": "2*x"}, 0.05), + ("x**2 + C", {"variable": "X", "integrand": "2*X"}, 0.0), + ("X**2 + C", {"variable": "x", "integrand": "2*x"}, 0.0), # Alternative constant notation ("x**2 + K", {"variable": "x", "integrand": "2*x"}, 1.0), ("sin(x) + D", {"variable": "x", "integrand": "cos(x)"}, 1.0), diff --git a/tests/test_knight_swap.py b/tests/test_knight_swap.py index 24f6b417..bea90d5e 100644 --- a/tests/test_knight_swap.py +++ b/tests/test_knight_swap.py @@ -155,8 +155,8 @@ def test_score_calculation(): # Test invalid answers assert dataset.score_answer(None, puzzle) == 0.0 - assert dataset.score_answer("", puzzle) == 0.01 - assert dataset.score_answer("Invalid", puzzle) == 0.01 + assert dataset.score_answer("", puzzle) == 0.0 + assert dataset.score_answer("Invalid", puzzle) == 0.0 # Test correct answer assert dataset.score_answer(puzzle["answer"], puzzle) == 1.0 diff --git a/tests/test_knights_knaves.py b/tests/test_knights_knaves.py index 0c595c0b..bcaf1fe7 100644 --- a/tests/test_knights_knaves.py +++ b/tests/test_knights_knaves.py @@ -99,8 +99,7 @@ def test_score_answer(): assert dataset.score_answer(correct_answer, problem) == 1.0 assert abs(dataset.score_answer(half_answer, problem) - 0.65) < 1e-10 assert dataset.score_answer(modified_answer, problem) == 1.0 - assert dataset.score_answer(wrong_answer, problem) == 0.01 - print("flipped") + assert dataset.score_answer(wrong_answer, problem) == 0.0 assert dataset.score_answer(flipped_answer, problem) == 1.0 diff --git a/tests/test_manipulate_matrix.py b/tests/test_manipulate_matrix.py index 2801340f..1a31af56 100644 --- a/tests/test_manipulate_matrix.py +++ b/tests/test_manipulate_matrix.py @@ -214,7 +214,7 @@ def test_manipulate_matrix_score_answer(): # incorrect answer answer = "1 2 3\n4 5 6\n7 8 8" - assert dataset.score_answer(answer, entry) == 0.01 + assert dataset.score_answer(answer, entry) == 0.0 # answer is none answer = None diff --git a/tests/test_n_queens.py b/tests/test_n_queens.py index 91373592..6182540b 100644 --- a/tests/test_n_queens.py +++ b/tests/test_n_queens.py @@ -117,7 +117,7 @@ def test_nqueens_score_answer(): # Test invalid answer gets score 0.01 invalid_answer = "_ _ _ _\n_ _ _ _\n_ _ _ _\n_ _ _ _" - assert dataset.score_answer(invalid_answer, item) == 0.01 + assert dataset.score_answer(invalid_answer, item) == 0.0 # Test None answer gets score 0.0 assert dataset.score_answer(None, item) == 0.0 diff --git a/tests/test_needle_haystack.py b/tests/test_needle_haystack.py index fb279da2..672d16b3 100644 --- a/tests/test_needle_haystack.py +++ b/tests/test_needle_haystack.py @@ -16,7 +16,7 @@ def test_needle_haystack(): # Test the scoring assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0 - assert dataset.score_answer(answer="david bowie rules", entry=item) == 0.01 + assert dataset.score_answer(answer="david bowie rules", entry=item) == 0.0 assert dataset.score_answer(answer=None, entry=item) == 0.0 config = NeedleHaystackConfig(seed=42, size=1, num_statements=500) diff --git a/tests/test_number_format.py b/tests/test_number_format.py index 882f38aa..76ae834b 100644 --- a/tests/test_number_format.py +++ b/tests/test_number_format.py @@ -110,7 +110,7 @@ def test_number_format_answer(): # Incorrect answer (diff larger than 1e-2) model_answer = "54245.9" - assert dataset.score_answer(model_answer, entry) == 0.01 + assert dataset.score_answer(model_answer, entry) == 0.0 # Answer is null model_answer = None diff --git a/tests/test_palindrome.py b/tests/test_palindrome.py index 49e23786..472390de 100644 --- a/tests/test_palindrome.py +++ b/tests/test_palindrome.py @@ -84,8 +84,8 @@ def test_score_answer(): wrong_letters = "abcd" if "abcd" != correct_answer else "efgh" assert dataset.score_answer(wrong_letters, entry=item) == 0.02 - # Empty String input should score 0.01 - assert dataset.score_answer("", entry=item) == 0.01 + # Empty String input should score 0.0 + assert dataset.score_answer("", entry=item) == 0.0 # Empty input should score 0.0 assert dataset.score_answer(None, entry=item) == 0.0 diff --git a/tests/test_palindrome_partitioning.py b/tests/test_palindrome_partitioning.py index a81c44bf..7a0d386f 100644 --- a/tests/test_palindrome_partitioning.py +++ b/tests/test_palindrome_partitioning.py @@ -95,17 +95,17 @@ def test_palindrome_partitioning_score_answer(): item = {"metadata": {"solution": [["no", "on"], ["noon"], ["n", "o", "o", "n"]]}} assert dataset.score_answer(answer, item) == 1 - # Verify the score is 0.01 when incorrect + # Verify the score is 0.0 when incorrect answer = json.dumps([["n", "o", "o", "n"], ["no", "on"]]) item = {"metadata": {"solution": [["no", "on"], ["noon"], ["n", "o", "o", "n"]]}} - assert dataset.score_answer(answer, item) == 0.01 + assert dataset.score_answer(answer, item) == 0.0 - # Verify the score is 0 when answer is None + # Verify the score is 0.0 when answer is None answer = None item = {"metadata": {"solution": [["no", "on"], ["noon"], ["n", "o", "o", "n"]]}} - assert dataset.score_answer(answer, item) == 0 + assert dataset.score_answer(answer, item) == 0.0 - # Verify the score is 0 when answer is malformed JSON + # Verify the score is 0.0 when answer is malformed JSON answer = '["n", "o", "o", "n"], ["no", "on"], ["noon"]' item = {"metadata": {"solution": [["no", "on"], ["noon"], ["n", "o", "o", "n"]]}} - assert dataset.score_answer(answer, item) == 0 + assert dataset.score_answer(answer, item) == 0.0 diff --git a/tests/test_polynomial_equations.py b/tests/test_polynomial_equations.py index e4e72b18..f99e6424 100644 --- a/tests/test_polynomial_equations.py +++ b/tests/test_polynomial_equations.py @@ -122,11 +122,11 @@ def test_polynomial_solutions_evaluation(): "oracle_answer, predicted_answer, expected_reward", [ ("4,-4.12", "4,-4.12", 1.0), # Exact match - ("4,-4.12", "4.0001,-4.120001", approx(0.9999, rel=1e-3)), # Very close match - ("4,-4.12", "4.1,-4.2", approx(0.9139, rel=1e-3)), - ("4,8", "4", approx(0.9, rel=1e-3)), # Missing an oracle solution -> missing solution penalty applies - ("4", "4,8", approx(0.95, rel=1e-3)), # extra solution -> extra solution penalty - ("-1,-2", "1,4", approx(0.06890, rel=1e-3)), # -1 matched w/ 1 and -2 matched w/ 4 + ("4,-4.12", "4.0001,-4.120001", approx(0.9994, rel=1e-3)), # Very close match + ("4,-4.12", "4.1,-4.2", approx(0.4086, rel=1e-3)), + ("4,8", "4", approx(0.5, rel=1e-3)), # Missing an oracle solution -> missing solution penalty applies + ("4", "4,8", approx(0.5, rel=1e-3)), # extra solution -> extra solution penalty + ("-1,-2", "1,4", approx(1.0305e-9, rel=1e-3)), # -1 matched w/ 1 and -2 matched w/ 4 ("", "1", approx(0, rel=1e-4)), # oracle no solution, predicted extra solution ("1", "", approx(0, rel=1e-4)), # oracle has a solution, predicted no solution ], diff --git a/tests/test_polynomial_multiplication.py b/tests/test_polynomial_multiplication.py index 10b404d9..0ddf334a 100644 --- a/tests/test_polynomial_multiplication.py +++ b/tests/test_polynomial_multiplication.py @@ -92,7 +92,6 @@ def test_polynomial_equations_dataset_items(): # Check metadata assert isinstance(item["metadata"]["polynomial_expr"], str) - assert isinstance(item["metadata"]["result"], str) assert isinstance(item["metadata"]["variables"], list) # Check polynomial_expr existence @@ -127,42 +126,6 @@ def test_cross_polynomial_equations_dataset_items(): # Check metadata assert isinstance(item["metadata"]["polynomial_expr"], str) - assert isinstance(item["metadata"]["result"], str) - assert isinstance(item["metadata"]["variables"], list) - - # Check polynomial_expr existence - poly_str = item["metadata"]["polynomial_expr"] - # Ensure it can parse with sympy - sp.sympify(poly_str) - - -def test_cross_polynomial_equations_dataset_items(): - """Test that generated items have correct structure""" - ds = create_dataset( - "polynomial_multiplication", - min_terms=2, - max_terms=3, - min_value=1, - max_value=5, - min_degree=1, - max_degree=2, - min_polynomials=2, - max_polynomials=5, - variables=tuple("xyz"), - allow_cross_variable_product=True, - allow_multivariate_polynomials=False, - size=3, - seed=100, - ) - - for item in ds: - assert "question" in item - assert "answer" in item - assert "metadata" in item - - # Check metadata - assert isinstance(item["metadata"]["polynomial_expr"], str) - assert isinstance(item["metadata"]["result"], str) assert isinstance(item["metadata"]["variables"], list) # Check polynomial_expr existence @@ -197,7 +160,6 @@ def test_multivariate_polynomial_equations_dataset_items(): # Check metadata assert isinstance(item["metadata"]["polynomial_expr"], str) - assert isinstance(item["metadata"]["result"], str) assert isinstance(item["metadata"]["variables"], list) # Check polynomial_expr existence @@ -242,7 +204,7 @@ def test_polynomial_solutions_evaluation(): poly_expr = sp.expand(poly_str) # Verify that each solution satisfies the polynomial - assert poly_expr == item["answer"] + assert str(poly_expr) == item["answer"] def test_score_function(): @@ -266,11 +228,11 @@ def test_score_function(): for item in ds: poly_str = item["metadata"]["polynomial_expr"] - assert ds.score_answer(poly_str, item) == 0.05 + assert ds.score_answer(poly_str, item) == 0.0 poly_expr = str(sp.expand(poly_str)) assert ds.score_answer(poly_expr, item) == 1.0 - assert ds.score_answer(None, item) == 0.00 - assert ds.score_answer("Not a polynomial", item) == 0.01 - assert ds.score_answer("x**4", item) == 0.05 + assert ds.score_answer(None, item) == 0.0 + assert ds.score_answer("Not a polynomial", item) == 0.0 + assert ds.score_answer("x**4", item) == 0.0 diff --git a/tests/test_pool_matrix.py b/tests/test_pool_matrix.py index c110967f..a529c05f 100644 --- a/tests/test_pool_matrix.py +++ b/tests/test_pool_matrix.py @@ -143,7 +143,7 @@ def test_pool_matrix_score_answer(): dataset = PoolMatrixDataset(config) for entry in dataset: assert dataset.score_answer(entry["answer"], entry=entry) == 1 - assert 0.0 < dataset.score_answer("1 2.0\n3.0 4", entry=entry) <= 0.1 + assert dataset.score_answer("1 2.0\n3.0 4", entry=entry) in [0.0, 0.1] assert dataset.score_answer("one two three", entry=entry) == 0.0 assert dataset.score_answer("", entry=entry) == 0.0 assert dataset.score_answer(None, entry=entry) == 0.0 diff --git a/tests/test_power_function.py b/tests/test_power_function.py index 08d15826..df2454a4 100644 --- a/tests/test_power_function.py +++ b/tests/test_power_function.py @@ -71,7 +71,7 @@ def test_power_function_score_function(): # Answer is far from solution answer = str(item["metadata"]["solution"] - 1) - assert dataset.score_answer(answer, item) == 0.01 + assert dataset.score_answer(answer, item) == 0.0 # Answer is None answer = None diff --git a/tests/test_products.py b/tests/test_products.py index 469ae5fd..1e4d8dae 100644 --- a/tests/test_products.py +++ b/tests/test_products.py @@ -139,7 +139,7 @@ def test_products_scoring(): assert dataset.score_answer(item["answer"], item) == 1.0, "Exact match should score 1.0" # Test scoring with wrong answer - assert dataset.score_answer("wrong", item) == 0.01, "Wrong answer should score 0.01" + assert dataset.score_answer("wrong", item) == 0.0, "Wrong answer should score 0.0" # Test scoring with partial match (answer contained in response) assert ( diff --git a/tests/test_propositional_logic.py b/tests/test_propositional_logic.py index c73ab1cc..d708cc4a 100644 --- a/tests/test_propositional_logic.py +++ b/tests/test_propositional_logic.py @@ -100,4 +100,4 @@ def test_propositional_logic_dataset_score_answer_incorrect(): dataset = PropositionalLogicDataset(PropositionalLogicConfig(size=100, seed=101)) for i, item in enumerate(dataset): score = dataset.score_answer("Wrong", item) - assert score == 0.01 + assert score == 0.0 diff --git a/tests/test_quantum_lock.py b/tests/test_quantum_lock.py index 69162750..cf9693e6 100644 --- a/tests/test_quantum_lock.py +++ b/tests/test_quantum_lock.py @@ -43,7 +43,8 @@ def test_quantumlock_items(): assert "target_value" in item["metadata"] # Verify solution works - assert dataset.score_answer(answer=item["metadata"]["solution_path"], entry=item) == 1.0 + answer = "".join(item["metadata"]["solution_path"]) + assert dataset.score_answer(answer=answer, entry=item) == 1.0 assert dataset.score_answer(answer=None, entry=item) == 0.0 @@ -98,17 +99,17 @@ def test_quantumlock_scoring(): dataset = QuantumLockDataset(config) for item in dataset: - solution = item["metadata"]["solution_path"] + solution = item["answer"] # Test correct solution assert dataset.score_answer(solution, item) == 1.0 # Test empty/None answers assert dataset.score_answer(None, item) == 0.0 - assert dataset.score_answer("", item) == 0.1 + assert dataset.score_answer("", item) == 0.0 # Test invalid buttons - assert dataset.score_answer("XYZ", item) == 0.1 + assert dataset.score_answer("XYZ", item) == 0.0 # Test case insensitivity if solution: diff --git a/tests/test_ransom_note.py b/tests/test_ransom_note.py index f452ca2e..bf340424 100644 --- a/tests/test_ransom_note.py +++ b/tests/test_ransom_note.py @@ -86,7 +86,7 @@ def test_group_anagrams_dataset_items(): # Test the scoring assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0 - assert dataset.score_answer(answer="gibberish", entry=item) == 0.01 + assert dataset.score_answer(answer="gibberish", entry=item) == 0.0 assert dataset.score_answer(answer=None, entry=item) == 0.0 diff --git a/tests/test_rearc.py b/tests/test_rearc.py index 17de2241..a23d8000 100644 --- a/tests/test_rearc.py +++ b/tests/test_rearc.py @@ -80,7 +80,7 @@ def test_rearc_scoring_edge_cases(): assert 0.0 < dataset.score_answer(partial, entry=item) < 1.0 # Malformed answer - assert dataset.score_answer("[[invalid", entry=item) == 0.01 + assert dataset.score_answer("[[invalid", entry=item) == 0.0 # Case sensitivity answer = format_board(item["metadata"]["output"], dataset.board_format_opts).lower() diff --git a/tests/test_self_reference.py b/tests/test_self_reference.py index 66f15081..0c582641 100644 --- a/tests/test_self_reference.py +++ b/tests/test_self_reference.py @@ -18,8 +18,8 @@ def test_self_reference(): # Test the scoring assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0 - assert dataset.score_answer(answer=99, entry=item) == 0.1 - assert dataset.score_answer(answer="99", entry=item) == 0.1 + assert dataset.score_answer(answer=99, entry=item) == 0.0 + assert dataset.score_answer(answer="99", entry=item) == 0.0 assert dataset.score_answer(answer=None, entry=item) == 0.0 # # Medium @@ -34,8 +34,8 @@ def test_self_reference(): # Test the scoring assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0 - assert dataset.score_answer(answer=99, entry=item) == 0.1 - assert dataset.score_answer(answer="99", entry=item) == 0.1 + assert dataset.score_answer(answer=99, entry=item) == 0.0 + assert dataset.score_answer(answer="99", entry=item) == 0.0 assert dataset.score_answer(answer=None, entry=item) == 0.0 # # Hard @@ -50,6 +50,6 @@ def test_self_reference(): # Test the scoring assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0 - assert dataset.score_answer(answer=99, entry=item) == 0.1 - assert dataset.score_answer(answer="99", entry=item) == 0.1 + assert dataset.score_answer(answer=99, entry=item) == 0.0 + assert dataset.score_answer(answer="99", entry=item) == 0.0 assert dataset.score_answer(answer=None, entry=item) == 0.0 diff --git a/tests/test_shortest_path.py b/tests/test_shortest_path.py index f726b7cf..57950a1e 100644 --- a/tests/test_shortest_path.py +++ b/tests/test_shortest_path.py @@ -164,7 +164,7 @@ def test_shortest_path_answer(): ] }, } - assert dataset.score_answer("right down right down", entry) == 0.01 + assert dataset.score_answer("right down right down", entry) == 0.0 # Answer is None entry = { diff --git a/tests/test_simple_integration.py b/tests/test_simple_integration.py index 726b14be..994bf20f 100644 --- a/tests/test_simple_integration.py +++ b/tests/test_simple_integration.py @@ -94,16 +94,16 @@ def test_score_answer_cases(): ("x**2", {"variable": "x", "integrand": "2*x"}, 1.0), ("log(x)", {"variable": "x", "integrand": "1/x"}, 1.0), # Incorrect but properly formatted - ("x**3 + C", {"variable": "x", "integrand": "2*x"}, 0.05), - ("cos(X)", {"variable": "X", "integrand": "sin(X)"}, 0.05), + ("x**3 + C", {"variable": "x", "integrand": "2*x"}, 0.0), + ("cos(X)", {"variable": "X", "integrand": "sin(X)"}, 0.0), # Malformed expressions - ("x**2 +", {"variable": "x", "integrand": "2*x"}, 0.01), - ("sin(x", {"variable": "x", "integrand": "cos(x)"}, 0.01), + ("x**2 +", {"variable": "x", "integrand": "2*x"}, 0.0), + ("sin(x", {"variable": "x", "integrand": "cos(x)"}, 0.0), # Empty answer - ("", {"variable": "x", "integrand": "2*x"}, 0.01), + ("", {"variable": "x", "integrand": "2*x"}, 0.0), # Case sensitivity - ("x**2 + C", {"variable": "X", "integrand": "2*X"}, 0.05), - ("X**2 + C", {"variable": "x", "integrand": "2*x"}, 0.05), + ("x**2 + C", {"variable": "X", "integrand": "2*X"}, 0.0), + ("X**2 + C", {"variable": "x", "integrand": "2*x"}, 0.0), # Alternative constant notation ("x**2 + K", {"variable": "x", "integrand": "2*x"}, 1.0), ("sin(x) + D", {"variable": "x", "integrand": "cos(x)"}, 1.0), diff --git a/tests/test_sokoban.py b/tests/test_sokoban.py index c4d1e2b8..1e0e6502 100644 --- a/tests/test_sokoban.py +++ b/tests/test_sokoban.py @@ -6,6 +6,10 @@ def test_sokoban(): """Test basic properties and solution of generated items""" + dataset = SokobanDataset(SokobanConfig(size=10, seed=1234)) + for i, item in enumerate(dataset): + assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0 + # Easy config = SokobanConfig(seed=42, size=20) dataset = SokobanDataset(config) @@ -18,11 +22,13 @@ def test_sokoban(): # Test the scoring assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0 - assert dataset.score_answer(answer="RU", entry=item) == 0.1 + assert dataset.score_answer(answer="RU", entry=item) == 0.0 assert dataset.score_answer(answer=None, entry=item) == 0.0 - # Medium - config = SokobanConfig(seed=42, min_h=40, max_h=50, min_w=40, max_w=50, min_boxes=20, max_boxes=30, size=3) + # Hard + config = SokobanConfig( + seed=42, min_h=15, max_h=20, min_w=15, max_w=20, min_boxes=10, max_boxes=15, size=3, max_depth=90 + ) dataset = SokobanDataset(config) for item in dataset: @@ -35,8 +41,10 @@ def test_sokoban(): assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0 assert dataset.score_answer(answer=None, entry=item) == 0.0 - # Hard - config = SokobanConfig(seed=42, min_h=400, max_h=500, min_w=400, max_w=500, min_boxes=50, max_boxes=50, size=1) + # min == max ranges + config = SokobanConfig( + seed=42, min_h=11, max_h=11, min_w=11, max_w=11, min_boxes=11, max_boxes=11, size=3, max_depth=60 + ) dataset = SokobanDataset(config) for item in dataset: diff --git a/tests/test_spiral_matrix.py b/tests/test_spiral_matrix.py index 333ecadd..9e5c510e 100644 --- a/tests/test_spiral_matrix.py +++ b/tests/test_spiral_matrix.py @@ -85,12 +85,12 @@ def test_spiral_matrix_answer(): # Score answer in list format (partially correct) entry = {"answer": "1 2 3 6 9 8 7 4 5"} answer = "[1, 2, 3, 6, 9, 8, 7, 4, 5]" - assert dataset.score_answer(answer, entry) == 0.5 + assert dataset.score_answer(answer, entry) == 0.1 # Answer is incorrect entry = {"answer": "1 2 3 6 9 8 7 4 5"} answer = "1 2 3" - assert dataset.score_answer(answer, entry) == 0.01 + assert dataset.score_answer(answer, entry) == 0.0 # Answer is none entry = {"answer": "1 2 3 6 9 8 7 4 5"} diff --git a/tests/test_string_insertion.py b/tests/test_string_insertion.py index 9d815b15..faff8d90 100644 --- a/tests/test_string_insertion.py +++ b/tests/test_string_insertion.py @@ -101,4 +101,4 @@ def test_string_insertion_answer(): # Test score_answer with correct answer as python list of characters (partial correct) answer = "['A', 'A', 'B', 'C', 'D', 'A', 'E', 'E', 'E', 'E', 'E', 'E', 'E', 'B', 'C', 'D', 'E', 'B', 'A', 'A', 'A', 'A', 'A']" entry = {"answer": "AABCDAEEEEEEEBCDEBAAAAA"} - assert dataset.score_answer(answer, entry) == 0.5 + assert dataset.score_answer(answer, entry) == 0.1 diff --git a/tests/test_tower_of_hanoi.py b/tests/test_tower_of_hanoi.py index 2d870078..01ae5d1d 100644 --- a/tests/test_tower_of_hanoi.py +++ b/tests/test_tower_of_hanoi.py @@ -78,7 +78,7 @@ def test_toh_dataset_items(): # Verify solution_length consistency assert solution_length == len( - item["answer"] + item["answer"].splitlines() ), f"Item {i} metadata 'solution_length' does not match actual number of moves." # Optional: Additional checks like verifying that start and target pegs are distinct @@ -94,8 +94,6 @@ def test_toh_move_validity(): num_disks = instance["metadata"]["num_disks"] num_pegs = instance["metadata"]["num_pegs"] start_peg = instance["metadata"]["start_peg"] - target_peg = instance["metadata"]["target_peg"] - auxiliary_pegs = instance["metadata"]["auxiliary_pegs"] pegs = list(range(1, num_pegs + 1)) # Initialize pegs_state: all disks start on the start peg @@ -104,7 +102,8 @@ def test_toh_move_validity(): pegs_state[start_peg].append(disk) # Iterate over each move and validate - for move_num, move in enumerate(instance["answer"], start=1): + moves = instance["answer"].splitlines() + for move_num, move in enumerate(moves, start=1): disk, from_peg, to_peg = parse_move(move) # Check that from_peg exists @@ -157,7 +156,7 @@ def test_toh_final_state_correct(): pegs_state[start_peg].append(disk) # Perform all moves - for move in instance["answer"]: + for move in instance["answer"].splitlines(): disk, from_peg, to_peg = parse_move(move) pegs_state[from_peg].pop() pegs_state[to_peg].append(disk) @@ -228,17 +227,15 @@ def is_valid_final_state(pegs_state: dict, target_peg: int, num_disks: int) -> b return target_stack == list(range(num_disks, 0, -1)) -def test_score_answer(): +def test_toh_score_answer(): """ Test that the score_answer method returns the expected reward values. Expected behavior: - Correct answer (i.e. equivalent in length, or better, than the one provided in the dataset item) gives 1.0. - A correct solution that is suboptimal length gives a proportional reward of optimal_move_count/user_move_count - - A badly formatted answer gives a minimal reward (0.01). - An answer that is syntactically valid but does not solve the puzzle gives a partial reward (0.05). - - An empty string gives 0.01. - - None gives 0.0. + - A badly formatted or empty answer gives a minimal reward (0.0). """ # Create a dataset instance using the default configuration. config = HanoiConfig(min_disks=3, max_disks=5, min_pegs=3, max_pegs=4, size=5, seed=42) @@ -253,17 +250,17 @@ def test_score_answer(): # 2. A badly formatted answer should yield minimal reward (0.01). score_bad_format = dataset.score_answer(answer="a wrong solution", entry=item) - assert score_bad_format == 0.01, f"Badly formatted answer score {score_bad_format} is not 0.01." + assert score_bad_format == 0.0, f"Badly formatted answer score {score_bad_format} is not 0.0" # 3. An answer that is validly formatted but unsolved. # For example, remove the last move from the correct answer. - unfinished_answer = correct_answer[:-1] + unfinished_answer = "\n".join(correct_answer.splitlines()[:-1]) score_unsolved = dataset.score_answer(answer=unfinished_answer, entry=item) assert score_unsolved == 0.05, f"Unsolved answer score {score_unsolved} is not 0.05." # 4. An empty answer should yield 0.01. score_empty = dataset.score_answer(answer="", entry=item) - assert score_empty == 0.01, f"Empty answer score {score_empty} is not 0.01." + assert score_empty == 0.0, f"Empty answer score {score_empty} is not 0.0." # 5. A None answer should yield 0.0. score_none = dataset.score_answer(answer=None, entry=item) diff --git a/tests/test_utils.py b/tests/test_utils.py index da3bafcd..f64544d9 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -50,4 +50,4 @@ def test_compute_decimal_reward(): # Test invalid answers assert compute_decimal_reward(None, "42") == 0.0 assert compute_decimal_reward("", "42") == 0.0 - assert compute_decimal_reward("not a number", "42") == 0.01 + assert compute_decimal_reward("not a number", "42") == 0.0 diff --git a/tests/test_word_ladder.py b/tests/test_word_ladder.py index 1aba4cf3..738d7939 100644 --- a/tests/test_word_ladder.py +++ b/tests/test_word_ladder.py @@ -380,20 +380,20 @@ def test_word_ladder_score_answer(): assert dataset.score_answer("COLD", entry) == 0.0 # Test wrong start word - assert dataset.score_answer("BOLD,CORD,CARD,WARD,WARM", entry) == 0.01 + assert dataset.score_answer("BOLD,CORD,CARD,WARD,WARM", entry) == 0.0 # Test wrong end word - assert dataset.score_answer("COLD,CORD,CARD,WARD,WARP", entry) == 0.01 + assert dataset.score_answer("COLD,CORD,CARD,WARD,WARP", entry) == 0.0 # Test wrong word length - assert dataset.score_answer("COLD,CORDS,CARDS,WARD,WARM", entry) == 0.01 + assert dataset.score_answer("COLD,CORDS,CARDS,WARD,WARM", entry) == 0.0 # Test invalid transitions (more than one letter change) - assert dataset.score_answer("COLD,WARD,WARM", entry) == 0.01 + assert dataset.score_answer("COLD,WARD,WARM", entry) == 0.0 # Test case insensitivity assert dataset.score_answer("cold,cord,card,ward,warm", entry) == 1.0 # Test with unknown words (should return partial credit) - assert dataset.score_answer("COLD,COXD,CARD,WARD,WARM", entry) < 1.0 - assert dataset.score_answer("COLD,COXD,CARD,WARD,WARM", entry) > 0.0 + assert dataset.score_answer("COLD,COXD,CORD,CARD,WARD,WARM", entry) < 1.0 + assert dataset.score_answer("COLD,COXD,CORD,CARD,WARD,WARM", entry) > 0.0 diff --git a/tests/test_word_sorting.py b/tests/test_word_sorting.py index 8920e814..802c1472 100644 --- a/tests/test_word_sorting.py +++ b/tests/test_word_sorting.py @@ -102,7 +102,7 @@ def test_word_sorting_dataset_items(): # Test the scoring assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0 - assert dataset.score_answer(answer="gibberish", entry=item) == 0.01 + assert dataset.score_answer(answer="gibberish", entry=item) == 0.0 assert dataset.score_answer(answer=None, entry=item) == 0.0 @@ -143,7 +143,7 @@ def test_word_sorting_scoring(): # Garbage answer = "gibberish" - assert dataset.score_answer(answer, item) == 0.01 + assert dataset.score_answer(answer, item) == 0.0 # Empty answer answer = None