Skip to content

Commit

Permalink
Remove strip from ProceduralDataset::core score_answer() (#250)
Browse files Browse the repository at this point in the history
* remove strip from ProceduralDataset::core score_answer(), strip in extract answer (optional, default=True)
* test: Move test_extract_answer() from test_dataset.py to test_utils.py
* refactor: Improve decimal reward computation with more flexible comparison
* fix: Implement rounding for format_number when round_if_needed is True
* test: Add test case for compute_decimal_reward with sign and zeros
  • Loading branch information
andreaskoepf authored Mar 2, 2025
1 parent a66a7e7 commit 24828e1
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 26 deletions.
3 changes: 2 additions & 1 deletion reasoning_gym/arithmetic/chain_sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ def _generate_task(self, rng: random.Random, num_terms: int, min_value: int, max
return expression, result

def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
return utils.compute_decimal_reward(answer, oracle_answer=entry["answer"])
# tolerate sign, leading zeros and trailing decimals, strip commas "+01,000.00" == "1000"
return utils.compute_decimal_reward(answer, oracle_answer=entry["answer"], strip_commas=True)


class ChainSumCurriculum(BaseCurriculum):
Expand Down
3 changes: 2 additions & 1 deletion reasoning_gym/arithmetic/products.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ def _generate_task(self, rng: random.Random, num_terms: int, min_value: int, max
return expression, result

def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
return utils.compute_decimal_reward(answer, oracle_answer=entry["answer"])
# tolerate sign, leading zeros and trailing decimals, strip commas "+01,000.00" == "1000"
return utils.compute_decimal_reward(answer, oracle_answer=entry["answer"], strip_commas=True)


class ProductsCurriculum(BaseCurriculum):
Expand Down
5 changes: 2 additions & 3 deletions reasoning_gym/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,9 @@ def __getitem__(self, idx: int) -> dict[str, Any]:

def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""Overwrite this method in derived classes if a single oracle answer is not available."""
oracle_answer = entry["answer"].strip()
oracle_answer = entry["answer"]
reward = 0.0
if answer is not None and len(answer) > 0:
answer = answer.strip()
if isinstance(answer, str) and len(answer) > 0:
if answer == oracle_answer:
reward = 1.0
elif oracle_answer in answer:
Expand Down
34 changes: 20 additions & 14 deletions reasoning_gym/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
}


def extract_answer(completion: str, tag_name: str = "answer") -> Optional[str]:
def extract_answer(completion: str, tag_name: str = "answer", strip: bool = True) -> Optional[str]:
regex = f"<{tag_name}>\\s?(.*?)\\s?</{tag_name}>"
matches = list(
re.finditer(
Expand All @@ -30,21 +30,25 @@ def extract_answer(completion: str, tag_name: str = "answer") -> Optional[str]:
)
if not matches:
return None
return matches[-1].group(1)
answer = matches[-1].group(1)
if strip:
answer = answer.strip()
return answer


def format_number(num: Union[int, float], max_decimals: int = 2) -> str:
def format_number(num: Union[int, float], max_decimals: int = 2, round_if_needed: bool = False) -> str:
"""Convert a number to string representation with controlled decimal places.
Args:
num: Number to format
max_decimals: Maximum allowed decimal places
round_if_needed: If True, round the number to max_decimals instead of raising an error
Returns:
String representation of the number
Raises:
ValueError: If number requires more decimal places than allowed
ValueError: If number requires more decimal places than allowed and round_if_needed is False
"""
if isinstance(num, int) or num.is_integer():
return str(int(num))
Expand All @@ -57,19 +61,20 @@ def format_number(num: Union[int, float], max_decimals: int = 2) -> str:
str_val = str_val.rstrip("0").rstrip(".")
if "." in str_val:
required_decimals = len(str_val.split(".")[1])
if required_decimals > max_decimals:
if required_decimals > max_decimals and not round_if_needed:
raise ValueError(f"Number {num} requires {required_decimals} decimals but only {max_decimals} allowed")

# Format with required decimals
# Format with required decimals (will round if needed)
result = f"{num:.{max_decimals}f}".rstrip("0").rstrip(".")

# Verify result parses back to original value
try:
parsed = float(result)
if not math.isclose(parsed, num, rel_tol=1e-9):
raise ValueError(f"String representation {result} does not match original value {num}")
except (ValueError, InvalidOperation) as e:
raise ValueError(f"Failed to verify string representation: {e}")
# Verify result parses back to original value (skip verification if rounding was applied)
if not round_if_needed:
try:
parsed = float(result)
if not math.isclose(parsed, num, rel_tol=1e-9):
raise ValueError(f"String representation {result} does not match original value {num}")
except (ValueError, InvalidOperation) as e:
raise ValueError(f"Failed to verify string representation: {e}")

return result

Expand All @@ -84,6 +89,7 @@ def is_integer(obj: Any) -> bool:

def compute_decimal_reward(answer: Optional[str], oracle_answer: str, strip_commas: bool = True) -> float:
"""Compute the reward for a given answer compared to the oracle answer.
Tolerate sign, leading zeros and trailing decimals, optionally strip commas ("+01,000.00" == "1000")
Args:
answer: Answer provided by model
Expand All @@ -102,7 +108,7 @@ def compute_decimal_reward(answer: Optional[str], oracle_answer: str, strip_comm
oracle_answer = oracle_answer.replace(",", "")

if Decimal(answer) == Decimal(oracle_answer):
reward = 1.0
return 1.0
except:
pass

Expand Down
8 changes: 1 addition & 7 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from reasoning_gym.arithmetic.basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConfig
from reasoning_gym.dataset import ReseedingDataset
from reasoning_gym.utils import extract_answer


def test_reseeding_dataset_iteration():
Expand Down Expand Up @@ -41,12 +40,7 @@ def test_reseeding_dataset_iteration():
assert infinite_dataset.score_answer(test_item["answer"], test_item) == 1.0


def test_extract_answer():
assert extract_answer("This is a text. <final_answer>1234</final_answer>", tag_name="final_answer") == "1234"

# ignore single whitespae
assert extract_answer("This is a text. <answer>\n1234 </answer>", tag_name="answer") == "1234"

def test_basic_arithmetic_score_answer():
config = BasicArithmeticDatasetConfig(
min_terms=2, max_terms=3, min_digits=1, max_digits=2, operators=["+"], allow_parentheses=False, seed=42, size=10
)
Expand Down
53 changes: 53 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import pytest

from reasoning_gym.utils import compute_decimal_reward, extract_answer, format_number


def test_extract_answer():
assert extract_answer("This is a text. <final_answer>1234</final_answer>", tag_name="final_answer") == "1234"

# ignore whitespaces
assert extract_answer("This is a text. <answer>\n1234 </answer>", tag_name="answer", strip=True) == "1234"


def test_format_number():
# Test integers
assert format_number(42) == "42"
assert format_number(42.0) == "42"

# Test decimals
assert format_number(3.14) == "3.14"
assert format_number(3.10) == "3.1"
assert format_number(3.00) == "3"

# Test with max_decimals (rounding)
assert format_number(3.14159, max_decimals=4, round_if_needed=True) == "3.1416"

# Test with trailing zeros
assert format_number(5.5000) == "5.5"

# Test error cases
with pytest.raises(ValueError):
format_number(3.14159, max_decimals=2)


def test_compute_decimal_reward():
# Test exact matches
assert compute_decimal_reward("42", "42") == 1.0
assert compute_decimal_reward("3.14", "3.14") == 1.0

# Test with commas
assert compute_decimal_reward("1,000", "1000") == 1.0
assert compute_decimal_reward("1,000", "1000", strip_commas=False) < 1.0

# Test with sign, leading zeros, and trailing decimals
assert compute_decimal_reward("+0001,000.00", "1000") == 1.0

# Test partial matches
assert compute_decimal_reward("The answer is 42", "42") < 1.0
assert compute_decimal_reward("The answer is 42", "42") > 0.01

# Test invalid answers
assert compute_decimal_reward(None, "42") == 0.0
assert compute_decimal_reward("", "42") == 0.0
assert compute_decimal_reward("not a number", "42") == 0.01

0 comments on commit 24828e1

Please sign in to comment.