Skip to content

Commit

Permalink
Merge pull request #41 from rishabhranawat/aiw
Browse files Browse the repository at this point in the history
Add Alice In Wonderland Problem Procedural Dataset
  • Loading branch information
andreaskoepf authored Feb 2, 2025
2 parents 4e9fc4b + 5f6d615 commit 02cfa95
Show file tree
Hide file tree
Showing 2 changed files with 287 additions and 0 deletions.
188 changes: 188 additions & 0 deletions reasoning_gym/logic/aiw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
from dataclasses import dataclass, field
from typing import List, Optional
from enum import Enum
from random import Random
from string import Template

from ..factory import ProceduralDataset, register_dataset


class TaskType(Enum):
"""Defines the type of task for the Alice in Wonderland dataset."""
SIBLINGS = "siblings"
FRIENDS = "friends"
COLLEAGUES = "colleagues" # Added colleagues task


@dataclass
class AliceInWonderlandConfig:
"""Configuration options for the Alice in Wonderland dataset.
Attributes:
male_names (List[str]): List of male names to use in questions.
female_names (List[str]): List of female names to use in questions. Must include 'Alice'.
task_types (List[TaskType]): List of task types to include in dataset.
seed (Optional[int]): Seed for random number generation.
size (int): Number of samples in the dataset.
max_entities (int): Max number of siblings/friends/colleagues in questions.
"""
male_names: List[str] = field(
default_factory=lambda: [
"James", "John", "Robert", "Michael", "William", "David",
"Richard", "Joseph", "Thomas", "Charles", "Bob"
]
)
female_names: List[str] = field(
default_factory=lambda: [
"Mary", "Patricia", "Jennifer", "Linda", "Elizabeth",
"Barbara", "Susan", "Jessica", "Sarah", "Margaret", "Alice"
]
)
task_types: List[TaskType] = field(
default_factory=lambda: [
TaskType.SIBLINGS, TaskType.FRIENDS, TaskType.COLLEAGUES] # Added Colleagues
)
seed: Optional[int] = None
size: int = 10
max_entities: int = 6 # Added max_entities

def validate(self) -> None:
"""Validates the configuration parameters."""
assert len(self.male_names) > 0, "must provide male names"
assert len(self.female_names) > 0, "must provide female names"
assert "Alice" in self.female_names, "'Alice' must be in female names"
assert len(self.task_types) > 0, "must provide at least one task type"
assert self.max_entities > 0, "max_entities must be positive"


class AliceInWonderlandDataset(ProceduralDataset):
"""
A procedural dataset inspired by the "Alice in Wonderland" paper.
The dataset is inspired by the following paper:
@inproceedings{nezhurina2024alice,
title={Alice in Wonderland: Simple Tasks Reveal Severe Generalization and
Basic Reasoning Deficits in State-Of-the-Art Large Language Models},
author={Marianna Nezhurina and Lucia Cipolina-Kun and Mehdi Cherti and
Jenia Jitsev},
booktitle={NeurIPS 2024 Workshop on Scientific Methods for Understanding
Deep Learning},
year={2024},
url={https://openreview.net/forum?id=Mkl7dzjYiW}
}
"""

def __init__(self, config: AliceInWonderlandConfig):
super().__init__(config=config, seed=config.seed, size=config.size)
self.templates = {
TaskType.SIBLINGS: [
Template(
"$female_name has $num_brothers brothers and she also has "
"$num_sisters sisters. How many sisters does "
"$female_name's brother have?"
),
Template(
"$female_name has $num_sisters sisters and she also has "
"$num_brothers brothers. How many sisters does "
"$male_name's brother have?"
),
],
TaskType.FRIENDS: [
Template(
"$female_name has $num_male male friends and she also has "
"$num_female female friends. They all are friends with each "
"other and have no other friends aside. How many female "
"friends does $male_name, a male friend of $female_name, "
"have?"
)
],
TaskType.COLLEAGUES: [ # New colleagues templates
Template(
"$female_name has $num_male_colleagues_alice_circle male colleagues and she also has "
"$num_female_colleagues_alice_circle female colleagues. These are all colleagues that $female_name has. "
"All these mentioned persons around $female_name are colleagues of each other. "
"$male_name has $num_male_colleagues_bob_circle male colleagues "
"and $num_female_colleagues_bob_circle female colleagues in total. "
"All these mentioned persons around $male_name are colleagues of each other. "
"The people in the circle around $male_name do not have "
"other colleagues aside - with the only exception of Matilda. "
"She is colleague of $male_name and she is also colleague of $female_name, "
"being part of $female_name's circle. How many female colleagues does Matilda have?"
),
],
}

def _get_aiw(self, rng: Random) -> dict:
"""Generates a single Alice in Wonderland question.
Args:
rng (Random): Random number generator.
Returns:
dict: A dictionary containing the generated question, the right answer
and a description of the example.
"""
task_type = rng.choice(self.config.task_types)
female_name = rng.choice(self.config.female_names)
male_name = rng.choice(self.config.male_names)

