From 869624eb82887b52ba441630f669ba5b3987482c Mon Sep 17 00:00:00 2001 From: abdulhakeem Date: Sun, 9 Feb 2025 00:11:30 -0600 Subject: [PATCH] Use oracle_answer to compute reward --- reasoning_gym/algorithmic/word_ladder.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/reasoning_gym/algorithmic/word_ladder.py b/reasoning_gym/algorithmic/word_ladder.py index e103184c..b7f371b6 100644 --- a/reasoning_gym/algorithmic/word_ladder.py +++ b/reasoning_gym/algorithmic/word_ladder.py @@ -241,15 +241,16 @@ def __getitem__(self, idx: int) -> dict: } def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float: - _ = entry["answer"].upper().strip() + oracle_answer = entry["answer"].upper().strip() answer = answer.upper().strip() if answer is not None else None + is_answer_correct = set(answer.split(",")) == set(oracle_answer.split(",")) word_dict = self.word_dict # NOTE: I am assuming that answer is a comma-separated string of words and that if it exactly matches the oracle answer # it is correct and gets a reward of 1.0. # Check for two conditions: - # 1. Ensure all words in the answer are valid + # 1. Ensure all words in the answer are valid (Assuming all generated words would be found in the word_dict) # 2. Ensure every changed word is a single letter change from the previous word is_all_words_valid = all(word in word_dict for word in answer.split(",")) words = answer.split(",") @@ -258,12 +259,12 @@ def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float: for i in range(1, len(words)): if sum(1 for a, b in zip(words[i - 1], words[i]) if a != b) == 1: single_letter_change_words += 1 - # Number of compairsons should be total_words - 1 - is_all_single_letter_change = single_letter_change_words == total_words - 1 + # Number of comparisons should be total_words - 1 + is_all_single_letter_change = single_letter_change_words == (total_words - 1) reward = 0.0 if answer is not None: - if is_all_words_valid and is_all_single_letter_change: + if is_answer_correct: reward = 1.0 elif is_all_words_valid: reward = 0.5