Skip to content

Commit

Permalink
Use oracle_answer to compute reward
Browse files Browse the repository at this point in the history
  • Loading branch information
Adefioye committed Feb 9, 2025
1 parent ce63878 commit 869624e
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions reasoning_gym/algorithmic/word_ladder.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,15 +241,16 @@ def __getitem__(self, idx: int) -> dict:
}

def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
_ = entry["answer"].upper().strip()
oracle_answer = entry["answer"].upper().strip()
answer = answer.upper().strip() if answer is not None else None
is_answer_correct = set(answer.split(",")) == set(oracle_answer.split(","))
word_dict = self.word_dict

# NOTE: I am assuming that answer is a comma-separated string of words and that if it exactly matches the oracle answer
# it is correct and gets a reward of 1.0.

# Check for two conditions:
# 1. Ensure all words in the answer are valid
# 1. Ensure all words in the answer are valid (Assuming all generated words would be found in the word_dict)
# 2. Ensure every changed word is a single letter change from the previous word
is_all_words_valid = all(word in word_dict for word in answer.split(","))
words = answer.split(",")
Expand All @@ -258,12 +259,12 @@ def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
for i in range(1, len(words)):
if sum(1 for a, b in zip(words[i - 1], words[i]) if a != b) == 1:
single_letter_change_words += 1
# Number of compairsons should be total_words - 1
is_all_single_letter_change = single_letter_change_words == total_words - 1
# Number of comparisons should be total_words - 1
is_all_single_letter_change = single_letter_change_words == (total_words - 1)

reward = 0.0
if answer is not None:
if is_all_words_valid and is_all_single_letter_change:
if is_answer_correct:
reward = 1.0
elif is_all_words_valid:
reward = 0.5
Expand Down

0 comments on commit 869624e

Please sign in to comment.