Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
joenorton committed Feb 1, 2025
1 parent f75d9a2 commit d0d84ae
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
5 changes: 4 additions & 1 deletion reasoning_gym/algorithmic/palindrome_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from ..factory import ProceduralDataset, register_dataset


@dataclass
class PalindromeConfig:
"""
Expand All @@ -15,6 +16,7 @@ class PalindromeConfig:
- seed: Optional seed for reproducibility.
- size: Number of palindrome samples in the virtual dataset.
"""

min_length: int = 3
max_length: int = 10
seed: Optional[int] = None
Expand All @@ -30,6 +32,7 @@ class PalindromeDataset(ProceduralDataset):
"""
Generates a set of letters that can be assembled into a palindrome.
"""

def __init__(self, config: PalindromeConfig):
super().__init__(config=config, seed=config.seed, size=config.size)

Expand All @@ -50,7 +53,7 @@ def __getitem__(self, idx: int) -> dict:
palindrome = self._assemble_palindrome(letters)

question_str = (
f"Rearrange these letters to form a palindrome (a word, phrase, or sequence that remains the same in reverse): {', '.join(scrambled_letters)}\n\n"
f"Rearrange these letters to form a palindrome (a word, phrase, or sequence that remains the same in reverse): {', '.join(scrambled_letters)}\n\n"
"Example format:\n"
"racecar\n\n"
"What is your palindrome?"
Expand Down
10 changes: 7 additions & 3 deletions tests/test_palindrome.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,18 @@

from reasoning_gym.algorithmic.palindrome_generation import PalindromeConfig, PalindromeDataset


def test_palindrome_config_validation():
"""Test that invalid configs raise appropriate errors"""
with pytest.raises(AssertionError):
config = PalindromeConfig(min_length=0) # Too short
config.validate()

with pytest.raises(AssertionError):
config = PalindromeConfig(min_length=5, max_length=3) # Invalid range
config.validate()


def test_palindrome_deterministic():
"""Test that dataset generates same items with same seed"""
config = PalindromeConfig(seed=42, size=10)
Expand All @@ -21,6 +23,7 @@ def test_palindrome_deterministic():
for i in range(len(dataset1)):
assert dataset1[i] == dataset2[i]


def test_palindrome_items():
"""Test basic properties of generated items"""
config = PalindromeConfig(min_length=3, max_length=7, size=10, seed=42)
Expand All @@ -40,14 +43,15 @@ def test_palindrome_items():
palindrome = item["answer"]
assert palindrome == palindrome[::-1], f"{palindrome} is not a palindrome"


def test_palindrome_randomization():
"""Test letter randomization in the question"""
config = PalindromeConfig(min_length=4, max_length=4, size=10, seed=42)
dataset = PalindromeDataset(config)

for item in dataset:
letters = item["metadata"]["letters"]
palindrome = item["metadata"]["generated_palindrome"]

# Ensure the same letters are present but in different order
assert sorted(letters) == sorted(palindrome)

0 comments on commit d0d84ae

Please sign in to comment.