diff --git a/reasoning_gym/arithmetic/chain_sum.py b/reasoning_gym/arithmetic/chain_sum.py index 2072983c..351b258f 100644 --- a/reasoning_gym/arithmetic/chain_sum.py +++ b/reasoning_gym/arithmetic/chain_sum.py @@ -111,7 +111,8 @@ def _generate_task(self, rng: random.Random, num_terms: int, min_value: int, max return expression, result def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: - return utils.compute_decimal_reward(answer, oracle_answer=entry["answer"]) + # tolerate sign, leading zeros and trailing decimals, strip commas "+01,000.00" == "1000" + return utils.compute_decimal_reward(answer, oracle_answer=entry["answer"], strip_commas=True) class ChainSumCurriculum(BaseCurriculum): diff --git a/reasoning_gym/arithmetic/products.py b/reasoning_gym/arithmetic/products.py index 8401be91..dbe89cd0 100644 --- a/reasoning_gym/arithmetic/products.py +++ b/reasoning_gym/arithmetic/products.py @@ -103,7 +103,8 @@ def _generate_task(self, rng: random.Random, num_terms: int, min_value: int, max return expression, result def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: - return utils.compute_decimal_reward(answer, oracle_answer=entry["answer"]) + # tolerate sign, leading zeros and trailing decimals, strip commas "+01,000.00" == "1000" + return utils.compute_decimal_reward(answer, oracle_answer=entry["answer"], strip_commas=True) class ProductsCurriculum(BaseCurriculum): diff --git a/reasoning_gym/dataset.py b/reasoning_gym/dataset.py index 17bb00e4..d727c863 100644 --- a/reasoning_gym/dataset.py +++ b/reasoning_gym/dataset.py @@ -62,10 +62,9 @@ def __getitem__(self, idx: int) -> dict[str, Any]: def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: """Overwrite this method in derived classes if a single oracle answer is not available.""" - oracle_answer = entry["answer"].strip() + oracle_answer = entry["answer"] reward = 0.0 - if answer is not None and len(answer) > 0: - answer = answer.strip() + if isinstance(answer, str) and len(answer) > 0: if answer == oracle_answer: reward = 1.0 elif oracle_answer in answer: diff --git a/reasoning_gym/utils.py b/reasoning_gym/utils.py index 70d4a343..fe30e75c 100644 --- a/reasoning_gym/utils.py +++ b/reasoning_gym/utils.py @@ -19,7 +19,7 @@ } -def extract_answer(completion: str, tag_name: str = "answer") -> Optional[str]: +def extract_answer(completion: str, tag_name: str = "answer", strip: bool = True) -> Optional[str]: regex = f"<{tag_name}>\\s?(.*?)\\s?" matches = list( re.finditer( @@ -30,21 +30,25 @@ def extract_answer(completion: str, tag_name: str = "answer") -> Optional[str]: ) if not matches: return None - return matches[-1].group(1) + answer = matches[-1].group(1) + if strip: + answer = answer.strip() + return answer -def format_number(num: Union[int, float], max_decimals: int = 2) -> str: +def format_number(num: Union[int, float], max_decimals: int = 2, round_if_needed: bool = False) -> str: """Convert a number to string representation with controlled decimal places. Args: num: Number to format max_decimals: Maximum allowed decimal places + round_if_needed: If True, round the number to max_decimals instead of raising an error Returns: String representation of the number Raises: - ValueError: If number requires more decimal places than allowed + ValueError: If number requires more decimal places than allowed and round_if_needed is False """ if isinstance(num, int) or num.is_integer(): return str(int(num)) @@ -57,19 +61,20 @@ def format_number(num: Union[int, float], max_decimals: int = 2) -> str: str_val = str_val.rstrip("0").rstrip(".") if "." in str_val: required_decimals = len(str_val.split(".")[1]) - if required_decimals > max_decimals: + if required_decimals > max_decimals and not round_if_needed: raise ValueError(f"Number {num} requires {required_decimals} decimals but only {max_decimals} allowed") - # Format with required decimals + # Format with required decimals (will round if needed) result = f"{num:.{max_decimals}f}".rstrip("0").rstrip(".") - # Verify result parses back to original value - try: - parsed = float(result) - if not math.isclose(parsed, num, rel_tol=1e-9): - raise ValueError(f"String representation {result} does not match original value {num}") - except (ValueError, InvalidOperation) as e: - raise ValueError(f"Failed to verify string representation: {e}") + # Verify result parses back to original value (skip verification if rounding was applied) + if not round_if_needed: + try: + parsed = float(result) + if not math.isclose(parsed, num, rel_tol=1e-9): + raise ValueError(f"String representation {result} does not match original value {num}") + except (ValueError, InvalidOperation) as e: + raise ValueError(f"Failed to verify string representation: {e}") return result @@ -84,6 +89,7 @@ def is_integer(obj: Any) -> bool: def compute_decimal_reward(answer: Optional[str], oracle_answer: str, strip_commas: bool = True) -> float: """Compute the reward for a given answer compared to the oracle answer. + Tolerate sign, leading zeros and trailing decimals, optionally strip commas ("+01,000.00" == "1000") Args: answer: Answer provided by model @@ -102,7 +108,7 @@ def compute_decimal_reward(answer: Optional[str], oracle_answer: str, strip_comm oracle_answer = oracle_answer.replace(",", "") if Decimal(answer) == Decimal(oracle_answer): - reward = 1.0 + return 1.0 except: pass diff --git a/tests/test_dataset.py b/tests/test_dataset.py index f321ad2a..4373a204 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -2,7 +2,6 @@ from reasoning_gym.arithmetic.basic_arithmetic import BasicArithmeticDataset, BasicArithmeticDatasetConfig from reasoning_gym.dataset import ReseedingDataset -from reasoning_gym.utils import extract_answer def test_reseeding_dataset_iteration(): @@ -41,12 +40,7 @@ def test_reseeding_dataset_iteration(): assert infinite_dataset.score_answer(test_item["answer"], test_item) == 1.0 -def test_extract_answer(): - assert extract_answer("This is a text. 1234", tag_name="final_answer") == "1234" - - # ignore single whitespae - assert extract_answer("This is a text. \n1234 ", tag_name="answer") == "1234" - +def test_basic_arithmetic_score_answer(): config = BasicArithmeticDatasetConfig( min_terms=2, max_terms=3, min_digits=1, max_digits=2, operators=["+"], allow_parentheses=False, seed=42, size=10 ) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..da3bafcd --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,53 @@ +import pytest + +from reasoning_gym.utils import compute_decimal_reward, extract_answer, format_number + + +def test_extract_answer(): + assert extract_answer("This is a text. 1234", tag_name="final_answer") == "1234" + + # ignore whitespaces + assert extract_answer("This is a text. \n1234 ", tag_name="answer", strip=True) == "1234" + + +def test_format_number(): + # Test integers + assert format_number(42) == "42" + assert format_number(42.0) == "42" + + # Test decimals + assert format_number(3.14) == "3.14" + assert format_number(3.10) == "3.1" + assert format_number(3.00) == "3" + + # Test with max_decimals (rounding) + assert format_number(3.14159, max_decimals=4, round_if_needed=True) == "3.1416" + + # Test with trailing zeros + assert format_number(5.5000) == "5.5" + + # Test error cases + with pytest.raises(ValueError): + format_number(3.14159, max_decimals=2) + + +def test_compute_decimal_reward(): + # Test exact matches + assert compute_decimal_reward("42", "42") == 1.0 + assert compute_decimal_reward("3.14", "3.14") == 1.0 + + # Test with commas + assert compute_decimal_reward("1,000", "1000") == 1.0 + assert compute_decimal_reward("1,000", "1000", strip_commas=False) < 1.0 + + # Test with sign, leading zeros, and trailing decimals + assert compute_decimal_reward("+0001,000.00", "1000") == 1.0 + + # Test partial matches + assert compute_decimal_reward("The answer is 42", "42") < 1.0 + assert compute_decimal_reward("The answer is 42", "42") > 0.01 + + # Test invalid answers + assert compute_decimal_reward(None, "42") == 0.0 + assert compute_decimal_reward("", "42") == 0.0 + assert compute_decimal_reward("not a number", "42") == 0.01