Skip to content

Commit

Permalink
more dynamic scoring for jumble (#246)
Browse files Browse the repository at this point in the history
  • Loading branch information
Miserlou authored Mar 1, 2025
1 parent 9c581f1 commit 39f151a
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 8 deletions.
35 changes: 28 additions & 7 deletions reasoning_gym/algorithmic/letter_jumble.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,25 @@ def __getitem__(self, idx: int) -> dict:
},
}

def partial(self, expected_answer, model_answer):
expected_words = expected_answer.split()
model_words = model_answer.split()

# Each word in the expected answer is worth an equal fraction of 1.0
total_words = len(expected_words)
score_per_word = 1.0 / total_words if total_words else 0

# Calculate scores word by word
scores = []
for i, word in enumerate(expected_words):
# Check if the corresponding word exists in model_answer and matches exactly
if i < len(model_words) and word == model_words[i]:
scores.append(score_per_word)
else:
scores.append(0.0)

return min(1.0, sum(scores))

def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
"""Determine if the solution provided solves this task.
Expand All @@ -136,16 +155,18 @@ def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
float: The computed score between 0.0 and 1.0.
"""

oracle_answer = entry["answer"].strip()
if not answer:
return 0.0

oracle_answer = entry["answer"].strip().lower()
if answer:
answer = answer.strip()
answer = answer.strip().lower()
if answer == oracle_answer:
return 1.0
elif answer.lower() == oracle_answer.lower():
return 0.5
return 1.0 # Perfect score!
else:
return 0.01
return 0.0
partial_score = self.partial(oracle_answer, answer)
return partial_score
return 0.01


register_dataset("letter_jumble", LetterJumbleDataset, LetterJumbleConfig)
6 changes: 5 additions & 1 deletion tests/test_letter_jumble.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,12 @@ def test_letter_jumble_dataset_items():

# Test the scoring
assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0
assert dataset.score_answer(answer="gibberish", entry=item) == 0.01
assert dataset.score_answer(answer=None, entry=item) == 0.0
answera = item["answer"].split(" ")
answera[0] = "flippityfloop"
answera[1] = "doopadoopadoop"
answerf = " ".join(answera)
assert 0.01 <= dataset.score_answer(answer=answerf, entry=item) <= 1.0


def test_letter_jumble_iteration():
Expand Down

0 comments on commit 39f151a

Please sign in to comment.