From 8ccc4d7b0cf57c3aaa506b6546d79fd999e0f4fc Mon Sep 17 00:00:00 2001 From: Zafir Stojanovski Date: Wed, 5 Mar 2025 22:42:46 +0100 Subject: [PATCH] feat(env): Course Schedule Curriculum (#266) * course schedule curriculum * update levels * update comments * lint --- reasoning_gym/graphs/__init__.py | 3 +- reasoning_gym/graphs/course_schedule.py | 74 ++++++++++++++++++++----- tests/test_course_schedule.py | 52 +++++++++++++---- 3 files changed, 102 insertions(+), 27 deletions(-) diff --git a/reasoning_gym/graphs/__init__.py b/reasoning_gym/graphs/__init__.py index fc4b9f19..a2918efa 100644 --- a/reasoning_gym/graphs/__init__.py +++ b/reasoning_gym/graphs/__init__.py @@ -1,4 +1,4 @@ -from .course_schedule import CourseScheduleConfig, CourseScheduleDataset +from .course_schedule import CourseScheduleConfig, CourseScheduleCurriculum, CourseScheduleDataset from .family_relationships import FamilyRelationshipsConfig, FamilyRelationshipsDataset from .largest_island import LargestIslandDataset from .quantum_lock import QuantumLockConfig, QuantumLockDataset @@ -12,6 +12,7 @@ "LargestIslandDataset", "CourseScheduleDataset", "CourseScheduleConfig", + "CourseScheduleCurriculum", "ShortestPathConfig", "ShortestPathDataset", ] diff --git a/reasoning_gym/graphs/course_schedule.py b/reasoning_gym/graphs/course_schedule.py index 15497cc9..ad3b6ee3 100644 --- a/reasoning_gym/graphs/course_schedule.py +++ b/reasoning_gym/graphs/course_schedule.py @@ -10,10 +10,9 @@ from random import Random from typing import Optional +from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset -MAX_NUM_COURSES = 1_000 - QUESTION_TEMPLATE = """There are a total of {num_courses} courses you have to take, labeled from 0 to {last_index}. You are given the following list of prerequisites, where prerequisites[i] = (a_i, b_i) indicates that you must first take course b_i if you want to take course a_i: @@ -27,25 +26,29 @@ class CourseScheduleConfig: """Configuration for Course Schedule dataset generation""" - num_courses: int = 5 # Total number of courses (ranging from 0 to num_courses - 1) + min_num_courses: int = 5 # Minimum number of courses + max_num_courses: int = 10 # Maximum number of courses + min_num_prerequisites: int = 1 # Minimum number of prerequisites (per course) max_num_prerequisites: int = 2 # Maximum number of prerequisites (per course) - p_solvable: float = 0.5 # Probability that the course schedule is solvable min_cycle_length: int = 3 # Minimum length of a cycle in the prerequisites (if unsolvable) max_cycle_length: int = 5 # Maximum length of a cycle in the prerequisites (if unsolvable) + p_solvable: float = 0.5 # Probability that the course schedule is solvable size: int = 500 # Virtual dataset size seed: Optional[int] = None def validate(self): """Validate configuration parameters""" - assert 1 <= self.num_courses <= MAX_NUM_COURSES, f"num_courses must be between 1 and {MAX_NUM_COURSES}" assert ( - 1 <= self.max_num_prerequisites <= self.num_courses - ), "max_num_prerequisites must be between 1 and num_courses" - assert 0 <= self.p_solvable <= 1, "p_solvable must be between 0 and 1" + 3 <= self.min_num_courses <= self.max_num_courses + ), "min_num_courses must be between 3 and max_num_courses" assert ( 3 <= self.min_cycle_length <= self.max_cycle_length ), "min_cycle_length must be between 3 and max_cycle_length" + assert ( + 1 <= self.min_num_prerequisites <= self.max_num_prerequisites + ), "min_num_prerequisites must be between 0 and max_num_prerequisites" + assert 0 <= self.p_solvable <= 1, "p_solvable must be between 0 and 1" class CourseScheduleDataset(ProceduralDataset): @@ -88,14 +91,17 @@ def _create_prerequisites(self, rng: Random, courses: list[int], solvable: bool) for idx in range(len(courses) - 1, 0, -1): current_course = courses[idx] available_prereqs = courses[:idx] # Only earlier courses can be prerequisites - num_prerequisites = rng.randint(0, min(len(available_prereqs), self.config.max_num_prerequisites)) + num_prerequisites = min( + len(available_prereqs), + rng.randint(self.config.min_num_prerequisites, self.config.max_num_prerequisites), + ) if num_prerequisites > 0: chosen_prereqs = rng.sample(available_prereqs, num_prerequisites) prerequisites.extend([[current_course, p] for p in chosen_prereqs]) if not solvable: # If solution should be unsolvable, create a cycle - cycle_length = rng.randint(self.config.min_cycle_length, min(self.config.max_cycle_length, len(courses))) + cycle_length = min(len(courses), rng.randint(self.config.min_cycle_length, self.config.max_cycle_length)) cycle_courses = rng.sample(courses, cycle_length) for i in range(cycle_length): prerequisites.append([cycle_courses[i], cycle_courses[(i + 1) % cycle_length]]) @@ -109,18 +115,19 @@ def __getitem__(self, idx: int) -> dict: """Generate a single Course Schedule question""" rng = Random(self.seed + idx) - courses = list(range(self.config.num_courses)) + num_courses = rng.randint(self.config.min_num_courses, self.config.max_num_courses) + courses = list(range(num_courses)) rng.shuffle(courses) solvable = rng.random() < self.config.p_solvable prerequisites = self._create_prerequisites(rng, courses, solvable) - answer = self._can_finish(self.config.num_courses, prerequisites) + answer = self._can_finish(num_courses, prerequisites) return { "question": QUESTION_TEMPLATE.format( - num_courses=self.config.num_courses, - last_index=self.config.num_courses - 1, + num_courses=num_courses, + last_index=num_courses - 1, prerequisites=str(prerequisites), ), "answer": str(answer), @@ -128,4 +135,43 @@ def __getitem__(self, idx: int) -> dict: } +class CourseScheduleCurriculum(BaseCurriculum): + def __init__(self): + super().__init__(CourseScheduleCurriculum.__name__, CourseScheduleConfig) + + # Define attributes + self._define_attributes( + RangeAttributeDefinition( + name="num_courses", + levels=[10, 50, 100, 500], + default_level=0, # Start with 5 courses + description="Number of courses in the schedule", + attr_type=AttributeType.APPEND, + min_value=3, # Ensure at least 3 courses + lower_field_name="min_num_courses", + upper_field_name="max_num_courses", + ), + RangeAttributeDefinition( + name="num_prerequisites", + levels=[2, 3, 4, 5], + default_level=0, # Start with 2 prerequisites max + description="Number of prerequisites per course", + attr_type=AttributeType.APPEND, + min_value=0, + lower_field_name="min_num_prerequisites", + upper_field_name="max_num_prerequisites", + ), + RangeAttributeDefinition( + name="cycle_length", + levels=[3, 4, 5, 6], + default_level=0, # Start with 3 cycle length + description="Length of a cycle in the prerequisites", + attr_type=AttributeType.APPEND, + min_value=3, + lower_field_name="min_cycle_length", + upper_field_name="max_cycle_length", + ), + ) + + register_dataset("course_schedule", CourseScheduleDataset, CourseScheduleConfig) diff --git a/tests/test_course_schedule.py b/tests/test_course_schedule.py index d62c41cf..bfb49c98 100644 --- a/tests/test_course_schedule.py +++ b/tests/test_course_schedule.py @@ -2,29 +2,29 @@ import pytest -from reasoning_gym.graphs.course_schedule import CourseScheduleConfig, CourseScheduleDataset +from reasoning_gym.graphs.course_schedule import CourseScheduleConfig, CourseScheduleCurriculum, CourseScheduleDataset def test_course_schedule_config_validation(): """Test that invalid configs raise appropriate errors""" with pytest.raises(AssertionError): - config = CourseScheduleConfig(num_courses=-1) # Negative not allowed + config = CourseScheduleConfig(min_num_courses=2) # must be >= 3 config.validate() with pytest.raises(AssertionError): - config = CourseScheduleConfig(num_courses=0) # Zero not allowed + config = CourseScheduleConfig(min_num_courses=6, max_num_courses=5) # min > max config.validate() with pytest.raises(AssertionError): - config = CourseScheduleConfig(max_num_prerequisites=-1) # Negative not allowed + config = CourseScheduleConfig(min_num_prerequisites=-1) # neg not allowed config.validate() with pytest.raises(AssertionError): - config = CourseScheduleConfig(max_num_prerequisites=0) # Zero not allowed + config = CourseScheduleConfig(min_num_prerequisites=5, max_num_prerequisites=4) # min > max config.validate() with pytest.raises(AssertionError): - config = CourseScheduleConfig(num_courses=3, max_num_prerequisites=5) # max_num_prerequisites > num_courses + config = CourseScheduleConfig(max_num_prerequisites=0) # Zero not allowed config.validate() with pytest.raises(AssertionError): @@ -60,7 +60,7 @@ def test_course_schedule_dataset_deterministic(): def test_course_schedule_dataset_items(): """Test basic properties of generated items""" - config = CourseScheduleConfig(num_courses=15, size=10, seed=42) + config = CourseScheduleConfig(max_num_courses=15, size=10, seed=42) dataset = CourseScheduleDataset(config) for i in range(len(dataset)): @@ -83,13 +83,12 @@ def test_course_schedule_dataset_items(): solution = item["metadata"]["solution"] # Solution obtained from topological sort # Verify metadata - assert len(courses) == config.num_courses - assert max(courses) == config.num_courses - 1 - assert len(prerequisites) <= config.max_num_prerequisites * config.num_courses + assert len(courses) <= config.max_num_courses + assert len(prerequisites) <= config.max_num_prerequisites * len(courses) assert all(len(prereq) == 2 for prereq in prerequisites) for course, prereq in prerequisites: - assert course < config.num_courses - assert prereq < config.num_courses + assert course < len(courses) + assert prereq < len(courses) assert course != prereq assert solution == solvable @@ -125,3 +124,32 @@ def test_course_schedule_answer(): # Indirect cycle of length 3 prerequisites = [[0, 1], [1, 2], [2, 0]] assert dataset._can_finish(num_courses=3, prerequisites=prerequisites) == False + + +def test_course_schedule_curriculum(): + curriculum = CourseScheduleCurriculum() + + base_value = {"size": 150, "seed": 1} + + base_cfg: CourseScheduleConfig = curriculum.generate_configuration(base_value) + assert base_cfg.seed == 1 + assert base_cfg.size == 150 + assert base_cfg.min_num_courses == 10 and base_cfg.max_num_courses == 10 + assert base_cfg.min_num_prerequisites == 2 and base_cfg.max_num_prerequisites == 2 + assert base_cfg.min_cycle_length == 3 and base_cfg.max_cycle_length == 3 + + # test incrementing attribute levels + curriculum.increment_attr_level("num_courses") + curriculum.increment_attr_level("num_prerequisites") + curriculum.increment_attr_level("cycle_length") + increased_cfg = curriculum.generate_configuration(base_value) + assert increased_cfg.min_num_courses == 10 and increased_cfg.max_num_courses == 50 + assert increased_cfg.min_num_prerequisites == 2 and increased_cfg.max_num_prerequisites == 3 + assert increased_cfg.min_cycle_length == 3 and increased_cfg.max_cycle_length == 4 + + # test decrementing attribute level for num_courses again + curriculum.decrement_attr_level("num_courses") + partially_decreased_cfg = curriculum.generate_configuration(base_value) + assert partially_decreased_cfg.min_num_courses == 10 and partially_decreased_cfg.max_num_courses == 10 + assert partially_decreased_cfg.min_num_prerequisites == 2 and partially_decreased_cfg.max_num_prerequisites == 3 + assert partially_decreased_cfg.min_cycle_length == 3 and partially_decreased_cfg.max_cycle_length == 4