Skip to content

Commit

Permalink
Add score_answer method to word_ladder
Browse files Browse the repository at this point in the history
  • Loading branch information
Adefioye committed Feb 9, 2025
1 parent 1f9d9d2 commit ce63878
Show file tree
Hide file tree
Showing 4 changed files with 370,170 additions and 3 deletions.
60 changes: 58 additions & 2 deletions reasoning_gym/algorithmic/word_ladder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from collections import deque
from dataclasses import dataclass
from random import Random
from typing import Dict, List, Optional, Set, Tuple
from typing import Dict, List, Optional, Set, Tuple, Any
from pathlib import Path

from reasoning_gym.data import read_data_file
from reasoning_gym.data import read_data_file, read_json_file

from ..factory import ProceduralDataset, register_dataset

Expand All @@ -20,6 +21,7 @@ class WordLadderConfig:
max_chain_length: int = -1 # Set to -1 for shortest path or a max
seed: Optional[int] = None
size: int = 500
dictionary_file_path: str = "words_dictionary.json"

def validate(self) -> None:
"""Validate configuration parameters"""
Expand Down Expand Up @@ -64,6 +66,7 @@ def __init__(self, config: WordLadderConfig):
self.config = config
self.word_sets = {}
self.word_graphs = {}
self._word_dict = None # A large list of dictionary words to validate words against

# Load words from CSV
self.word_sets = self._load_words_from_csv(
Expand Down Expand Up @@ -116,6 +119,23 @@ def _load_words_from_csv(cls, min_length: int = 3, max_length: int = 5) -> Dict[
raise ValueError(f"No valid words found for length {length}")

return word_sets

def _load_word_dictionary(self, file_path: str) -> Dict[str, Any]:
"""Load word dictionary from JSON file"""
return read_json_file(file_path)

@property
def word_dict(self) -> Dict[str, Any]:
"""Lazy loading of word dictionary"""
if self._word_dict is None:
self._word_dict = self._load_word_dictionary(self.config.dictionary_file_path)
return self._word_dict

# Lazy loading of word dictionary
def load_word_dictionary(self, file_path: str) -> Set[str]:
if not hasattr(self, "_word_dict"):
self._word_dict = self._load_word_dictionary(file_path)
return self._word_dict

def _get_neighbors(self, word: str, word_set: Set[str]) -> Set[str]:
"""Get neighbors from either precomputed graph or by computing on demand"""
Expand Down Expand Up @@ -219,6 +239,42 @@ def __getitem__(self, idx: int) -> dict:
"answer": ",".join(path),
"metadata": {"start_word": start, "end_word": end, "word_length": length, "chain_length": len(path)},
}

def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
_ = entry["answer"].upper().strip()
answer = answer.upper().strip() if answer is not None else None
word_dict = self.word_dict

# NOTE: I am assuming that answer is a comma-separated string of words and that if it exactly matches the oracle answer
# it is correct and gets a reward of 1.0.

# Check for two conditions:
# 1. Ensure all words in the answer are valid
# 2. Ensure every changed word is a single letter change from the previous word
is_all_words_valid = all(word in word_dict for word in answer.split(","))
words = answer.split(",")
total_words = len(words)
single_letter_change_words = 0
for i in range(1, len(words)):
if sum(1 for a, b in zip(words[i - 1], words[i]) if a != b) == 1:
single_letter_change_words += 1
# Number of compairsons should be total_words - 1
is_all_single_letter_change = single_letter_change_words == total_words - 1

reward = 0.0
if answer is not None:
if is_all_words_valid and is_all_single_letter_change:
reward = 1.0
elif is_all_words_valid:
reward = 0.5
elif is_all_single_letter_change:
reward = 0.5
elif single_letter_change_words > 0:
reward = single_letter_change_words / (total_words - 1)
else:
reward = 0.01

return reward


register_dataset("word_ladder", WordLadderDataset, WordLadderConfig)
8 changes: 8 additions & 0 deletions reasoning_gym/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from importlib import resources
from pathlib import Path
from typing import Any, Dict
import json


def get_data_file_path(filename: str) -> Path:
Expand Down Expand Up @@ -35,5 +37,11 @@ def read_data_file(filename: str) -> str:
"""
return resources.files("reasoning_gym.data").joinpath(filename).read_text()

def read_json_file(file_name: str) -> Dict[str, Any]:
"""Read a json file from a file path into a dictionary."""
file_path = get_data_file_path(file_name)
with open(file_path, "r") as file:
return json.load(file)


__all__ = ["get_data_file_path", "read_data_file"]
Loading

0 comments on commit ce63878

Please sign in to comment.