Skip to content

Commit

Permalink
post merge lint
Browse files Browse the repository at this point in the history
  • Loading branch information
andreaskoepf committed Feb 2, 2025
1 parent 02cfa95 commit f396d3d
Show file tree
Hide file tree
Showing 6 changed files with 198 additions and 64 deletions.
47 changes: 46 additions & 1 deletion GALLERY.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ This gallery shows examples from all available datasets using their default conf

## Available Datasets
- [advanced_geometry](#advanced_geometry)
- [aiw](#aiw)
- [base_conversion](#base_conversion)
- [basic_arithmetic](#basic_arithmetic)
- [bf](#bf)
Expand Down Expand Up @@ -73,6 +74,50 @@ Metadata: {'A': (6, 7), 'B': (-7, -5), 'C': (2, -3), 'incircle_radius_exact': 's
````

### aiw
A procedural dataset inspired by the "Alice in Wonderland" paper.

The dataset is inspired by the following paper:
@inproceedings{nezhurina2024alice,
title={Alice in Wonderland: Simple Tasks Reveal Severe Generalization and
Basic Reasoning Deficits in State-Of-the-Art Large Language Models},
author={Marianna Nezhurina and Lucia Cipolina-Kun and Mehdi Cherti and
Jenia Jitsev},
booktitle={NeurIPS 2024 Workshop on Scientific Methods for Understanding
Deep Learning},
year={2024},
url={https://openreview.net/forum?id=Mkl7dzjYiW}
}

Default configuration:
```python
male_names = ['James', 'John', 'Robert', 'Michael', 'William', 'David', 'Richard', 'Joseph', 'Thomas', 'Charles', 'Bob']
female_names = ['Mary', 'Patricia', 'Jennifer', 'Linda', 'Elizabeth', 'Barbara', 'Susan', 'Jessica', 'Sarah', 'Margaret', 'Alice']
task_types = [<TaskType.SIBLINGS: 'siblings'>, <TaskType.FRIENDS: 'friends'>, <TaskType.COLLEAGUES: 'colleagues'>]
seed = 42
size = 10
max_entities = 6
```

Example tasks:
````
Example 1:
Question: Patricia has 6 male colleagues and she also has 3 female colleagues. These are all colleagues that Patricia has. All these mentioned persons around Patricia are colleagues of each other. James has 2 male colleagues and 2 female colleagues in total. All these mentioned persons around James are colleagues of each other. The people in the circle around James do not have other colleagues aside - with the only exception of Matilda. She is colleague of James and she is also colleague of Patricia, being part of Patricia's circle. How many female colleagues does Matilda have?
Answer: 4
Metadata: {'task_type': 'colleagues'}
Example 2:
Question: Elizabeth has 4 brothers and she also has 3 sisters. How many sisters does Elizabeth's brother have?
Answer: 4
Metadata: {'task_type': 'siblings'}
Example 3:
Question: Sarah has 6 male friends and she also has 1 female friends. They all are friends with each other and have no other friends aside. How many female friends does Thomas, a male friend of Sarah, have?
Answer: 2
Metadata: {'task_type': 'friends'}
````

### base_conversion
Generates base conversion tasks

Expand Down Expand Up @@ -1548,7 +1593,7 @@ Metadata: {'task_type': 'datetime_tz', 'start_time': datetime.datetime(2964, 6,
Example 2:
Question: A video call started at 09:44 and ended at 12:22. How long was the call? Answer in HH:MM.
Answer: 02:38
Metadata: {'task_type': 'time', 'start_time': datetime.datetime(2025, 2, 1, 9, 44), 'end_time': datetime.datetime(2025, 2, 1, 12, 22), 'format': '%H:%M', 'expected_format': 'HH:MM'}
Metadata: {'task_type': 'time', 'start_time': datetime.datetime(2025, 2, 2, 9, 44), 'end_time': datetime.datetime(2025, 2, 2, 12, 22), 'format': '%H:%M', 'expected_format': 'HH:MM'}
Example 3:
Question: Calculate the time difference between Sat Dec 22 2677 and Thu Mar 21 2678. Express the result in D days.
Expand Down
12 changes: 7 additions & 5 deletions reasoning_gym/geometry/advanced_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@ class AdvancedGeometryConfig:

# Probability or list of tasks we want to generate
# For demonstration, we have three categories:
task_types: List[str] = field(default_factory=lambda: [
"orthocenter",
"incircle_radius",
"angle_measure",
])
task_types: List[str] = field(
default_factory=lambda: [
"orthocenter",
"incircle_radius",
"angle_measure",
]
)

def validate(self):
assert self.min_coord < self.max_coord, "min_coord must be < max_coord."
Expand Down
110 changes: 94 additions & 16 deletions reasoning_gym/graphs/family_relationships.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,22 +65,100 @@ class FamilyRelationshipsConfig:

min_family_size: int = 4
max_family_size: int = 8
male_names: List[str] = field(default_factory=lambda: [
"James", "John", "Robert", "Michael", "William", "David", "Richard", "Joseph",
"Thomas", "Charles", "Peter", "Daniel", "Matthew", "Christopher", "Andrew",
"George", "Edward", "Benjamin", "Henry", "Samuel", "Alexander", "Oliver",
"Jack", "Harry", "Jacob", "Noah", "Ethan", "Lucas", "Mason", "Logan",
"Sebastian", "Theodore", "Owen", "Liam", "Aiden", "Kai", "Jayden", "Zion",
"Phoenix", "Atlas", "Axel", "Ryder", "Finn"
])
female_names: List[str] = field(default_factory=lambda: [
"Mary", "Patricia", "Jennifer", "Linda", "Elizabeth", "Barbara", "Susan",
"Jessica", "Sarah", "Karen", "Emma", "Lisa", "Anna", "Margaret", "Victoria",
"Charlotte", "Sophia", "Isabella", "Olivia", "Ava", "Mia", "Emily",
"Abigail", "Amelia", "Eleanor", "Grace", "Alice", "Lucy", "Chloe",
"Sophie", "Lily", "Hannah", "Zoe", "Luna", "Nova", "Aria", "Willow",
"Aurora", "Sage", "River", "Winter", "Sky", "Rain"
])
male_names: List[str] = field(
default_factory=lambda: [
"James",
"John",
"Robert",
"Michael",
"William",
"David",
"Richard",
"Joseph",
"Thomas",
"Charles",
"Peter",
"Daniel",
"Matthew",
"Christopher",
"Andrew",
"George",
"Edward",
"Benjamin",
"Henry",
"Samuel",
"Alexander",
"Oliver",
"Jack",
"Harry",
"Jacob",
"Noah",
"Ethan",
"Lucas",
"Mason",
"Logan",
"Sebastian",
"Theodore",
"Owen",
"Liam",
"Aiden",
"Kai",
"Jayden",
"Zion",
"Phoenix",
"Atlas",
"Axel",
"Ryder",
"Finn",
]
)
female_names: List[str] = field(
default_factory=lambda: [
"Mary",
"Patricia",
"Jennifer",
"Linda",
"Elizabeth",
"Barbara",
"Susan",
"Jessica",
"Sarah",
"Karen",
"Emma",
"Lisa",
"Anna",
"Margaret",
"Victoria",
"Charlotte",
"Sophia",
"Isabella",
"Olivia",
"Ava",
"Mia",
"Emily",
"Abigail",
"Amelia",
"Eleanor",
"Grace",
"Alice",
"Lucy",
"Chloe",
"Sophie",
"Lily",
"Hannah",
"Zoe",
"Luna",
"Nova",
"Aria",
"Willow",
"Aurora",
"Sage",
"River",
"Winter",
"Sky",
"Rain",
]
)
seed: Optional[int] = None
size: int = 500

Expand Down
3 changes: 3 additions & 0 deletions reasoning_gym/logic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@
- Syllogisms
"""

from .aiw import AliceInWonderlandConfig, AliceInWonderlandDataset
from .propositional_logic import PropositionalLogicConfig, PropositionalLogicDataset
from .syllogisms import SyllogismConfig, SyllogismDataset, Term

__all__ = [
"AliceInWonderlandConfig",
"AliceInWonderlandDataset",
"PropositionalLogicConfig",
"PropositionalLogicDataset",
"SyllogismConfig",
Expand Down
81 changes: 45 additions & 36 deletions reasoning_gym/logic/aiw.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from dataclasses import dataclass, field
from typing import List, Optional
from enum import Enum
from random import Random
from string import Template
from typing import List, Optional

from ..factory import ProceduralDataset, register_dataset


class TaskType(Enum):
"""Defines the type of task for the Alice in Wonderland dataset."""

SIBLINGS = "siblings"
FRIENDS = "friends"
COLLEAGUES = "colleagues" # Added colleagues task
Expand All @@ -26,21 +27,39 @@ class AliceInWonderlandConfig:
size (int): Number of samples in the dataset.
max_entities (int): Max number of siblings/friends/colleagues in questions.
"""

male_names: List[str] = field(
default_factory=lambda: [
"James", "John", "Robert", "Michael", "William", "David",
"Richard", "Joseph", "Thomas", "Charles", "Bob"
"James",
"John",
"Robert",
"Michael",
"William",
"David",
"Richard",
"Joseph",
"Thomas",
"Charles",
"Bob",
]
)
female_names: List[str] = field(
default_factory=lambda: [
"Mary", "Patricia", "Jennifer", "Linda", "Elizabeth",
"Barbara", "Susan", "Jessica", "Sarah", "Margaret", "Alice"
"Mary",
"Patricia",
"Jennifer",
"Linda",
"Elizabeth",
"Barbara",
"Susan",
"Jessica",
"Sarah",
"Margaret",
"Alice",
]
)
task_types: List[TaskType] = field(
default_factory=lambda: [
TaskType.SIBLINGS, TaskType.FRIENDS, TaskType.COLLEAGUES] # Added Colleagues
default_factory=lambda: [TaskType.SIBLINGS, TaskType.FRIENDS, TaskType.COLLEAGUES] # Added Colleagues
)
seed: Optional[int] = None
size: int = 10
Expand All @@ -57,19 +76,19 @@ def validate(self) -> None:

class AliceInWonderlandDataset(ProceduralDataset):
"""
A procedural dataset inspired by the "Alice in Wonderland" paper.
The dataset is inspired by the following paper:
@inproceedings{nezhurina2024alice,
title={Alice in Wonderland: Simple Tasks Reveal Severe Generalization and
Basic Reasoning Deficits in State-Of-the-Art Large Language Models},
author={Marianna Nezhurina and Lucia Cipolina-Kun and Mehdi Cherti and
Jenia Jitsev},
booktitle={NeurIPS 2024 Workshop on Scientific Methods for Understanding
Deep Learning},
year={2024},
url={https://openreview.net/forum?id=Mkl7dzjYiW}
}
A procedural dataset inspired by the "Alice in Wonderland" paper.
The dataset is inspired by the following paper:
@inproceedings{nezhurina2024alice,
title={Alice in Wonderland: Simple Tasks Reveal Severe Generalization and
Basic Reasoning Deficits in State-Of-the-Art Large Language Models},
author={Marianna Nezhurina and Lucia Cipolina-Kun and Mehdi Cherti and
Jenia Jitsev},
booktitle={NeurIPS 2024 Workshop on Scientific Methods for Understanding
Deep Learning},
year={2024},
url={https://openreview.net/forum?id=Mkl7dzjYiW}
}
"""

Expand Down Expand Up @@ -152,14 +171,10 @@ def _get_aiw(self, rng: Random) -> dict:
num_female=num_female,
)
elif task_type == TaskType.COLLEAGUES:
num_male_colleagues_alice_circle = rng.randint(
1, self.config.max_entities)
num_female_colleagues_alice_circle = rng.randint(
1, self.config.max_entities)
num_male_colleagues_bob_circle = rng.randint(
1, self.config.max_entities)
num_female_colleagues_bob_circle = rng.randint(
1, self.config.max_entities)
num_male_colleagues_alice_circle = rng.randint(1, self.config.max_entities)
num_female_colleagues_alice_circle = rng.randint(1, self.config.max_entities)
num_male_colleagues_bob_circle = rng.randint(1, self.config.max_entities)
num_female_colleagues_bob_circle = rng.randint(1, self.config.max_entities)

answer = num_female_colleagues_alice_circle + 1
template = rng.choice(self.templates[TaskType.COLLEAGUES])
Expand All @@ -169,16 +184,10 @@ def _get_aiw(self, rng: Random) -> dict:
num_male_colleagues_alice_circle=num_male_colleagues_alice_circle,
num_female_colleagues_alice_circle=num_female_colleagues_alice_circle,
num_male_colleagues_bob_circle=num_male_colleagues_bob_circle,
num_female_colleagues_bob_circle=num_female_colleagues_bob_circle
num_female_colleagues_bob_circle=num_female_colleagues_bob_circle,
)

return {
"question": question,
"answer": answer,
"metadata": {
"task_type": task_type.value
}
}
return {"question": question, "answer": answer, "metadata": {"task_type": task_type.value}}

def __getitem__(self, idx: int) -> dict:
rng = Random(self.seed + idx)
Expand Down
9 changes: 3 additions & 6 deletions tests/test_aiw.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ def test_aiw_config_validation():
config.validate()

with pytest.raises(AssertionError):
config = AliceInWonderlandConfig(
female_names=["Mary", "Jane"]) # No Alice
config = AliceInWonderlandConfig(female_names=["Mary", "Jane"]) # No Alice
config.validate()

with pytest.raises(AssertionError):
Expand Down Expand Up @@ -56,8 +55,7 @@ def test_aiw_items():
# Verify question task type characteristics
task_type = item["metadata"]["task_type"]
if task_type == TaskType.SIBLINGS.value:
assert any(phrase in item["question"]
for phrase in ["brothers", "sisters"])
assert any(phrase in item["question"] for phrase in ["brothers", "sisters"])
elif task_type == TaskType.FRIENDS.value:
assert "friends" in item["question"]
elif task_type == TaskType.COLLEAGUES:
Expand Down Expand Up @@ -95,5 +93,4 @@ def test_aiw_random_ranges():
numbers = [int(n) for n in question.split() if n.isdigit()]

# Check all numbers are in reasonable range (1-6 as per implementation)
assert all(
1 <= n <= 12 for n in numbers), f"Numbers out of range: {numbers}"
assert all(1 <= n <= 12 for n in numbers), f"Numbers out of range: {numbers}"

0 comments on commit f396d3d

Please sign in to comment.