diff --git a/reasoning_gym/algorithmic/__init__.py b/reasoning_gym/algorithmic/__init__.py index 1b509970..8224a019 100644 --- a/reasoning_gym/algorithmic/__init__.py +++ b/reasoning_gym/algorithmic/__init__.py @@ -12,6 +12,7 @@ from .letter_jumble import LetterJumbleConfig, LetterJumbleDataset from .number_filtering import NumberFilteringConfig, NumberFilteringDataset from .number_sorting import NumberSortingConfig, NumberSortingDataset +from .palindrome_generation import PalindromeConfig, PalindromeDataset from .sentence_reordering import SentenceReorderingConfig, SentenceReorderingDataset from .spell_backward import SpellBackwardConfig, SpellBackwardDataset from .word_ladder import WordLadderConfig, WordLadderDataset @@ -42,4 +43,6 @@ "TextTransformation", "WordLadderConfig", "WordLadderDataset", + "PalindromeConfig", + "PalindromeDataset", ] diff --git a/reasoning_gym/algorithmic/palindrome_generation.py b/reasoning_gym/algorithmic/palindrome_generation.py new file mode 100644 index 00000000..e7ab21e3 --- /dev/null +++ b/reasoning_gym/algorithmic/palindrome_generation.py @@ -0,0 +1,115 @@ +import random +import string +from dataclasses import dataclass +from typing import Any, Dict, Optional + +from ..factory import ProceduralDataset, register_dataset + + +@dataclass +class PalindromeConfig: + """ + Configuration for the palindrome task. + + - min_length: Minimum length of the palindrome. + - max_length: Maximum length of the palindrome. + - seed: Optional seed for reproducibility. + - size: Number of palindrome samples in the virtual dataset. + """ + + min_length: int = 3 + max_length: int = 10 + seed: Optional[int] = None + size: int = 50 + + def validate(self) -> None: + """Validate configuration parameters.""" + assert self.min_length >= 1, "min_length must be >= 1" + assert self.max_length >= self.min_length, "max_length must be >= min_length" + + +class PalindromeDataset(ProceduralDataset): + """ + Generates a set of letters that can be assembled into a palindrome. + """ + + def __init__(self, config: PalindromeConfig): + super().__init__(config=config, seed=config.seed, size=config.size) + + def __getitem__(self, idx: int) -> dict: + """ + Generate a single palindrome task. + + Returns: + dict with: + - "question": Set of letters to form a palindrome. + - "answer": A correct palindrome. + - "metadata": Includes letter set and generated palindrome. + """ + rng = random.Random(self.seed + idx) + length = rng.randint(self.config.min_length, self.config.max_length) + letters = self._generate_palindrome_letters(rng, length) + scrambled_letters = rng.sample(letters, len(letters)) # Scramble the order + palindrome = self._assemble_palindrome(letters) + + question_str = ( + "Rearrange these letters to form a palindrome. A palindrome is a word, phrase, or sequence that reads the same forward and backward.\n\n" + "For example, if the letters are: a, a, b — a valid palindrome is: aba.\n\n" + f"Your letters: {', '.join(scrambled_letters)}\n\n" + "What palindrome can you form from these letters?" + ) + + return { + "question": question_str, + "answer": palindrome, + "metadata": { + "letters": scrambled_letters, + "generated_palindrome": palindrome, + }, + } + + def _generate_palindrome_letters(self, rng: random.Random, length: int) -> list[str]: + """Generate a set of letters that can form a palindrome.""" + half_length = length // 2 + letters = rng.choices(string.ascii_lowercase, k=half_length) + if length % 2 == 1: + middle_letter = rng.choice(string.ascii_lowercase) + return letters + [middle_letter] + letters[::-1] + return letters + letters[::-1] + + def _assemble_palindrome(self, letters: list[str]) -> str: + """Return the palindrome string from the letter set.""" + return "".join(letters) + + def score_answer(self, answer: Optional[str], metadata: Dict[str, Any]) -> float: + """Determine if the solution provided is a valid palindrome. + The answer is expected to be a single string + + Expected behavior: + - Correct answer (palindrome with only correct letters in the correct quantities) gives 1.0 + - An answer that is a palindrome, but not with the same letters as provided, gives 0.05 + - An answer that is a string, but not a palindrome gives 0.02 + - An empty string gives 0.01. + - None gives 0.0. + """ + if answer is None or not isinstance(answer, str): + return 0.0 # No answer given + + if answer == "": + return 0.01 + + answer = answer.strip().lower() + expected_letters = metadata["letters"] + + # Check if the answer is a palindrome + if answer != answer[::-1]: + return 0.02 + + # Check if answer contains the same letters as provided (ignoring order) + if sorted(answer) != sorted(expected_letters): + return 0.05 + + return 1.0 # Correct solution + + +register_dataset("palindrome", PalindromeDataset, PalindromeConfig) diff --git a/tests/test_palindrome.py b/tests/test_palindrome.py new file mode 100644 index 00000000..e2844267 --- /dev/null +++ b/tests/test_palindrome.py @@ -0,0 +1,92 @@ +import pytest + +from reasoning_gym.algorithmic.palindrome_generation import PalindromeConfig, PalindromeDataset + + +def test_palindrome_config_validation(): + """Test that invalid configs raise appropriate errors""" + with pytest.raises(AssertionError): + config = PalindromeConfig(min_length=0) # Too short + config.validate() + + with pytest.raises(AssertionError): + config = PalindromeConfig(min_length=5, max_length=3) # Invalid range + config.validate() + + +def test_palindrome_deterministic(): + """Test that dataset generates same items with same seed""" + config = PalindromeConfig(seed=42, size=10) + dataset1 = PalindromeDataset(config) + dataset2 = PalindromeDataset(config) + + for i in range(len(dataset1)): + assert dataset1[i] == dataset2[i] + + +def test_palindrome_items(): + """Test basic properties of generated items""" + config = PalindromeConfig(min_length=3, max_length=7, size=10, seed=42) + dataset = PalindromeDataset(config) + + for item in dataset: + assert isinstance(item, dict) + assert "question" in item + assert "answer" in item + assert "metadata" in item + + # Check metadata contains required fields + assert "letters" in item["metadata"] + assert "generated_palindrome" in item["metadata"] + + # Verify answer is a palindrome + palindrome = item["answer"] + assert palindrome == palindrome[::-1], f"{palindrome} is not a palindrome" + + +def test_palindrome_randomization(): + """Test letter randomization in the question""" + config = PalindromeConfig(min_length=4, max_length=4, size=10, seed=42) + dataset = PalindromeDataset(config) + + for item in dataset: + letters = item["metadata"]["letters"] + palindrome = item["metadata"]["generated_palindrome"] + + # Ensure the same letters are present but in different order + assert sorted(letters) == sorted(palindrome) + + +def test_score_answer(): + """Test the scoring mechanism for palindrome answers. + + Expected behavior: + - Correct answer (palindrome with only correct letters in the correct quantities) gives 1.0 + - An answer that is a palindrome, but not with the same letters as provided, gives 0.05 + - An answer that is a string, but not a palindrome gives 0.02 + - An empty string gives 0.01. + - None gives 0.0. + """ + config = PalindromeConfig(min_length=4, max_length=6, size=10, seed=42) + dataset = PalindromeDataset(config) + + for item in dataset: + correct_answer = item["answer"] + metadata = item["metadata"] + + # Correct answer should score 1.0 + assert dataset.score_answer(correct_answer, metadata) == 1.0 + + # Incorrect answer (palindrome, but not correct one) should score 0.05 + pal_letters = "racecar" if "racecar" != correct_answer else "aba" + assert dataset.score_answer(pal_letters, metadata) == 0.05 + + # Incorrect answer (not palindrome) should score 0.02 + wrong_letters = "abcd" if "abcd" != correct_answer else "efgh" + assert dataset.score_answer(wrong_letters, metadata) == 0.02 + + # Empty String input should score 0.01 + assert dataset.score_answer("", metadata) == 0.01 + + # Empty input should score 0.0 + assert dataset.score_answer(None, metadata) == 0.0