Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove strip from ProceduralDataset::core score_answer() #250

Merged
merged 9 commits into from
Mar 2, 2025
3 changes: 2 additions & 1 deletion reasoning_gym/arithmetic/chain_sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ def _generate_task(self, rng: random.Random, num_terms: int, min_value: int, max
return expression, result

def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
return utils.compute_decimal_reward(answer, oracle_answer=entry["answer"])
# tolerate sign, leading zeros and trailing decimals, strip commas "+01,000.00" == "1000"
return utils.compute_decimal_reward(answer, oracle_answer=entry["answer"], strip_commas=True)


class ChainSumCurriculum(BaseCurriculum):
Expand Down
3 changes: 2 additions & 1 deletion reasoning_gym/arithmetic/products.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ def _generate_task(self, rng: random.Random, num_terms: int, min_value: int, max
return expression, result

def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
return utils.compute_decimal_reward(answer, oracle_answer=entry["answer"])
# tolerate sign, leading zeros and trailing decimals, strip commas "+01,000.00" == "1000"
return utils.compute_decimal_reward(answer, oracle_answer=entry["answer"], strip_commas=True)


class ProductsCurriculum(BaseCurriculum):
Expand Down
5 changes: 2 additions & 3 deletions reasoning_gym/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,9 @@ def __getitem__(self, idx: int) -> dict[str, Any]:

def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""Overwrite this method in derived classes if a single oracle answer is not available."""
oracle_answer = entry["answer"].strip()
oracle_answer = entry["answer"]
reward = 0.0
if answer is not None and len(answer) > 0:
answer = answer.strip()
if isinstance(answer, str) and len(answer) > 0:
if answer == oracle_answer:
reward = 1.0
elif oracle_answer in answer:
Expand Down
34 changes: 20 additions & 14 deletions reasoning_gym/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
}


def extract_answer(completion: str, tag_name: str = "answer") -> Optional[str]:
def extract_answer(completion: str, tag_name: str = "answer", strip: bool = True) -> Optional[str]:
regex = f"<{tag_name}>\\s?(.*?)\\s?</{tag_name}>"
matches = list(
re.finditer(
Expand All @@ -30,21 +30,25 @@ def extract_answer(completion: str, tag_name: str = "answer") -> Optional[str]:
)
if not matches:
return None
return matches[-1].group(1)
answer = matches[-1].group(1)
if strip:
answer = answer.strip()
return answer


def format_number(num: Union[int, float], max_decimals: int = 2) -> str:
def format_number(num: Union[int, float], max_decimals: int = 2, round_if_needed: bool = False) -> str:
"""Convert a number to string representation with controlled decimal places.

Args:
num: Number to format
max_decimals: Maximum allowed decimal places
round_if_needed: If True, round the number to max_decimals instead of raising an error

Returns:
String representation of the number

Raises:
ValueError: If number requires more decimal places than allowed
ValueError: If number requires more decimal places than allowed and round_if_needed is False
"""
if isinstance(num, int) or num.is_integer():
return str(int(num))
Expand All @@ -57,19 +61,20 @@ def format_number(num: Union[int, float], max_decimals: int = 2) -> str:
str_val = str_val.rstrip("0").rstrip(".")
if "." in str_val:
required_decimals = len(str_val.split(".")[1])
if required_decimals > max_decimals:
if required_decimals > max_decimals and not round_if_needed:
raise ValueError(f"Number {num} requires {required_decimals} decimals but only {max_decimals} allowed")

# Format with required decimals
# Format with required decimals (will round if needed)
result = f"{num:.{max_decimals}f}".rstrip("0").rstrip(".")

# Verify result parses back to original value
try:
parsed = float(result)
if not math.isclose(parsed, num, rel_tol=1e-9):
raise ValueError(f"String representation {result} does not match original value {num}")
except (ValueError, InvalidOperation) as e:
raise ValueError(f"Failed to verify string representation: {e}")
# Verify result parses back to original value (skip verification if rounding was applied)
if not round_if_needed:
try:
parsed = float(result)
if not math.isclose(parsed, num, rel_tol=1e-9):
raise ValueError(f"String representation {result} does not match original value {num}")
except (ValueError, InvalidOperation) as e:
raise ValueError(f"Failed to verify string representation: {e}")

return result

Expand All @@ -84,6 +89,7 @@ def is_integer(obj: Any) -> bool:

def compute_decimal_reward(answer: Optional[str], oracle_answer: str, strip_commas: bool = True) -> float:
"""Compute the reward for a given answer compared to the oracle answer.
Tolerate sign, leading zeros and trailing decimals, optionally strip commas ("+01,000.00" == "1000")

Args:
answer: Answer provided by model
Expand All @@ -102,7 +108,7 @@ def compute_decimal_reward(answer: Optional[str], oracle_answer: str, strip_comm
oracle_answer = oracle_answer.replace(",", "")

if Decimal(answer) == Decimal(oracle_answer):
reward = 1.0
return 1.0
except:
pass

Expand Down
8 changes: 1 addition & 7 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from reasoning_gym.arithmetic.basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConfig
from reasoning_gym.dataset import ReseedingDataset
from reasoning_gym.utils import extract_answer


def test_reseeding_dataset_iteration():
Expand Down Expand Up @@ -41,12 +40,7 @@ def test_reseeding_dataset_iteration():
assert infinite_dataset.score_answer(test_item["answer"], test_item) == 1.0


def test_extract_answer():
assert extract_answer("This is a text. <final_answer>1234</final_answer>", tag_name="final_answer") == "1234"

# ignore single whitespae
assert extract_answer("This is a text. <answer>\n1234 </answer>", tag_name="answer") == "1234"

def test_basic_arithmetic_score_answer():
config = BasicArithmeticDatasetConfig(
min_terms=2, max_terms=3, min_digits=1, max_digits=2, operators=["+"], allow_parentheses=False, seed=42, size=10
)
Expand Down
53 changes: 53 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import pytest

from reasoning_gym.utils import compute_decimal_reward, extract_answer, format_number


def test_extract_answer():
assert extract_answer("This is a text. <final_answer>1234</final_answer>", tag_name="final_answer") == "1234"

# ignore whitespaces
assert extract_answer("This is a text. <answer>\n1234 </answer>", tag_name="answer", strip=True) == "1234"


def test_format_number():
# Test integers
assert format_number(42) == "42"
assert format_number(42.0) == "42"

# Test decimals
assert format_number(3.14) == "3.14"
assert format_number(3.10) == "3.1"
assert format_number(3.00) == "3"

# Test with max_decimals (rounding)
assert format_number(3.14159, max_decimals=4, round_if_needed=True) == "3.1416"

# Test with trailing zeros
assert format_number(5.5000) == "5.5"

# Test error cases
with pytest.raises(ValueError):
format_number(3.14159, max_decimals=2)


def test_compute_decimal_reward():
# Test exact matches
assert compute_decimal_reward("42", "42") == 1.0
assert compute_decimal_reward("3.14", "3.14") == 1.0

# Test with commas
assert compute_decimal_reward("1,000", "1000") == 1.0
assert compute_decimal_reward("1,000", "1000", strip_commas=False) < 1.0

# Test with sign, leading zeros, and trailing decimals
assert compute_decimal_reward("+0001,000.00", "1000") == 1.0

# Test partial matches
assert compute_decimal_reward("The answer is 42", "42") < 1.0
assert compute_decimal_reward("The answer is 42", "42") > 0.01

# Test invalid answers
assert compute_decimal_reward(None, "42") == 0.0
assert compute_decimal_reward("", "42") == 0.0
assert compute_decimal_reward("not a number", "42") == 0.01