diff --git a/reasoning_gym/arithmetic/chain_sum.py b/reasoning_gym/arithmetic/chain_sum.py index 2072983c..8464068f 100644 --- a/reasoning_gym/arithmetic/chain_sum.py +++ b/reasoning_gym/arithmetic/chain_sum.py @@ -110,9 +110,6 @@ def _generate_task(self, rng: random.Random, num_terms: int, min_value: int, max expression = " ".join(expression_parts) return expression, result - def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: - return utils.compute_decimal_reward(answer, oracle_answer=entry["answer"]) - class ChainSumCurriculum(BaseCurriculum): def __init__(self): diff --git a/reasoning_gym/dataset.py b/reasoning_gym/dataset.py index 17bb00e4..d727c863 100644 --- a/reasoning_gym/dataset.py +++ b/reasoning_gym/dataset.py @@ -62,10 +62,9 @@ def __getitem__(self, idx: int) -> dict[str, Any]: def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: """Overwrite this method in derived classes if a single oracle answer is not available.""" - oracle_answer = entry["answer"].strip() + oracle_answer = entry["answer"] reward = 0.0 - if answer is not None and len(answer) > 0: - answer = answer.strip() + if isinstance(answer, str) and len(answer) > 0: if answer == oracle_answer: reward = 1.0 elif oracle_answer in answer: diff --git a/reasoning_gym/utils.py b/reasoning_gym/utils.py index 70d4a343..33b4fe35 100644 --- a/reasoning_gym/utils.py +++ b/reasoning_gym/utils.py @@ -19,7 +19,7 @@ } -def extract_answer(completion: str, tag_name: str = "answer") -> Optional[str]: +def extract_answer(completion: str, tag_name: str = "answer", strip: bool = True) -> Optional[str]: regex = f"<{tag_name}>\\s?(.*?)\\s?" matches = list( re.finditer( @@ -30,7 +30,10 @@ def extract_answer(completion: str, tag_name: str = "answer") -> Optional[str]: ) if not matches: return None - return matches[-1].group(1) + answer = matches[-1].group(1) + if strip: + answer = answer.strip() + return answer def format_number(num: Union[int, float], max_decimals: int = 2) -> str: diff --git a/tests/test_dataset.py b/tests/test_dataset.py index f321ad2a..56c89593 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -44,8 +44,8 @@ def test_reseeding_dataset_iteration(): def test_extract_answer(): assert extract_answer("This is a text. 1234", tag_name="final_answer") == "1234" - # ignore single whitespae - assert extract_answer("This is a text. \n1234 ", tag_name="answer") == "1234" + # ignore whitespaces + assert extract_answer("This is a text. \n1234 ", tag_name="answer", strip=True) == "1234" config = BasicArithmeticDatasetConfig( min_terms=2, max_terms=3, min_digits=1, max_digits=2, operators=["+"], allow_parentheses=False, seed=42, size=10