From 86525f6401c99fc241aa4484ce5448b1654c83e6 Mon Sep 17 00:00:00 2001 From: rishabhranawat Date: Sat, 1 Feb 2025 11:37:50 -0800 Subject: [PATCH 1/5] [aiw] basic version of alice-in-wonderland procedural dataset --- reasoning_gym/logic/aiw.py | 183 +++++++++++++++++++++++++++++++++++++ tests/test_aiw.py | 107 ++++++++++++++++++++++ 2 files changed, 290 insertions(+) create mode 100644 reasoning_gym/logic/aiw.py create mode 100644 tests/test_aiw.py diff --git a/reasoning_gym/logic/aiw.py b/reasoning_gym/logic/aiw.py new file mode 100644 index 00000000..bee47a89 --- /dev/null +++ b/reasoning_gym/logic/aiw.py @@ -0,0 +1,183 @@ +from dataclasses import dataclass, field +from typing import List, Optional +from enum import Enum +from random import Random +from string import Template + +from ..factory import ProceduralDataset, register_dataset + + +class TaskType(Enum): + """Defines the type of task for the Alice in Wonderland dataset.""" + SIBLINGS = "siblings" + FRIENDS = "friends" + + +class OutputFormat(Enum): + """Defines the output format for the generated questions.""" + PLAIN = "plain" + RESTRICTED = "restricted" + THINKING = "thinking" + + +@dataclass +class AliceInWonderlandConfig: + """Configuration options for the Alice in Wonderland dataset. + + Attributes: + male_names (List[str]): List of male names to use in questions. + female_names (List[str]): List of female names to use in questions. Must include 'Alice'. + task_types (List[TaskType]): List of task types to include in dataset. + output_formats (List[OutputFormat]): List of output formats to include in dataset. + seed (Optional[int]): Seed for random number generation. + size (int): Number of samples in the dataset. + max_entities (int): Max number of siblings/friends in questions. + """ + male_names: List[str] = field( + default_factory=lambda: [ + "James", "John", "Robert", "Michael", "William", "David", + "Richard", "Joseph", "Thomas", "Charles", "Bob" + ] + ) + female_names: List[str] = field( + default_factory=lambda: [ + "Mary", "Patricia", "Jennifer", "Linda", "Elizabeth", + "Barbara", "Susan", "Jessica", "Sarah", "Margaret", "Alice" + ] + ) + task_types: List[TaskType] = field( + default_factory=lambda: [TaskType.SIBLINGS, TaskType.FRIENDS] + ) + output_formats: List[OutputFormat] = field( + default_factory=lambda: [ + OutputFormat.PLAIN, + OutputFormat.RESTRICTED, + OutputFormat.THINKING, + ] + ) + seed: Optional[int] = None + size: int = 10 + max_entities: int = 6 # Added max_entities + + def validate(self) -> None: + """Validates the configuration parameters.""" + assert len(self.male_names) > 0, "must provide male names" + assert len(self.female_names) > 0, "must provide female names" + assert "Alice" in self.female_names, "'Alice' must be in female names" + assert len(self.task_types) > 0, "must provide at least one task type" + assert len( + self.output_formats) > 0, "must provide at least one output format" + assert self.max_entities > 0, "max_entities must be positive" + + +class AliceInWonderlandDataset(ProceduralDataset): + """ + A procedural dataset inspired by the "Alice in Wonderland" paper. + + The dataset is inspired by the following paper: + @inproceedings{nezhurina2024alice, + title={Alice in Wonderland: Simple Tasks Reveal Severe Generalization and + Basic Reasoning Deficits in State-Of-the-Art Large Language Models}, + author={Marianna Nezhurina and Lucia Cipolina-Kun and Mehdi Cherti and + Jenia Jitsev}, + booktitle={NeurIPS 2024 Workshop on Scientific Methods for Understanding + Deep Learning}, + year={2024}, + url={https://openreview.net/forum?id=Mkl7dzjYiW} + } + + """ + def __init__(self, config: AliceInWonderlandConfig): + super().__init__(config=config, seed=config.seed, size=config.size) + self.templates = { + TaskType.SIBLINGS: [ + Template( + "$female_name has $num_brothers brothers and she also has " + "$num_sisters sisters. How many sisters does " + "$female_name's brother have?" + ), + Template( + "$female_name has $num_sisters sisters and she also has " + "$num_brothers brothers. How many sisters does " + "$male_name's brother have?" + ), + ], + TaskType.FRIENDS: [ + Template( + "$female_name has $num_male male friends and she also has " + "$num_female female friends. They all are friends with each " + "other and have no other friends aside. How many female " + "friends does $male_name, a male friend of $female_name, " + "have?" + ) + ], + } + + self.format_templates = { + OutputFormat.PLAIN: Template("$question"), + OutputFormat.RESTRICTED: Template( + "$question To answer the question, DO NOT OUTPUT ANY TEXT EXCEPT " + 'following format that contains final answer: "### Answer:"' + ), + OutputFormat.THINKING: Template( + "$question Before providing answer to this problem, think " + "carefully step by step and double check the path to the " + 'correct solution for any mistakes. Provide then the final ' + 'answer in following form: "### Answer:"' + ), + } + + def _get_aiw(self, rng: Random) -> dict: + """Generates a single Alice in Wonderland question. + + Args: + rng (Random): Random number generator. + + Returns: + dict: A dictionary containing the generated question, the right answer + and a description of the example. + """ + task_type = rng.choice(self.config.task_types) + output_format = rng.choice(self.config.output_formats) + female_name = rng.choice(self.config.female_names) + male_name = rng.choice(self.config.male_names) + + if task_type == TaskType.SIBLINGS: + num_brothers = rng.randint(1, self.config.max_entities) + num_sisters = rng.randint(1, self.config.max_entities) + answer = num_sisters + 1 + template = rng.choice(self.templates[TaskType.SIBLINGS]) + question = template.substitute( + female_name=female_name, + male_name=male_name, + num_brothers=num_brothers, + num_sisters=num_sisters, + ) + elif task_type == TaskType.FRIENDS: + num_male = rng.randint(1, self.config.max_entities) + num_female = rng.randint(1, self.config.max_entities) + answer = num_female + 1 + template = rng.choice(self.templates[TaskType.FRIENDS]) + question = template.substitute( + female_name=female_name, + male_name=male_name, + num_male=num_male, + num_female=num_female, + ) + + formatted_question = self.format_templates[output_format].substitute( + question=question + ) + + return { + "prompt": formatted_question, + "right_answer": str(answer), + "description": f"{task_type.value} variation, {output_format.value} format", + } + + def __getitem__(self, idx: int) -> dict: + rng = Random(self.seed + idx) + return self._get_aiw(rng) + + +register_dataset("aiw", AliceInWonderlandDataset, AliceInWonderlandConfig) \ No newline at end of file diff --git a/tests/test_aiw.py b/tests/test_aiw.py new file mode 100644 index 00000000..5a0fbaf5 --- /dev/null +++ b/tests/test_aiw.py @@ -0,0 +1,107 @@ +import pytest + +from reasoning_gym.logic.aiw import AliceInWonderlandConfig, AliceInWonderlandDataset, TaskType, OutputFormat + +def test_aiw_config_validation(): + """Test that invalid configs raise appropriate errors""" + with pytest.raises(AssertionError): + config = AliceInWonderlandConfig(male_names=[]) # Empty male names + config.validate() + + with pytest.raises(AssertionError): + config = AliceInWonderlandConfig(female_names=[]) # Empty female names + config.validate() + + with pytest.raises(AssertionError): + config = AliceInWonderlandConfig(female_names=["Mary", "Jane"]) # No Alice + config.validate() + + with pytest.raises(AssertionError): + config = AliceInWonderlandConfig(task_types=[]) # No task types + config.validate() + + with pytest.raises(AssertionError): + config = AliceInWonderlandConfig(output_formats=[]) # No output formats + config.validate() + +def test_aiw_deterministic(): + """Test that dataset generates same items with same seed""" + config = AliceInWonderlandConfig(seed=42, size=10) + dataset1 = AliceInWonderlandDataset(config) + dataset2 = AliceInWonderlandDataset(config) + + for i in range(len(dataset1)): + assert dataset1[i] == dataset2[i] + +def test_aiw_items(): + """Test basic properties of generated items""" + config = AliceInWonderlandConfig(size=50, seed=42) + dataset = AliceInWonderlandDataset(config) + + for i in range(len(dataset)): + item = dataset[i] + assert isinstance(item, dict) + assert "prompt" in item + assert "right_answer" in item + assert "description" in item + + # Verify answer is numeric and positive + answer = int(item["right_answer"]) + assert answer > 0 + + # Verify question contains at least one female name + female_names = config.female_names + assert any(name in item["prompt"] for name in female_names) + + # Verify question format + if TaskType.SIBLINGS.value in item["description"]: + assert any(phrase in item["prompt"] for phrase in ["brothers", "sisters"]) + elif TaskType.FRIENDS.value in item["description"]: + assert "friends" in item["prompt"] + + # Verify output format + if OutputFormat.RESTRICTED.value in item["description"]: + assert "DO NOT OUTPUT ANY TEXT EXCEPT" in item["prompt"] + elif OutputFormat.THINKING.value in item["description"]: + assert "think carefully step by step" in item["prompt"] + +def test_aiw_iteration(): + """Test that iteration works correctly""" + config = AliceInWonderlandConfig(size=5, seed=42) + dataset = AliceInWonderlandDataset(config) + + # Test manual iteration + items = [] + for item in dataset: + items.append(item) + assert len(items) == config.size + + # Test list conversion + items = list(dataset) + assert len(items) == config.size + + # Test multiple iterations yield same results + first_items = list(dataset) + second_items = list(dataset) + assert first_items == second_items + +def test_aiw_random_ranges(): + """Test that generated numbers stay within expected ranges""" + config = AliceInWonderlandConfig(size=30, seed=42, max_entities=12) + dataset = AliceInWonderlandDataset(config) + + for item in dataset: + prompt = item["prompt"] + numbers = [int(n) for n in prompt.split() if n.isdigit()] + + # Check all numbers are in reasonable range (1-6 as per implementation) + assert all(1 <= n <= 12 for n in numbers), f"Numbers out of range: {numbers}" + +def test_output_format_is_correct(): + """Test that the output format adheres to the user input""" + config = AliceInWonderlandConfig(size=30, seed=42, output_formats=[OutputFormat.THINKING]) + dataset = AliceInWonderlandDataset(config) + + for item in dataset: + prompt = item["prompt"] + assert "think carefully step by step" in item["prompt"] From 57a1b5c3538054bac7def5639c5fc07c496797f0 Mon Sep 17 00:00:00 2001 From: rishabhranawat Date: Sat, 1 Feb 2025 12:04:44 -0800 Subject: [PATCH 2/5] [aiw] add colleague variation --- reasoning_gym/logic/aiw.py | 37 +++++++++++++++++++++++++++++++++++-- 1 file changed, 35 insertions(+), 2 deletions(-) diff --git a/reasoning_gym/logic/aiw.py b/reasoning_gym/logic/aiw.py index bee47a89..816b5ae3 100644 --- a/reasoning_gym/logic/aiw.py +++ b/reasoning_gym/logic/aiw.py @@ -11,6 +11,7 @@ class TaskType(Enum): """Defines the type of task for the Alice in Wonderland dataset.""" SIBLINGS = "siblings" FRIENDS = "friends" + COLLEAGUES = "colleagues" # Added colleagues task class OutputFormat(Enum): @@ -31,7 +32,7 @@ class AliceInWonderlandConfig: output_formats (List[OutputFormat]): List of output formats to include in dataset. seed (Optional[int]): Seed for random number generation. size (int): Number of samples in the dataset. - max_entities (int): Max number of siblings/friends in questions. + max_entities (int): Max number of siblings/friends/colleagues in questions. """ male_names: List[str] = field( default_factory=lambda: [ @@ -46,7 +47,7 @@ class AliceInWonderlandConfig: ] ) task_types: List[TaskType] = field( - default_factory=lambda: [TaskType.SIBLINGS, TaskType.FRIENDS] + default_factory=lambda: [TaskType.SIBLINGS, TaskType.FRIENDS, TaskType.COLLEAGUES] # Added Colleagues ) output_formats: List[OutputFormat] = field( default_factory=lambda: [ @@ -111,6 +112,20 @@ def __init__(self, config: AliceInWonderlandConfig): "have?" ) ], + TaskType.COLLEAGUES: [ # New colleagues templates + Template( + "$female_name has $num_male_colleagues_alice_circle male colleagues and she also has " + "$num_female_colleagues_alice_circle female colleagues. These are all colleagues that $female_name has. " + "All these mentioned persons around $female_name are colleagues of each other. " + "$male_name has $num_male_colleagues_bob_circle male colleagues " + "and $num_female_colleagues_bob_circle female colleagues in total. " + "All these mentioned persons around $male_name are colleagues of each other. " + "The people in the circle around $male_name do not have " + "other colleagues aside - with the only exception of Matilda. " + "She is colleague of $male_name and she is also colleague of $female_name, " + "being part of $female_name's circle. How many female colleagues does Matilda have?" + ), + ], } self.format_templates = { @@ -145,6 +160,7 @@ def _get_aiw(self, rng: Random) -> dict: if task_type == TaskType.SIBLINGS: num_brothers = rng.randint(1, self.config.max_entities) num_sisters = rng.randint(1, self.config.max_entities) + answer = num_sisters + 1 template = rng.choice(self.templates[TaskType.SIBLINGS]) question = template.substitute( @@ -156,6 +172,7 @@ def _get_aiw(self, rng: Random) -> dict: elif task_type == TaskType.FRIENDS: num_male = rng.randint(1, self.config.max_entities) num_female = rng.randint(1, self.config.max_entities) + answer = num_female + 1 template = rng.choice(self.templates[TaskType.FRIENDS]) question = template.substitute( @@ -164,6 +181,22 @@ def _get_aiw(self, rng: Random) -> dict: num_male=num_male, num_female=num_female, ) + elif task_type == TaskType.COLLEAGUES: + num_male_colleagues_alice_circle = rng.randint(1, self.config.max_entities) + num_female_colleagues_alice_circle = rng.randint(1, self.config.max_entities) + num_male_colleagues_bob_circle = rng.randint(1, self.config.max_entities) + num_female_colleagues_bob_circle = rng.randint(1, self.config.max_entities) + + answer = num_female_colleagues_alice_circle + 1 + template = rng.choice(self.templates[TaskType.COLLEAGUES]) + question = template.substitute( + female_name=female_name, + male_name=male_name, + num_male_colleagues_alice_circle=num_male_colleagues_alice_circle, + num_female_colleagues_alice_circle=num_female_colleagues_alice_circle, + num_male_colleagues_bob_circle=num_male_colleagues_bob_circle, + num_female_colleagues_bob_circle=num_female_colleagues_bob_circle + ) formatted_question = self.format_templates[output_format].substitute( question=question From 3d42e84807cc525cc1b21ffb857a69981b6aabc0 Mon Sep 17 00:00:00 2001 From: rishabhranawat Date: Sat, 1 Feb 2025 16:30:05 -0800 Subject: [PATCH 3/5] [aiw] remove output_formats style and change return type to a standard format --- reasoning_gym/logic/aiw.py | 47 +++++++++++-------------- tests/test_aiw.py | 70 +++++++++++++++++--------------------- 2 files changed, 51 insertions(+), 66 deletions(-) diff --git a/reasoning_gym/logic/aiw.py b/reasoning_gym/logic/aiw.py index 816b5ae3..2ce3a13b 100644 --- a/reasoning_gym/logic/aiw.py +++ b/reasoning_gym/logic/aiw.py @@ -29,7 +29,6 @@ class AliceInWonderlandConfig: male_names (List[str]): List of male names to use in questions. female_names (List[str]): List of female names to use in questions. Must include 'Alice'. task_types (List[TaskType]): List of task types to include in dataset. - output_formats (List[OutputFormat]): List of output formats to include in dataset. seed (Optional[int]): Seed for random number generation. size (int): Number of samples in the dataset. max_entities (int): Max number of siblings/friends/colleagues in questions. @@ -47,14 +46,8 @@ class AliceInWonderlandConfig: ] ) task_types: List[TaskType] = field( - default_factory=lambda: [TaskType.SIBLINGS, TaskType.FRIENDS, TaskType.COLLEAGUES] # Added Colleagues - ) - output_formats: List[OutputFormat] = field( default_factory=lambda: [ - OutputFormat.PLAIN, - OutputFormat.RESTRICTED, - OutputFormat.THINKING, - ] + TaskType.SIBLINGS, TaskType.FRIENDS, TaskType.COLLEAGUES] # Added Colleagues ) seed: Optional[int] = None size: int = 10 @@ -66,8 +59,6 @@ def validate(self) -> None: assert len(self.female_names) > 0, "must provide female names" assert "Alice" in self.female_names, "'Alice' must be in female names" assert len(self.task_types) > 0, "must provide at least one task type" - assert len( - self.output_formats) > 0, "must provide at least one output format" assert self.max_entities > 0, "max_entities must be positive" @@ -88,6 +79,7 @@ class AliceInWonderlandDataset(ProceduralDataset): } """ + def __init__(self, config: AliceInWonderlandConfig): super().__init__(config=config, seed=config.seed, size=config.size) self.templates = { @@ -118,12 +110,12 @@ def __init__(self, config: AliceInWonderlandConfig): "$num_female_colleagues_alice_circle female colleagues. These are all colleagues that $female_name has. " "All these mentioned persons around $female_name are colleagues of each other. " "$male_name has $num_male_colleagues_bob_circle male colleagues " - "and $num_female_colleagues_bob_circle female colleagues in total. " + "and $num_female_colleagues_bob_circle female colleagues in total. " "All these mentioned persons around $male_name are colleagues of each other. " - "The people in the circle around $male_name do not have " - "other colleagues aside - with the only exception of Matilda. " + "The people in the circle around $male_name do not have " + "other colleagues aside - with the only exception of Matilda. " "She is colleague of $male_name and she is also colleague of $female_name, " - "being part of $female_name's circle. How many female colleagues does Matilda have?" + "being part of $female_name's circle. How many female colleagues does Matilda have?" ), ], } @@ -153,7 +145,6 @@ def _get_aiw(self, rng: Random) -> dict: and a description of the example. """ task_type = rng.choice(self.config.task_types) - output_format = rng.choice(self.config.output_formats) female_name = rng.choice(self.config.female_names) male_name = rng.choice(self.config.male_names) @@ -182,10 +173,14 @@ def _get_aiw(self, rng: Random) -> dict: num_female=num_female, ) elif task_type == TaskType.COLLEAGUES: - num_male_colleagues_alice_circle = rng.randint(1, self.config.max_entities) - num_female_colleagues_alice_circle = rng.randint(1, self.config.max_entities) - num_male_colleagues_bob_circle = rng.randint(1, self.config.max_entities) - num_female_colleagues_bob_circle = rng.randint(1, self.config.max_entities) + num_male_colleagues_alice_circle = rng.randint( + 1, self.config.max_entities) + num_female_colleagues_alice_circle = rng.randint( + 1, self.config.max_entities) + num_male_colleagues_bob_circle = rng.randint( + 1, self.config.max_entities) + num_female_colleagues_bob_circle = rng.randint( + 1, self.config.max_entities) answer = num_female_colleagues_alice_circle + 1 template = rng.choice(self.templates[TaskType.COLLEAGUES]) @@ -198,14 +193,12 @@ def _get_aiw(self, rng: Random) -> dict: num_female_colleagues_bob_circle=num_female_colleagues_bob_circle ) - formatted_question = self.format_templates[output_format].substitute( - question=question - ) - return { - "prompt": formatted_question, - "right_answer": str(answer), - "description": f"{task_type.value} variation, {output_format.value} format", + "question": question, + "answer": answer, + "metadata": { + "task_type": task_type.value + } } def __getitem__(self, idx: int) -> dict: @@ -213,4 +206,4 @@ def __getitem__(self, idx: int) -> dict: return self._get_aiw(rng) -register_dataset("aiw", AliceInWonderlandDataset, AliceInWonderlandConfig) \ No newline at end of file +register_dataset("aiw", AliceInWonderlandDataset, AliceInWonderlandConfig) diff --git a/tests/test_aiw.py b/tests/test_aiw.py index 5a0fbaf5..279fcc2c 100644 --- a/tests/test_aiw.py +++ b/tests/test_aiw.py @@ -2,6 +2,7 @@ from reasoning_gym.logic.aiw import AliceInWonderlandConfig, AliceInWonderlandDataset, TaskType, OutputFormat + def test_aiw_config_validation(): """Test that invalid configs raise appropriate errors""" with pytest.raises(AssertionError): @@ -11,18 +12,16 @@ def test_aiw_config_validation(): with pytest.raises(AssertionError): config = AliceInWonderlandConfig(female_names=[]) # Empty female names config.validate() - + with pytest.raises(AssertionError): - config = AliceInWonderlandConfig(female_names=["Mary", "Jane"]) # No Alice + config = AliceInWonderlandConfig( + female_names=["Mary", "Jane"]) # No Alice config.validate() - + with pytest.raises(AssertionError): config = AliceInWonderlandConfig(task_types=[]) # No task types config.validate() - - with pytest.raises(AssertionError): - config = AliceInWonderlandConfig(output_formats=[]) # No output formats - config.validate() + def test_aiw_deterministic(): """Test that dataset generates same items with same seed""" @@ -33,6 +32,7 @@ def test_aiw_deterministic(): for i in range(len(dataset1)): assert dataset1[i] == dataset2[i] + def test_aiw_items(): """Test basic properties of generated items""" config = AliceInWonderlandConfig(size=50, seed=42) @@ -41,29 +41,28 @@ def test_aiw_items(): for i in range(len(dataset)): item = dataset[i] assert isinstance(item, dict) - assert "prompt" in item - assert "right_answer" in item - assert "description" in item - + assert "question" in item + assert "answer" in item + assert "metadata" in item + # Verify answer is numeric and positive - answer = int(item["right_answer"]) + answer = int(item["answer"]) assert answer > 0 - + # Verify question contains at least one female name female_names = config.female_names - assert any(name in item["prompt"] for name in female_names) - - # Verify question format - if TaskType.SIBLINGS.value in item["description"]: - assert any(phrase in item["prompt"] for phrase in ["brothers", "sisters"]) - elif TaskType.FRIENDS.value in item["description"]: - assert "friends" in item["prompt"] - - # Verify output format - if OutputFormat.RESTRICTED.value in item["description"]: - assert "DO NOT OUTPUT ANY TEXT EXCEPT" in item["prompt"] - elif OutputFormat.THINKING.value in item["description"]: - assert "think carefully step by step" in item["prompt"] + assert any(name in item["question"] for name in female_names) + + # Verify question task type characteristics + task_type = item["metadata"]["task_type"] + if task_type == TaskType.SIBLINGS.value: + assert any(phrase in item["question"] + for phrase in ["brothers", "sisters"]) + elif task_type == TaskType.FRIENDS.value: + assert "friends" in item["question"] + elif task_type == TaskType.COLLEAGUES: + assert "colleagues" in item["question"] + def test_aiw_iteration(): """Test that iteration works correctly""" @@ -85,23 +84,16 @@ def test_aiw_iteration(): second_items = list(dataset) assert first_items == second_items + def test_aiw_random_ranges(): """Test that generated numbers stay within expected ranges""" config = AliceInWonderlandConfig(size=30, seed=42, max_entities=12) dataset = AliceInWonderlandDataset(config) for item in dataset: - prompt = item["prompt"] - numbers = [int(n) for n in prompt.split() if n.isdigit()] - - # Check all numbers are in reasonable range (1-6 as per implementation) - assert all(1 <= n <= 12 for n in numbers), f"Numbers out of range: {numbers}" + question = item["question"] + numbers = [int(n) for n in question.split() if n.isdigit()] -def test_output_format_is_correct(): - """Test that the output format adheres to the user input""" - config = AliceInWonderlandConfig(size=30, seed=42, output_formats=[OutputFormat.THINKING]) - dataset = AliceInWonderlandDataset(config) - - for item in dataset: - prompt = item["prompt"] - assert "think carefully step by step" in item["prompt"] + # Check all numbers are in reasonable range (1-6 as per implementation) + assert all( + 1 <= n <= 12 for n in numbers), f"Numbers out of range: {numbers}" From f8696d6d22fcbfa9362cd0fcc97b22af52de1d25 Mon Sep 17 00:00:00 2001 From: rishabhranawat Date: Sat, 1 Feb 2025 16:31:45 -0800 Subject: [PATCH 4/5] [aiw] remove output format enum --- reasoning_gym/logic/aiw.py | 7 ------- tests/test_aiw.py | 2 +- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/reasoning_gym/logic/aiw.py b/reasoning_gym/logic/aiw.py index 2ce3a13b..ebeb50a7 100644 --- a/reasoning_gym/logic/aiw.py +++ b/reasoning_gym/logic/aiw.py @@ -14,13 +14,6 @@ class TaskType(Enum): COLLEAGUES = "colleagues" # Added colleagues task -class OutputFormat(Enum): - """Defines the output format for the generated questions.""" - PLAIN = "plain" - RESTRICTED = "restricted" - THINKING = "thinking" - - @dataclass class AliceInWonderlandConfig: """Configuration options for the Alice in Wonderland dataset. diff --git a/tests/test_aiw.py b/tests/test_aiw.py index 279fcc2c..bf556cbe 100644 --- a/tests/test_aiw.py +++ b/tests/test_aiw.py @@ -1,6 +1,6 @@ import pytest -from reasoning_gym.logic.aiw import AliceInWonderlandConfig, AliceInWonderlandDataset, TaskType, OutputFormat +from reasoning_gym.logic.aiw import AliceInWonderlandConfig, AliceInWonderlandDataset, TaskType def test_aiw_config_validation(): From 5f6d6153696c3b213ef287e6f4e0a226d3439480 Mon Sep 17 00:00:00 2001 From: rishabhranawat Date: Sat, 1 Feb 2025 16:33:08 -0800 Subject: [PATCH 5/5] [aiw] remove output format template --- reasoning_gym/logic/aiw.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/reasoning_gym/logic/aiw.py b/reasoning_gym/logic/aiw.py index ebeb50a7..7a27b6d0 100644 --- a/reasoning_gym/logic/aiw.py +++ b/reasoning_gym/logic/aiw.py @@ -113,20 +113,6 @@ def __init__(self, config: AliceInWonderlandConfig): ], } - self.format_templates = { - OutputFormat.PLAIN: Template("$question"), - OutputFormat.RESTRICTED: Template( - "$question To answer the question, DO NOT OUTPUT ANY TEXT EXCEPT " - 'following format that contains final answer: "### Answer:"' - ), - OutputFormat.THINKING: Template( - "$question Before providing answer to this problem, think " - "carefully step by step and double check the path to the " - 'correct solution for any mistakes. Provide then the final ' - 'answer in following form: "### Answer:"' - ), - } - def _get_aiw(self, rng: Random) -> dict: """Generates a single Alice in Wonderland question.