if task_type == TaskType.SIBLINGS:
num_brothers = rng.randint(1, self.config.max_entities)
num_sisters = rng.randint(1, self.config.max_entities)

answer = num_sisters + 1
template = rng.choice(self.templates[TaskType.SIBLINGS])
question = template.substitute(
female_name=female_name,
male_name=male_name,
num_brothers=num_brothers,
num_sisters=num_sisters,
)
elif task_type == TaskType.FRIENDS:
num_male = rng.randint(1, self.config.max_entities)
num_female = rng.randint(1, self.config.max_entities)

answer = num_female + 1
template = rng.choice(self.templates[TaskType.FRIENDS])
question = template.substitute(
female_name=female_name,
male_name=male_name,
num_male=num_male,
num_female=num_female,
)
elif task_type == TaskType.COLLEAGUES:
num_male_colleagues_alice_circle = rng.randint(
1, self.config.max_entities)
num_female_colleagues_alice_circle = rng.randint(
1, self.config.max_entities)
num_male_colleagues_bob_circle = rng.randint(
1, self.config.max_entities)
num_female_colleagues_bob_circle = rng.randint(
1, self.config.max_entities)

answer = num_female_colleagues_alice_circle + 1
template = rng.choice(self.templates[TaskType.COLLEAGUES])
question = template.substitute(
female_name=female_name,
male_name=male_name,
num_male_colleagues_alice_circle=num_male_colleagues_alice_circle,
num_female_colleagues_alice_circle=num_female_colleagues_alice_circle,
num_male_colleagues_bob_circle=num_male_colleagues_bob_circle,
num_female_colleagues_bob_circle=num_female_colleagues_bob_circle
)

return {
"question": question,
"answer": answer,
"metadata": {
"task_type": task_type.value
}
}

def __getitem__(self, idx: int) -> dict:
rng = Random(self.seed + idx)
return self._get_aiw(rng)


register_dataset("aiw", AliceInWonderlandDataset, AliceInWonderlandConfig)
99 changes: 99 additions & 0 deletions tests/test_aiw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import pytest

from reasoning_gym.logic.aiw import AliceInWonderlandConfig, AliceInWonderlandDataset, TaskType


def test_aiw_config_validation():
"""Test that invalid configs raise appropriate errors"""
with pytest.raises(AssertionError):
config = AliceInWonderlandConfig(male_names=[]) # Empty male names
config.validate()

with pytest.raises(AssertionError):
config = AliceInWonderlandConfig(female_names=[]) # Empty female names
config.validate()

with pytest.raises(AssertionError):
config = AliceInWonderlandConfig(
female_names=["Mary", "Jane"]) # No Alice
config.validate()

with pytest.raises(AssertionError):
config = AliceInWonderlandConfig(task_types=[]) # No task types
config.validate()


def test_aiw_deterministic():
"""Test that dataset generates same items with same seed"""
config = AliceInWonderlandConfig(seed=42, size=10)
dataset1 = AliceInWonderlandDataset(config)
dataset2 = AliceInWonderlandDataset(config)

for i in range(len(dataset1)):
assert dataset1[i] == dataset2[i]


def test_aiw_items():
"""Test basic properties of generated items"""
config = AliceInWonderlandConfig(size=50, seed=42)
dataset = AliceInWonderlandDataset(config)

for i in range(len(dataset)):
item = dataset[i]
assert isinstance(item, dict)
assert "question" in item
assert "answer" in item
assert "metadata" in item

# Verify answer is numeric and positive
answer = int(item["answer"])
assert answer > 0

# Verify question contains at least one female name
female_names = config.female_names
assert any(name in item["question"] for name in female_names)

# Verify question task type characteristics
task_type = item["metadata"]["task_type"]
if task_type == TaskType.SIBLINGS.value:
assert any(phrase in item["question"]
for phrase in ["brothers", "sisters"])
elif task_type == TaskType.FRIENDS.value:
assert "friends" in item["question"]
elif task_type == TaskType.COLLEAGUES:
assert "colleagues" in item["question"]


def test_aiw_iteration():
"""Test that iteration works correctly"""
config = AliceInWonderlandConfig(size=5, seed=42)
dataset = AliceInWonderlandDataset(config)

# Test manual iteration
items = []
for item in dataset:
items.append(item)
assert len(items) == config.size

# Test list conversion
items = list(dataset)
assert len(items) == config.size

# Test multiple iterations yield same results
first_items = list(dataset)
second_items = list(dataset)
assert first_items == second_items


def test_aiw_random_ranges():
"""Test that generated numbers stay within expected ranges"""
config = AliceInWonderlandConfig(size=30, seed=42, max_entities=12)
dataset = AliceInWonderlandDataset(config)

for item in dataset:
question = item["question"]
numbers = [int(n) for n in question.split() if n.isdigit()]

# Check all numbers are in reasonable range (1-6 as per implementation)
assert all(
1 <= n <= 12 for n in numbers), f"Numbers out of range: {numbers}"

0 comments on commit 02cfa95

Please sign in to comment.