diff --git a/reasoning_gym/arithmetic/chain_sum.py b/reasoning_gym/arithmetic/chain_sum.py
index 2072983c..8464068f 100644
--- a/reasoning_gym/arithmetic/chain_sum.py
+++ b/reasoning_gym/arithmetic/chain_sum.py
@@ -110,9 +110,6 @@ def _generate_task(self, rng: random.Random, num_terms: int, min_value: int, max
expression = " ".join(expression_parts)
return expression, result
- def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
- return utils.compute_decimal_reward(answer, oracle_answer=entry["answer"])
-
class ChainSumCurriculum(BaseCurriculum):
def __init__(self):
diff --git a/reasoning_gym/dataset.py b/reasoning_gym/dataset.py
index 17bb00e4..d727c863 100644
--- a/reasoning_gym/dataset.py
+++ b/reasoning_gym/dataset.py
@@ -62,10 +62,9 @@ def __getitem__(self, idx: int) -> dict[str, Any]:
def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""Overwrite this method in derived classes if a single oracle answer is not available."""
- oracle_answer = entry["answer"].strip()
+ oracle_answer = entry["answer"]
reward = 0.0
- if answer is not None and len(answer) > 0:
- answer = answer.strip()
+ if isinstance(answer, str) and len(answer) > 0:
if answer == oracle_answer:
reward = 1.0
elif oracle_answer in answer:
diff --git a/reasoning_gym/utils.py b/reasoning_gym/utils.py
index 70d4a343..33b4fe35 100644
--- a/reasoning_gym/utils.py
+++ b/reasoning_gym/utils.py
@@ -19,7 +19,7 @@
}
-def extract_answer(completion: str, tag_name: str = "answer") -> Optional[str]:
+def extract_answer(completion: str, tag_name: str = "answer", strip: bool = True) -> Optional[str]:
regex = f"<{tag_name}>\\s?(.*?)\\s?{tag_name}>"
matches = list(
re.finditer(
@@ -30,7 +30,10 @@ def extract_answer(completion: str, tag_name: str = "answer") -> Optional[str]:
)
if not matches:
return None
- return matches[-1].group(1)
+ answer = matches[-1].group(1)
+ if strip:
+ answer = answer.strip()
+ return answer
def format_number(num: Union[int, float], max_decimals: int = 2) -> str:
diff --git a/tests/test_dataset.py b/tests/test_dataset.py
index f321ad2a..56c89593 100644
--- a/tests/test_dataset.py
+++ b/tests/test_dataset.py
@@ -44,8 +44,8 @@ def test_reseeding_dataset_iteration():
def test_extract_answer():
assert extract_answer("This is a text. 1234", tag_name="final_answer") == "1234"
- # ignore single whitespae
- assert extract_answer("This is a text. \n1234 ", tag_name="answer") == "1234"
+ # ignore whitespaces
+ assert extract_answer("This is a text. \n1234 ", tag_name="answer", strip=True) == "1234"
config = BasicArithmeticDatasetConfig(
min_terms=2, max_terms=3, min_digits=1, max_digits=2, operators=["+"], allow_parentheses=False, seed=42, size=10