Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add palindrome_generation #39

Merged
merged 3 commits into from
Feb 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions reasoning_gym/algorithmic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .letter_jumble import LetterJumbleConfig, LetterJumbleDataset
from .number_filtering import NumberFilteringConfig, NumberFilteringDataset
from .number_sorting import NumberSortingConfig, NumberSortingDataset
from .palindrome_generation import PalindromeConfig, PalindromeDataset
from .sentence_reordering import SentenceReorderingConfig, SentenceReorderingDataset
from .spell_backward import SpellBackwardConfig, SpellBackwardDataset
from .word_ladder import WordLadderConfig, WordLadderDataset
Expand Down Expand Up @@ -42,4 +43,6 @@
"TextTransformation",
"WordLadderConfig",
"WordLadderDataset",
"PalindromeConfig",
"PalindromeDataset",
]
115 changes: 115 additions & 0 deletions reasoning_gym/algorithmic/palindrome_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import random
import string
from dataclasses import dataclass
from typing import Any, Dict, Optional

from ..factory import ProceduralDataset, register_dataset


@dataclass
class PalindromeConfig:
"""
Configuration for the palindrome task.

- min_length: Minimum length of the palindrome.
- max_length: Maximum length of the palindrome.
- seed: Optional seed for reproducibility.
- size: Number of palindrome samples in the virtual dataset.
"""

min_length: int = 3
max_length: int = 10
seed: Optional[int] = None
size: int = 50

def validate(self) -> None:
"""Validate configuration parameters."""
assert self.min_length >= 1, "min_length must be >= 1"
assert self.max_length >= self.min_length, "max_length must be >= min_length"


class PalindromeDataset(ProceduralDataset):
"""
Generates a set of letters that can be assembled into a palindrome.
"""

def __init__(self, config: PalindromeConfig):
super().__init__(config=config, seed=config.seed, size=config.size)

def __getitem__(self, idx: int) -> dict:
"""
Generate a single palindrome task.

Returns:
dict with:
- "question": Set of letters to form a palindrome.
- "answer": A correct palindrome.
- "metadata": Includes letter set and generated palindrome.
"""
rng = random.Random(self.seed + idx)
length = rng.randint(self.config.min_length, self.config.max_length)
letters = self._generate_palindrome_letters(rng, length)
scrambled_letters = rng.sample(letters, len(letters)) # Scramble the order
palindrome = self._assemble_palindrome(letters)

question_str = (
"Rearrange these letters to form a palindrome. A palindrome is a word, phrase, or sequence that reads the same forward and backward.\n\n"
"For example, if the letters are: a, a, b — a valid palindrome is: aba.\n\n"
f"Your letters: {', '.join(scrambled_letters)}\n\n"
"What palindrome can you form from these letters?"
)

return {
"question": question_str,
"answer": palindrome,
"metadata": {
"letters": scrambled_letters,
"generated_palindrome": palindrome,
},
}

def _generate_palindrome_letters(self, rng: random.Random, length: int) -> list[str]:
"""Generate a set of letters that can form a palindrome."""
half_length = length // 2
letters = rng.choices(string.ascii_lowercase, k=half_length)
if length % 2 == 1:
middle_letter = rng.choice(string.ascii_lowercase)
return letters + [middle_letter] + letters[::-1]
return letters + letters[::-1]

def _assemble_palindrome(self, letters: list[str]) -> str:
"""Return the palindrome string from the letter set."""
return "".join(letters)

def score_answer(self, answer: Optional[str], metadata: Dict[str, Any]) -> float:
"""Determine if the solution provided is a valid palindrome.
The answer is expected to be a single string

Expected behavior:
- Correct answer (palindrome with only correct letters in the correct quantities) gives 1.0
- An answer that is a palindrome, but not with the same letters as provided, gives 0.05
- An answer that is a string, but not a palindrome gives 0.02
- An empty string gives 0.01.
- None gives 0.0.
"""
if answer is None or not isinstance(answer, str):
return 0.0 # No answer given

if answer == "":
return 0.01

answer = answer.strip().lower()
expected_letters = metadata["letters"]

# Check if the answer is a palindrome
if answer != answer[::-1]:
return 0.02

# Check if answer contains the same letters as provided (ignoring order)
if sorted(answer) != sorted(expected_letters):
return 0.05

return 1.0 # Correct solution


register_dataset("palindrome", PalindromeDataset, PalindromeConfig)
92 changes: 92 additions & 0 deletions tests/test_palindrome.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import pytest

from reasoning_gym.algorithmic.palindrome_generation import PalindromeConfig, PalindromeDataset


def test_palindrome_config_validation():
"""Test that invalid configs raise appropriate errors"""
with pytest.raises(AssertionError):
config = PalindromeConfig(min_length=0) # Too short
config.validate()

with pytest.raises(AssertionError):
config = PalindromeConfig(min_length=5, max_length=3) # Invalid range
config.validate()


def test_palindrome_deterministic():
"""Test that dataset generates same items with same seed"""
config = PalindromeConfig(seed=42, size=10)
dataset1 = PalindromeDataset(config)
dataset2 = PalindromeDataset(config)

for i in range(len(dataset1)):
assert dataset1[i] == dataset2[i]


def test_palindrome_items():
"""Test basic properties of generated items"""
config = PalindromeConfig(min_length=3, max_length=7, size=10, seed=42)
dataset = PalindromeDataset(config)

for item in dataset:
assert isinstance(item, dict)
assert "question" in item
assert "answer" in item
assert "metadata" in item

# Check metadata contains required fields
assert "letters" in item["metadata"]
assert "generated_palindrome" in item["metadata"]

# Verify answer is a palindrome
palindrome = item["answer"]
assert palindrome == palindrome[::-1], f"{palindrome} is not a palindrome"


def test_palindrome_randomization():
"""Test letter randomization in the question"""
config = PalindromeConfig(min_length=4, max_length=4, size=10, seed=42)
dataset = PalindromeDataset(config)

for item in dataset:
letters = item["metadata"]["letters"]
palindrome = item["metadata"]["generated_palindrome"]

# Ensure the same letters are present but in different order
assert sorted(letters) == sorted(palindrome)


def test_score_answer():
"""Test the scoring mechanism for palindrome answers.

Expected behavior:
- Correct answer (palindrome with only correct letters in the correct quantities) gives 1.0
- An answer that is a palindrome, but not with the same letters as provided, gives 0.05
- An answer that is a string, but not a palindrome gives 0.02
- An empty string gives 0.01.
- None gives 0.0.
"""
config = PalindromeConfig(min_length=4, max_length=6, size=10, seed=42)
dataset = PalindromeDataset(config)

for item in dataset:
correct_answer = item["answer"]
metadata = item["metadata"]

# Correct answer should score 1.0
assert dataset.score_answer(correct_answer, metadata) == 1.0

# Incorrect answer (palindrome, but not correct one) should score 0.05
pal_letters = "racecar" if "racecar" != correct_answer else "aba"
assert dataset.score_answer(pal_letters, metadata) == 0.05

# Incorrect answer (not palindrome) should score 0.02
wrong_letters = "abcd" if "abcd" != correct_answer else "efgh"
assert dataset.score_answer(wrong_letters, metadata) == 0.02

# Empty String input should score 0.01
assert dataset.score_answer("", metadata) == 0.01

# Empty input should score 0.0
assert dataset.score_answer(None, metadata) == 0.0