diff --git a/reasoning_gym/algorithmic/palindrome_generation.py b/reasoning_gym/algorithmic/palindrome_generation.py index 0ae0cb11..e7ab21e3 100644 --- a/reasoning_gym/algorithmic/palindrome_generation.py +++ b/reasoning_gym/algorithmic/palindrome_generation.py @@ -1,7 +1,7 @@ import random import string from dataclasses import dataclass -from typing import Optional +from typing import Any, Dict, Optional from ..factory import ProceduralDataset, register_dataset @@ -53,10 +53,10 @@ def __getitem__(self, idx: int) -> dict: palindrome = self._assemble_palindrome(letters) question_str = ( - f"Rearrange these letters to form a palindrome (a word, phrase, or sequence that remains the same in reverse): {', '.join(scrambled_letters)}\n\n" - "Example format:\n" - "racecar\n\n" - "What is your palindrome?" + "Rearrange these letters to form a palindrome. A palindrome is a word, phrase, or sequence that reads the same forward and backward.\n\n" + "For example, if the letters are: a, a, b — a valid palindrome is: aba.\n\n" + f"Your letters: {', '.join(scrambled_letters)}\n\n" + "What palindrome can you form from these letters?" ) return { @@ -81,5 +81,35 @@ def _assemble_palindrome(self, letters: list[str]) -> str: """Return the palindrome string from the letter set.""" return "".join(letters) + def score_answer(self, answer: Optional[str], metadata: Dict[str, Any]) -> float: + """Determine if the solution provided is a valid palindrome. + The answer is expected to be a single string + + Expected behavior: + - Correct answer (palindrome with only correct letters in the correct quantities) gives 1.0 + - An answer that is a palindrome, but not with the same letters as provided, gives 0.05 + - An answer that is a string, but not a palindrome gives 0.02 + - An empty string gives 0.01. + - None gives 0.0. + """ + if answer is None or not isinstance(answer, str): + return 0.0 # No answer given + + if answer == "": + return 0.01 + + answer = answer.strip().lower() + expected_letters = metadata["letters"] + + # Check if the answer is a palindrome + if answer != answer[::-1]: + return 0.02 + + # Check if answer contains the same letters as provided (ignoring order) + if sorted(answer) != sorted(expected_letters): + return 0.05 + + return 1.0 # Correct solution + register_dataset("palindrome", PalindromeDataset, PalindromeConfig) diff --git a/tests/test_palindrome.py b/tests/test_palindrome.py index b5bb18f2..e2844267 100644 --- a/tests/test_palindrome.py +++ b/tests/test_palindrome.py @@ -55,3 +55,38 @@ def test_palindrome_randomization(): # Ensure the same letters are present but in different order assert sorted(letters) == sorted(palindrome) + + +def test_score_answer(): + """Test the scoring mechanism for palindrome answers. + + Expected behavior: + - Correct answer (palindrome with only correct letters in the correct quantities) gives 1.0 + - An answer that is a palindrome, but not with the same letters as provided, gives 0.05 + - An answer that is a string, but not a palindrome gives 0.02 + - An empty string gives 0.01. + - None gives 0.0. + """ + config = PalindromeConfig(min_length=4, max_length=6, size=10, seed=42) + dataset = PalindromeDataset(config) + + for item in dataset: + correct_answer = item["answer"] + metadata = item["metadata"] + + # Correct answer should score 1.0 + assert dataset.score_answer(correct_answer, metadata) == 1.0 + + # Incorrect answer (palindrome, but not correct one) should score 0.05 + pal_letters = "racecar" if "racecar" != correct_answer else "aba" + assert dataset.score_answer(pal_letters, metadata) == 0.05 + + # Incorrect answer (not palindrome) should score 0.02 + wrong_letters = "abcd" if "abcd" != correct_answer else "efgh" + assert dataset.score_answer(wrong_letters, metadata) == 0.02 + + # Empty String input should score 0.01 + assert dataset.score_answer("", metadata) == 0.01 + + # Empty input should score 0.0 + assert dataset.score_answer(None, metadata) == 0.0