Skip to content

Commit

Permalink
remove strip from ProceduralDataset::core score_answer(), strip in ex…
Browse files Browse the repository at this point in the history
…tract answer (optional, default=True)
  • Loading branch information
andreaskoepf committed Mar 2, 2025
1 parent e71d2a9 commit 7f26542
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 10 deletions.
3 changes: 0 additions & 3 deletions reasoning_gym/arithmetic/chain_sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,6 @@ def _generate_task(self, rng: random.Random, num_terms: int, min_value: int, max
expression = " ".join(expression_parts)
return expression, result

def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
return utils.compute_decimal_reward(answer, oracle_answer=entry["answer"])


class ChainSumCurriculum(BaseCurriculum):
def __init__(self):
Expand Down
5 changes: 2 additions & 3 deletions reasoning_gym/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,9 @@ def __getitem__(self, idx: int) -> dict[str, Any]:

def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""Overwrite this method in derived classes if a single oracle answer is not available."""
oracle_answer = entry["answer"].strip()
oracle_answer = entry["answer"]
reward = 0.0
if answer is not None and len(answer) > 0:
answer = answer.strip()
if isinstance(answer, str) and len(answer) > 0:
if answer == oracle_answer:
reward = 1.0
elif oracle_answer in answer:
Expand Down
7 changes: 5 additions & 2 deletions reasoning_gym/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
}


def extract_answer(completion: str, tag_name: str = "answer") -> Optional[str]:
def extract_answer(completion: str, tag_name: str = "answer", strip: bool = True) -> Optional[str]:
regex = f"<{tag_name}>\\s?(.*?)\\s?</{tag_name}>"
matches = list(
re.finditer(
Expand All @@ -30,7 +30,10 @@ def extract_answer(completion: str, tag_name: str = "answer") -> Optional[str]:
)
if not matches:
return None
return matches[-1].group(1)
answer = matches[-1].group(1)
if strip:
answer = answer.strip()
return answer


def format_number(num: Union[int, float], max_decimals: int = 2) -> str:
Expand Down
4 changes: 2 additions & 2 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ def test_reseeding_dataset_iteration():
def test_extract_answer():
assert extract_answer("This is a text. <final_answer>1234</final_answer>", tag_name="final_answer") == "1234"

# ignore single whitespae
assert extract_answer("This is a text. <answer>\n1234 </answer>", tag_name="answer") == "1234"
# ignore whitespaces
assert extract_answer("This is a text. <answer>\n1234 </answer>", tag_name="answer", strip=True) == "1234"

config = BasicArithmeticDatasetConfig(
min_terms=2, max_terms=3, min_digits=1, max_digits=2, operators=["+"], allow_parentheses=False, seed=42, size=10
Expand Down

0 comments on commit 7f26542

Please sign in to comment.