Skip to content

Commit

Permalink
further reward simplifications
Browse files Browse the repository at this point in the history
  • Loading branch information
andreaskoepf committed Mar 4, 2025
1 parent acfd892 commit 9d0343e
Show file tree
Hide file tree
Showing 13 changed files with 40 additions and 94 deletions.
2 changes: 1 addition & 1 deletion reasoning_gym/arc/arc_agi.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
else:
reward = 0.05
except:
reward = 0.01
reward = 0.0
return reward


Expand Down
2 changes: 1 addition & 1 deletion reasoning_gym/arc/rearc.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def score_answer(self, answer: str, entry: dict[str, Any]) -> float:
else:
reward = 0.05
except:
reward = 0.01
reward = 0.0
return reward


Expand Down
5 changes: 2 additions & 3 deletions reasoning_gym/games/knight_swap.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,8 +314,7 @@ def score_answer(self, answer: Optional[str], entry: dict) -> float:
- 1.0 for correct answer (either "No" for impossible puzzles or valid solution of optimal length)
- A proportional score for correct but longer solutions
- 0.05 for valid moves that don't solve the puzzle
- 0.01 for invalid format
- 0.0 for None
- 0.0 for invalid format or None
"""
if not isinstance(answer, str):
return 0.0
Expand All @@ -326,7 +325,7 @@ def score_answer(self, answer: Optional[str], entry: dict) -> float:

# Handle impossible puzzles
if not entry["metadata"]["is_possible"]:
return 1.0 if answer.lower() == "no" else 0.01
return 1.0 if answer.lower() == "no" else 0.0

# Handle "No" answer for possible puzzles
if answer.lower() == "no":
Expand Down
15 changes: 4 additions & 11 deletions reasoning_gym/games/tower_of_hanoi.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def __getitem__(self, idx: int) -> dict:
start_peg=peg_labels[start_peg],
target_peg=peg_labels[target_peg],
),
"answer": solution,
"answer": "\n".join(solution),
"metadata": {
"num_disks": num_disks,
"num_pegs": num_pegs,
Expand Down Expand Up @@ -383,21 +383,14 @@ def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
Expected behavior:
- Correct answer (i.e. equivalent in length, or better, than the one provided in the dataset item) gives 1.0.
- A correct solution that is suboptimal length gives a proportional reward of optimal_move_count/user_move_count
- A badly formatted answer gives a minimal reward (0.01).
- An answer that is syntactically valid but does not solve the puzzle gives a partial reward (0.05).
- An empty string gives 0.01.
- None gives 0.0.
- A badly formatted or empty answer gives 0.0
"""
if not isinstance(answer, str) or len(answer) == 0:
return 0.0

# If answer is a string, split it into lines; if it's already a list, use it directly.
if isinstance(answer, str):
moves = [line.strip() for line in answer.strip().splitlines() if line.strip()]
elif isinstance(answer, list):
moves = [line.strip() for line in answer if isinstance(line, str) and line.strip()]
else:
return 0.0
# Spilt answer string it into lines
moves = [line.strip() for line in answer.strip().splitlines() if line.strip()]

# Build the initial peg state from metadata.
metadata = entry["metadata"]
Expand Down
12 changes: 3 additions & 9 deletions reasoning_gym/graphs/quantum_lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,18 +172,12 @@ def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
if not isinstance(answer, str):
return 0.0

# Get correct solution from metadata
correct_solution = entry["metadata"].get("solution_path", [])

# Normalize both answers
def normalize_seq(seq):
"""Handle both string and list inputs by converting to string first"""
# Convert sequence to string representation if it's a list
input_str = "".join(seq) if isinstance(seq, list) else str(seq or "")
return [c.upper() for c in re.findall(r"[A-C]", input_str.upper())]
def normalize_seq(seq: str) -> list[str]:
return [c.upper() for c in re.findall(r"[A-C]", seq.upper())]

