From 9d0343ef2a40dc0014456ef3348eb9f24460ee4c Mon Sep 17 00:00:00 2001 From: Andreas Koepf Date: Tue, 4 Mar 2025 19:48:59 +0100 Subject: [PATCH] further reward simplifications --- reasoning_gym/arc/arc_agi.py | 2 +- reasoning_gym/arc/rearc.py | 2 +- reasoning_gym/games/knight_swap.py | 5 ++- reasoning_gym/games/tower_of_hanoi.py | 15 +++----- reasoning_gym/graphs/quantum_lock.py | 12 ++----- tests/test_arc_agi.py | 2 +- tests/test_needle_haystack.py | 2 +- tests/test_polynomial_multiplication.py | 48 +++---------------------- tests/test_pool_matrix.py | 2 +- tests/test_quantum_lock.py | 9 ++--- tests/test_rearc.py | 2 +- tests/test_simple_integration.py | 14 ++++---- tests/test_tower_of_hanoi.py | 19 +++++----- 13 files changed, 40 insertions(+), 94 deletions(-) diff --git a/reasoning_gym/arc/arc_agi.py b/reasoning_gym/arc/arc_agi.py index 98c3f000..a19a6507 100644 --- a/reasoning_gym/arc/arc_agi.py +++ b/reasoning_gym/arc/arc_agi.py @@ -199,7 +199,7 @@ def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: else: reward = 0.05 except: - reward = 0.01 + reward = 0.0 return reward diff --git a/reasoning_gym/arc/rearc.py b/reasoning_gym/arc/rearc.py index fe4a9aa5..dabbe808 100644 --- a/reasoning_gym/arc/rearc.py +++ b/reasoning_gym/arc/rearc.py @@ -106,7 +106,7 @@ def score_answer(self, answer: str, entry: dict[str, Any]) -> float: else: reward = 0.05 except: - reward = 0.01 + reward = 0.0 return reward diff --git a/reasoning_gym/games/knight_swap.py b/reasoning_gym/games/knight_swap.py index 2a33886e..2a54289d 100644 --- a/reasoning_gym/games/knight_swap.py +++ b/reasoning_gym/games/knight_swap.py @@ -314,8 +314,7 @@ def score_answer(self, answer: Optional[str], entry: dict) -> float: - 1.0 for correct answer (either "No" for impossible puzzles or valid solution of optimal length) - A proportional score for correct but longer solutions - 0.05 for valid moves that don't solve the puzzle - - 0.01 for invalid format - - 0.0 for None + - 0.0 for invalid format or None """ if not isinstance(answer, str): return 0.0 @@ -326,7 +325,7 @@ def score_answer(self, answer: Optional[str], entry: dict) -> float: # Handle impossible puzzles if not entry["metadata"]["is_possible"]: - return 1.0 if answer.lower() == "no" else 0.01 + return 1.0 if answer.lower() == "no" else 0.0 # Handle "No" answer for possible puzzles if answer.lower() == "no": diff --git a/reasoning_gym/games/tower_of_hanoi.py b/reasoning_gym/games/tower_of_hanoi.py index 66462f61..9ab50cb8 100644 --- a/reasoning_gym/games/tower_of_hanoi.py +++ b/reasoning_gym/games/tower_of_hanoi.py @@ -266,7 +266,7 @@ def __getitem__(self, idx: int) -> dict: start_peg=peg_labels[start_peg], target_peg=peg_labels[target_peg], ), - "answer": solution, + "answer": "\n".join(solution), "metadata": { "num_disks": num_disks, "num_pegs": num_pegs, @@ -383,21 +383,14 @@ def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: Expected behavior: - Correct answer (i.e. equivalent in length, or better, than the one provided in the dataset item) gives 1.0. - A correct solution that is suboptimal length gives a proportional reward of optimal_move_count/user_move_count - - A badly formatted answer gives a minimal reward (0.01). - An answer that is syntactically valid but does not solve the puzzle gives a partial reward (0.05). - - An empty string gives 0.01. - - None gives 0.0. + - A badly formatted or empty answer gives 0.0 """ if not isinstance(answer, str) or len(answer) == 0: return 0.0 - # If answer is a string, split it into lines; if it's already a list, use it directly. - if isinstance(answer, str): - moves = [line.strip() for line in answer.strip().splitlines() if line.strip()] - elif isinstance(answer, list): - moves = [line.strip() for line in answer if isinstance(line, str) and line.strip()] - else: - return 0.0 + # Spilt answer string it into lines + moves = [line.strip() for line in answer.strip().splitlines() if line.strip()] # Build the initial peg state from metadata. metadata = entry["metadata"] diff --git a/reasoning_gym/graphs/quantum_lock.py b/reasoning_gym/graphs/quantum_lock.py index 9d0f276b..e2196906 100644 --- a/reasoning_gym/graphs/quantum_lock.py +++ b/reasoning_gym/graphs/quantum_lock.py @@ -172,18 +172,12 @@ def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: if not isinstance(answer, str): return 0.0 - # Get correct solution from metadata - correct_solution = entry["metadata"].get("solution_path", []) - # Normalize both answers - def normalize_seq(seq): - """Handle both string and list inputs by converting to string first""" - # Convert sequence to string representation if it's a list - input_str = "".join(seq) if isinstance(seq, list) else str(seq or "") - return [c.upper() for c in re.findall(r"[A-C]", input_str.upper())] + def normalize_seq(seq: str) -> list[str]: + return [c.upper() for c in re.findall(r"[A-C]", seq.upper())] user_sequence = normalize_seq(answer) - target_sequence = normalize_seq("".join(correct_solution)) + target_sequence = normalize_seq(entry["answer"]) # Exact sequence match required if user_sequence == target_sequence: diff --git a/tests/test_arc_agi.py b/tests/test_arc_agi.py index cfeecf21..ca7a92e3 100644 --- a/tests/test_arc_agi.py +++ b/tests/test_arc_agi.py @@ -110,7 +110,7 @@ def test_arc_agi_scoring(): assert dataset.score_answer(item["answer"], entry=item) == 1.0 # Test invalid format - assert dataset.score_answer("invalid grid format", entry=item) == 0.01 + assert dataset.score_answer("invalid grid format", entry=item) == 0.0 # Test None answer assert dataset.score_answer(None, entry=item) == 0.0 diff --git a/tests/test_needle_haystack.py b/tests/test_needle_haystack.py index fb279da2..672d16b3 100644 --- a/tests/test_needle_haystack.py +++ b/tests/test_needle_haystack.py @@ -16,7 +16,7 @@ def test_needle_haystack(): # Test the scoring assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0 - assert dataset.score_answer(answer="david bowie rules", entry=item) == 0.01 + assert dataset.score_answer(answer="david bowie rules", entry=item) == 0.0 assert dataset.score_answer(answer=None, entry=item) == 0.0 config = NeedleHaystackConfig(seed=42, size=1, num_statements=500) diff --git a/tests/test_polynomial_multiplication.py b/tests/test_polynomial_multiplication.py index 10b404d9..0ddf334a 100644 --- a/tests/test_polynomial_multiplication.py +++ b/tests/test_polynomial_multiplication.py @@ -92,7 +92,6 @@ def test_polynomial_equations_dataset_items(): # Check metadata assert isinstance(item["metadata"]["polynomial_expr"], str) - assert isinstance(item["metadata"]["result"], str) assert isinstance(item["metadata"]["variables"], list) # Check polynomial_expr existence @@ -127,42 +126,6 @@ def test_cross_polynomial_equations_dataset_items(): # Check metadata assert isinstance(item["metadata"]["polynomial_expr"], str) - assert isinstance(item["metadata"]["result"], str) - assert isinstance(item["metadata"]["variables"], list) - - # Check polynomial_expr existence - poly_str = item["metadata"]["polynomial_expr"] - # Ensure it can parse with sympy - sp.sympify(poly_str) - - -def test_cross_polynomial_equations_dataset_items(): - """Test that generated items have correct structure""" - ds = create_dataset( - "polynomial_multiplication", - min_terms=2, - max_terms=3, - min_value=1, - max_value=5, - min_degree=1, - max_degree=2, - min_polynomials=2, - max_polynomials=5, - variables=tuple("xyz"), - allow_cross_variable_product=True, - allow_multivariate_polynomials=False, - size=3, - seed=100, - ) - - for item in ds: - assert "question" in item - assert "answer" in item - assert "metadata" in item - - # Check metadata - assert isinstance(item["metadata"]["polynomial_expr"], str) - assert isinstance(item["metadata"]["result"], str) assert isinstance(item["metadata"]["variables"], list) # Check polynomial_expr existence @@ -197,7 +160,6 @@ def test_multivariate_polynomial_equations_dataset_items(): # Check metadata assert isinstance(item["metadata"]["polynomial_expr"], str) - assert isinstance(item["metadata"]["result"], str) assert isinstance(item["metadata"]["variables"], list) # Check polynomial_expr existence @@ -242,7 +204,7 @@ def test_polynomial_solutions_evaluation(): poly_expr = sp.expand(poly_str) # Verify that each solution satisfies the polynomial - assert poly_expr == item["answer"] + assert str(poly_expr) == item["answer"] def test_score_function(): @@ -266,11 +228,11 @@ def test_score_function(): for item in ds: poly_str = item["metadata"]["polynomial_expr"] - assert ds.score_answer(poly_str, item) == 0.05 + assert ds.score_answer(poly_str, item) == 0.0 poly_expr = str(sp.expand(poly_str)) assert ds.score_answer(poly_expr, item) == 1.0 - assert ds.score_answer(None, item) == 0.00 - assert ds.score_answer("Not a polynomial", item) == 0.01 - assert ds.score_answer("x**4", item) == 0.05 + assert ds.score_answer(None, item) == 0.0 + assert ds.score_answer("Not a polynomial", item) == 0.0 + assert ds.score_answer("x**4", item) == 0.0 diff --git a/tests/test_pool_matrix.py b/tests/test_pool_matrix.py index b27282b2..a529c05f 100644 --- a/tests/test_pool_matrix.py +++ b/tests/test_pool_matrix.py @@ -143,7 +143,7 @@ def test_pool_matrix_score_answer(): dataset = PoolMatrixDataset(config) for entry in dataset: assert dataset.score_answer(entry["answer"], entry=entry) == 1 - assert dataset.score_answer("1 2.0\n3.0 4", entry=entry) == 0.0 + assert dataset.score_answer("1 2.0\n3.0 4", entry=entry) in [0.0, 0.1] assert dataset.score_answer("one two three", entry=entry) == 0.0 assert dataset.score_answer("", entry=entry) == 0.0 assert dataset.score_answer(None, entry=entry) == 0.0 diff --git a/tests/test_quantum_lock.py b/tests/test_quantum_lock.py index 69162750..cf9693e6 100644 --- a/tests/test_quantum_lock.py +++ b/tests/test_quantum_lock.py @@ -43,7 +43,8 @@ def test_quantumlock_items(): assert "target_value" in item["metadata"] # Verify solution works - assert dataset.score_answer(answer=item["metadata"]["solution_path"], entry=item) == 1.0 + answer = "".join(item["metadata"]["solution_path"]) + assert dataset.score_answer(answer=answer, entry=item) == 1.0 assert dataset.score_answer(answer=None, entry=item) == 0.0 @@ -98,17 +99,17 @@ def test_quantumlock_scoring(): dataset = QuantumLockDataset(config) for item in dataset: - solution = item["metadata"]["solution_path"] + solution = item["answer"] # Test correct solution assert dataset.score_answer(solution, item) == 1.0 # Test empty/None answers assert dataset.score_answer(None, item) == 0.0 - assert dataset.score_answer("", item) == 0.1 + assert dataset.score_answer("", item) == 0.0 # Test invalid buttons - assert dataset.score_answer("XYZ", item) == 0.1 + assert dataset.score_answer("XYZ", item) == 0.0 # Test case insensitivity if solution: diff --git a/tests/test_rearc.py b/tests/test_rearc.py index 17de2241..a23d8000 100644 --- a/tests/test_rearc.py +++ b/tests/test_rearc.py @@ -80,7 +80,7 @@ def test_rearc_scoring_edge_cases(): assert 0.0 < dataset.score_answer(partial, entry=item) < 1.0 # Malformed answer - assert dataset.score_answer("[[invalid", entry=item) == 0.01 + assert dataset.score_answer("[[invalid", entry=item) == 0.0 # Case sensitivity answer = format_board(item["metadata"]["output"], dataset.board_format_opts).lower() diff --git a/tests/test_simple_integration.py b/tests/test_simple_integration.py index 726b14be..994bf20f 100644 --- a/tests/test_simple_integration.py +++ b/tests/test_simple_integration.py @@ -94,16 +94,16 @@ def test_score_answer_cases(): ("x**2", {"variable": "x", "integrand": "2*x"}, 1.0), ("log(x)", {"variable": "x", "integrand": "1/x"}, 1.0), # Incorrect but properly formatted - ("x**3 + C", {"variable": "x", "integrand": "2*x"}, 0.05), - ("cos(X)", {"variable": "X", "integrand": "sin(X)"}, 0.05), + ("x**3 + C", {"variable": "x", "integrand": "2*x"}, 0.0), + ("cos(X)", {"variable": "X", "integrand": "sin(X)"}, 0.0), # Malformed expressions - ("x**2 +", {"variable": "x", "integrand": "2*x"}, 0.01), - ("sin(x", {"variable": "x", "integrand": "cos(x)"}, 0.01), + ("x**2 +", {"variable": "x", "integrand": "2*x"}, 0.0), + ("sin(x", {"variable": "x", "integrand": "cos(x)"}, 0.0), # Empty answer - ("", {"variable": "x", "integrand": "2*x"}, 0.01), + ("", {"variable": "x", "integrand": "2*x"}, 0.0), # Case sensitivity - ("x**2 + C", {"variable": "X", "integrand": "2*X"}, 0.05), - ("X**2 + C", {"variable": "x", "integrand": "2*x"}, 0.05), + ("x**2 + C", {"variable": "X", "integrand": "2*X"}, 0.0), + ("X**2 + C", {"variable": "x", "integrand": "2*x"}, 0.0), # Alternative constant notation ("x**2 + K", {"variable": "x", "integrand": "2*x"}, 1.0), ("sin(x) + D", {"variable": "x", "integrand": "cos(x)"}, 1.0), diff --git a/tests/test_tower_of_hanoi.py b/tests/test_tower_of_hanoi.py index a7b5add3..01ae5d1d 100644 --- a/tests/test_tower_of_hanoi.py +++ b/tests/test_tower_of_hanoi.py @@ -78,7 +78,7 @@ def test_toh_dataset_items(): # Verify solution_length consistency assert solution_length == len( - item["answer"] + item["answer"].splitlines() ), f"Item {i} metadata 'solution_length' does not match actual number of moves." # Optional: Additional checks like verifying that start and target pegs are distinct @@ -94,8 +94,6 @@ def test_toh_move_validity(): num_disks = instance["metadata"]["num_disks"] num_pegs = instance["metadata"]["num_pegs"] start_peg = instance["metadata"]["start_peg"] - target_peg = instance["metadata"]["target_peg"] - auxiliary_pegs = instance["metadata"]["auxiliary_pegs"] pegs = list(range(1, num_pegs + 1)) # Initialize pegs_state: all disks start on the start peg @@ -104,7 +102,8 @@ def test_toh_move_validity(): pegs_state[start_peg].append(disk) # Iterate over each move and validate - for move_num, move in enumerate(instance["answer"], start=1): + moves = instance["answer"].splitlines() + for move_num, move in enumerate(moves, start=1): disk, from_peg, to_peg = parse_move(move) # Check that from_peg exists @@ -157,7 +156,7 @@ def test_toh_final_state_correct(): pegs_state[start_peg].append(disk) # Perform all moves - for move in instance["answer"]: + for move in instance["answer"].splitlines(): disk, from_peg, to_peg = parse_move(move) pegs_state[from_peg].pop() pegs_state[to_peg].append(disk) @@ -235,10 +234,8 @@ def test_toh_score_answer(): Expected behavior: - Correct answer (i.e. equivalent in length, or better, than the one provided in the dataset item) gives 1.0. - A correct solution that is suboptimal length gives a proportional reward of optimal_move_count/user_move_count - - A badly formatted answer gives a minimal reward (0.01). - An answer that is syntactically valid but does not solve the puzzle gives a partial reward (0.05). - - An empty string gives 0.01. - - None gives 0.0. + - A badly formatted or empty answer gives a minimal reward (0.0). """ # Create a dataset instance using the default configuration. config = HanoiConfig(min_disks=3, max_disks=5, min_pegs=3, max_pegs=4, size=5, seed=42) @@ -253,17 +250,17 @@ def test_toh_score_answer(): # 2. A badly formatted answer should yield minimal reward (0.01). score_bad_format = dataset.score_answer(answer="a wrong solution", entry=item) - assert score_bad_format == 0.01, f"Badly formatted answer score {score_bad_format} is not 0.01." + assert score_bad_format == 0.0, f"Badly formatted answer score {score_bad_format} is not 0.0" # 3. An answer that is validly formatted but unsolved. # For example, remove the last move from the correct answer. - unfinished_answer = correct_answer[:-1] + unfinished_answer = "\n".join(correct_answer.splitlines()[:-1]) score_unsolved = dataset.score_answer(answer=unfinished_answer, entry=item) assert score_unsolved == 0.05, f"Unsolved answer score {score_unsolved} is not 0.05." # 4. An empty answer should yield 0.01. score_empty = dataset.score_answer(answer="", entry=item) - assert score_empty == 0.01, f"Empty answer score {score_empty} is not 0.01." + assert score_empty == 0.0, f"Empty answer score {score_empty} is not 0.0." # 5. A None answer should yield 0.0. score_none = dataset.score_answer(answer=None, entry=item)