Skip to content

Commit

Permalink
test: Add test case for compute_decimal_reward with sign and zeros
Browse files Browse the repository at this point in the history
  • Loading branch information
andreaskoepf committed Mar 2, 2025
1 parent 4427053 commit 06c3434
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 2 deletions.
4 changes: 4 additions & 0 deletions reasoning_gym/arithmetic/chain_sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ def _generate_task(self, rng: random.Random, num_terms: int, min_value: int, max
expression = " ".join(expression_parts)
return expression, result

def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
# tolerate sign, leading zeros and trailing decimals, strip commas "+01,000.00" == "1000"
return utils.compute_decimal_reward(answer, oracle_answer=entry["answer"], strip_commas=True)


class ChainSumCurriculum(BaseCurriculum):
def __init__(self):
Expand Down
3 changes: 2 additions & 1 deletion reasoning_gym/arithmetic/products.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ def _generate_task(self, rng: random.Random, num_terms: int, min_value: int, max
return expression, result

def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
return utils.compute_decimal_reward(answer, oracle_answer=entry["answer"])
# tolerate sign, leading zeros and trailing decimals, strip commas "+01,000.00" == "1000"
return utils.compute_decimal_reward(answer, oracle_answer=entry["answer"], strip_commas=True)


class ProductsCurriculum(BaseCurriculum):
Expand Down
2 changes: 1 addition & 1 deletion reasoning_gym/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def compute_decimal_reward(answer: Optional[str], oracle_answer: str, strip_comm
oracle_answer = oracle_answer.replace(",", "")

if Decimal(answer) == Decimal(oracle_answer):
reward = 1.0
return 1.0
except:
pass

Expand Down
3 changes: 3 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ def test_compute_decimal_reward():
# Test with commas
assert compute_decimal_reward("1,000", "1000") == 1.0
assert compute_decimal_reward("1,000", "1000", strip_commas=False) < 1.0

# Test with sign, leading zeros, and trailing decimals
assert compute_decimal_reward("+0001,000.00", "1000") == 1.0

# Test partial matches
assert compute_decimal_reward("The answer is 42", "42") < 1.0
Expand Down

0 comments on commit 06c3434

Please sign in to comment.