user_sequence = normalize_seq(answer)
target_sequence = normalize_seq("".join(correct_solution))
target_sequence = normalize_seq(entry["answer"])

# Exact sequence match required
if user_sequence == target_sequence:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_arc_agi.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def test_arc_agi_scoring():
assert dataset.score_answer(item["answer"], entry=item) == 1.0

# Test invalid format
assert dataset.score_answer("invalid grid format", entry=item) == 0.01
assert dataset.score_answer("invalid grid format", entry=item) == 0.0

# Test None answer
assert dataset.score_answer(None, entry=item) == 0.0
Expand Down
2 changes: 1 addition & 1 deletion tests/test_needle_haystack.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_needle_haystack():

# Test the scoring
assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0
assert dataset.score_answer(answer="david bowie rules", entry=item) == 0.01
assert dataset.score_answer(answer="david bowie rules", entry=item) == 0.0
assert dataset.score_answer(answer=None, entry=item) == 0.0

config = NeedleHaystackConfig(seed=42, size=1, num_statements=500)
Expand Down
48 changes: 5 additions & 43 deletions tests/test_polynomial_multiplication.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ def test_polynomial_equations_dataset_items():

# Check metadata
assert isinstance(item["metadata"]["polynomial_expr"], str)
assert isinstance(item["metadata"]["result"], str)
assert isinstance(item["metadata"]["variables"], list)

# Check polynomial_expr existence
Expand Down Expand Up @@ -127,42 +126,6 @@ def test_cross_polynomial_equations_dataset_items():

# Check metadata
assert isinstance(item["metadata"]["polynomial_expr"], str)
assert isinstance(item["metadata"]["result"], str)
assert isinstance(item["metadata"]["variables"], list)

# Check polynomial_expr existence
poly_str = item["metadata"]["polynomial_expr"]
# Ensure it can parse with sympy
sp.sympify(poly_str)


def test_cross_polynomial_equations_dataset_items():
"""Test that generated items have correct structure"""
ds = create_dataset(
"polynomial_multiplication",
min_terms=2,
max_terms=3,
min_value=1,
max_value=5,
min_degree=1,
max_degree=2,
min_polynomials=2,
max_polynomials=5,
variables=tuple("xyz"),
allow_cross_variable_product=True,
allow_multivariate_polynomials=False,
size=3,
seed=100,
)

for item in ds:
assert "question" in item
assert "answer" in item
assert "metadata" in item

# Check metadata
assert isinstance(item["metadata"]["polynomial_expr"], str)
assert isinstance(item["metadata"]["result"], str)
assert isinstance(item["metadata"]["variables"], list)

# Check polynomial_expr existence
Expand Down Expand Up @@ -197,7 +160,6 @@ def test_multivariate_polynomial_equations_dataset_items():

# Check metadata
assert isinstance(item["metadata"]["polynomial_expr"], str)
assert isinstance(item["metadata"]["result"], str)
assert isinstance(item["metadata"]["variables"], list)

# Check polynomial_expr existence
Expand Down Expand Up @@ -242,7 +204,7 @@ def test_polynomial_solutions_evaluation():
poly_expr = sp.expand(poly_str)

# Verify that each solution satisfies the polynomial
assert poly_expr == item["answer"]
assert str(poly_expr) == item["answer"]


def test_score_function():
Expand All @@ -266,11 +228,11 @@ def test_score_function():

for item in ds:
poly_str = item["metadata"]["polynomial_expr"]
assert ds.score_answer(poly_str, item) == 0.05
assert ds.score_answer(poly_str, item) == 0.0

poly_expr = str(sp.expand(poly_str))
assert ds.score_answer(poly_expr, item) == 1.0

assert ds.score_answer(None, item) == 0.00
assert ds.score_answer("Not a polynomial", item) == 0.01
assert ds.score_answer("x**4", item) == 0.05
assert ds.score_answer(None, item) == 0.0
assert ds.score_answer("Not a polynomial", item) == 0.0
assert ds.score_answer("x**4", item) == 0.0
2 changes: 1 addition & 1 deletion tests/test_pool_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def test_pool_matrix_score_answer():
dataset = PoolMatrixDataset(config)
for entry in dataset:
assert dataset.score_answer(entry["answer"], entry=entry) == 1
assert dataset.score_answer("1 2.0\n3.0 4", entry=entry) == 0.0
assert dataset.score_answer("1 2.0\n3.0 4", entry=entry) in [0.0, 0.1]
assert dataset.score_answer("one two three", entry=entry) == 0.0
assert dataset.score_answer("", entry=entry) == 0.0
assert dataset.score_answer(None, entry=entry) == 0.0
Expand Down
9 changes: 5 additions & 4 deletions tests/test_quantum_lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ def test_quantumlock_items():
assert "target_value" in item["metadata"]

# Verify solution works
assert dataset.score_answer(answer=item["metadata"]["solution_path"], entry=item) == 1.0
answer = "".join(item["metadata"]["solution_path"])
assert dataset.score_answer(answer=answer, entry=item) == 1.0
assert dataset.score_answer(answer=None, entry=item) == 0.0


Expand Down Expand Up @@ -98,17 +99,17 @@ def test_quantumlock_scoring():
dataset = QuantumLockDataset(config)

for item in dataset:
solution = item["metadata"]["solution_path"]
solution = item["answer"]

# Test correct solution
assert dataset.score_answer(solution, item) == 1.0

# Test empty/None answers
assert dataset.score_answer(None, item) == 0.0
assert dataset.score_answer("", item) == 0.1
assert dataset.score_answer("", item) == 0.0

# Test invalid buttons
assert dataset.score_answer("XYZ", item) == 0.1
assert dataset.score_answer("XYZ", item) == 0.0

# Test case insensitivity
if solution:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_rearc.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def test_rearc_scoring_edge_cases():
assert 0.0 < dataset.score_answer(partial, entry=item) < 1.0

# Malformed answer
assert dataset.score_answer("[[invalid", entry=item) == 0.01
assert dataset.score_answer("[[invalid", entry=item) == 0.0

# Case sensitivity
answer = format_board(item["metadata"]["output"], dataset.board_format_opts).lower()
Expand Down
14 changes: 7 additions & 7 deletions tests/test_simple_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,16 +94,16 @@ def test_score_answer_cases():
("x**2", {"variable": "x", "integrand": "2*x"}, 1.0),
("log(x)", {"variable": "x", "integrand": "1/x"}, 1.0),
# Incorrect but properly formatted
("x**3 + C", {"variable": "x", "integrand": "2*x"}, 0.05),
("cos(X)", {"variable": "X", "integrand": "sin(X)"}, 0.05),
("x**3 + C", {"variable": "x", "integrand": "2*x"}, 0.0),
("cos(X)", {"variable": "X", "integrand": "sin(X)"}, 0.0),
# Malformed expressions
("x**2 +", {"variable": "x", "integrand": "2*x"}, 0.01),
("sin(x", {"variable": "x", "integrand": "cos(x)"}, 0.01),
("x**2 +", {"variable": "x", "integrand": "2*x"}, 0.0),
("sin(x", {"variable": "x", "integrand": "cos(x)"}, 0.0),
# Empty answer
("", {"variable": "x", "integrand": "2*x"}, 0.01),
("", {"variable": "x", "integrand": "2*x"}, 0.0),
# Case sensitivity
("x**2 + C", {"variable": "X", "integrand": "2*X"}, 0.05),
("X**2 + C", {"variable": "x", "integrand": "2*x"}, 0.05),
("x**2 + C", {"variable": "X", "integrand": "2*X"}, 0.0),
("X**2 + C", {"variable": "x", "integrand": "2*x"}, 0.0),
# Alternative constant notation
("x**2 + K", {"variable": "x", "integrand": "2*x"}, 1.0),
("sin(x) + D", {"variable": "x", "integrand": "cos(x)"}, 1.0),
Expand Down
19 changes: 8 additions & 11 deletions tests/test_tower_of_hanoi.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def test_toh_dataset_items():

# Verify solution_length consistency
assert solution_length == len(
item["answer"]
item["answer"].splitlines()
), f"Item {i} metadata 'solution_length' does not match actual number of moves."

# Optional: Additional checks like verifying that start and target pegs are distinct
Expand All @@ -94,8 +94,6 @@ def test_toh_move_validity():
num_disks = instance["metadata"]["num_disks"]
num_pegs = instance["metadata"]["num_pegs"]
start_peg = instance["metadata"]["start_peg"]
target_peg = instance["metadata"]["target_peg"]
auxiliary_pegs = instance["metadata"]["auxiliary_pegs"]
pegs = list(range(1, num_pegs + 1))

# Initialize pegs_state: all disks start on the start peg
Expand All @@ -104,7 +102,8 @@ def test_toh_move_validity():
pegs_state[start_peg].append(disk)

# Iterate over each move and validate
for move_num, move in enumerate(instance["answer"], start=1):
moves = instance["answer"].splitlines()
for move_num, move in enumerate(moves, start=1):
disk, from_peg, to_peg = parse_move(move)

# Check that from_peg exists
Expand Down Expand Up @@ -157,7 +156,7 @@ def test_toh_final_state_correct():
pegs_state[start_peg].append(disk)

# Perform all moves
for move in instance["answer"]:
for move in instance["answer"].splitlines():
disk, from_peg, to_peg = parse_move(move)
pegs_state[from_peg].pop()
pegs_state[to_peg].append(disk)
Expand Down Expand Up @@ -235,10 +234,8 @@ def test_toh_score_answer():
Expected behavior:
- Correct answer (i.e. equivalent in length, or better, than the one provided in the dataset item) gives 1.0.
- A correct solution that is suboptimal length gives a proportional reward of optimal_move_count/user_move_count
- A badly formatted answer gives a minimal reward (0.01).
- An answer that is syntactically valid but does not solve the puzzle gives a partial reward (0.05).
- An empty string gives 0.01.
- None gives 0.0.
- A badly formatted or empty answer gives a minimal reward (0.0).
"""
# Create a dataset instance using the default configuration.
config = HanoiConfig(min_disks=3, max_disks=5, min_pegs=3, max_pegs=4, size=5, seed=42)
Expand All @@ -253,17 +250,17 @@ def test_toh_score_answer():

# 2. A badly formatted answer should yield minimal reward (0.01).
score_bad_format = dataset.score_answer(answer="a wrong solution", entry=item)
assert score_bad_format == 0.01, f"Badly formatted answer score {score_bad_format} is not 0.01."
assert score_bad_format == 0.0, f"Badly formatted answer score {score_bad_format} is not 0.0"

# 3. An answer that is validly formatted but unsolved.
# For example, remove the last move from the correct answer.
unfinished_answer = correct_answer[:-1]
unfinished_answer = "\n".join(correct_answer.splitlines()[:-1])
score_unsolved = dataset.score_answer(answer=unfinished_answer, entry=item)
assert score_unsolved == 0.05, f"Unsolved answer score {score_unsolved} is not 0.05."

# 4. An empty answer should yield 0.01.
score_empty = dataset.score_answer(answer="", entry=item)
assert score_empty == 0.01, f"Empty answer score {score_empty} is not 0.01."
assert score_empty == 0.0, f"Empty answer score {score_empty} is not 0.0."

# 5. A None answer should yield 0.0.
score_none = dataset.score_answer(answer=None, entry=item)
Expand Down

0 comments on commit 9d0343e

Please sign in to comment.