Skip to content

Commit

Permalink
feat(env): Course Schedule Curriculum (#266)
Browse files Browse the repository at this point in the history
* course schedule curriculum

* update levels

* update comments

* lint
  • Loading branch information
zafstojano authored Mar 5, 2025
1 parent f3ee9a9 commit 8ccc4d7
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 27 deletions.
3 changes: 2 additions & 1 deletion reasoning_gym/graphs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .course_schedule import CourseScheduleConfig, CourseScheduleDataset
from .course_schedule import CourseScheduleConfig, CourseScheduleCurriculum, CourseScheduleDataset
from .family_relationships import FamilyRelationshipsConfig, FamilyRelationshipsDataset
from .largest_island import LargestIslandDataset
from .quantum_lock import QuantumLockConfig, QuantumLockDataset
Expand All @@ -12,6 +12,7 @@
"LargestIslandDataset",
"CourseScheduleDataset",
"CourseScheduleConfig",
"CourseScheduleCurriculum",
"ShortestPathConfig",
"ShortestPathDataset",
]
74 changes: 60 additions & 14 deletions reasoning_gym/graphs/course_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,9 @@
from random import Random
from typing import Optional

from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset

MAX_NUM_COURSES = 1_000

QUESTION_TEMPLATE = """There are a total of {num_courses} courses you have to take, labeled from 0 to {last_index}.
You are given the following list of prerequisites, where prerequisites[i] = (a_i, b_i) indicates that you must first take course b_i if you want to take course a_i:
Expand All @@ -27,25 +26,29 @@
class CourseScheduleConfig:
"""Configuration for Course Schedule dataset generation"""

num_courses: int = 5 # Total number of courses (ranging from 0 to num_courses - 1)
min_num_courses: int = 5 # Minimum number of courses
max_num_courses: int = 10 # Maximum number of courses
min_num_prerequisites: int = 1 # Minimum number of prerequisites (per course)
max_num_prerequisites: int = 2 # Maximum number of prerequisites (per course)
p_solvable: float = 0.5 # Probability that the course schedule is solvable
min_cycle_length: int = 3 # Minimum length of a cycle in the prerequisites (if unsolvable)
max_cycle_length: int = 5 # Maximum length of a cycle in the prerequisites (if unsolvable)
p_solvable: float = 0.5 # Probability that the course schedule is solvable

size: int = 500 # Virtual dataset size
seed: Optional[int] = None

def validate(self):
"""Validate configuration parameters"""
assert 1 <= self.num_courses <= MAX_NUM_COURSES, f"num_courses must be between 1 and {MAX_NUM_COURSES}"
assert (
1 <= self.max_num_prerequisites <= self.num_courses
), "max_num_prerequisites must be between 1 and num_courses"
assert 0 <= self.p_solvable <= 1, "p_solvable must be between 0 and 1"
3 <= self.min_num_courses <= self.max_num_courses
), "min_num_courses must be between 3 and max_num_courses"
assert (
3 <= self.min_cycle_length <= self.max_cycle_length
), "min_cycle_length must be between 3 and max_cycle_length"
assert (
1 <= self.min_num_prerequisites <= self.max_num_prerequisites
), "min_num_prerequisites must be between 0 and max_num_prerequisites"
assert 0 <= self.p_solvable <= 1, "p_solvable must be between 0 and 1"


class CourseScheduleDataset(ProceduralDataset):
Expand Down Expand Up @@ -88,14 +91,17 @@ def _create_prerequisites(self, rng: Random, courses: list[int], solvable: bool)
for idx in range(len(courses) - 1, 0, -1):
current_course = courses[idx]
available_prereqs = courses[:idx] # Only earlier courses can be prerequisites
num_prerequisites = rng.randint(0, min(len(available_prereqs), self.config.max_num_prerequisites))
num_prerequisites = min(
len(available_prereqs),
rng.randint(self.config.min_num_prerequisites, self.config.max_num_prerequisites),
)
if num_prerequisites > 0:
chosen_prereqs = rng.sample(available_prereqs, num_prerequisites)
prerequisites.extend([[current_course, p] for p in chosen_prereqs])

if not solvable:
# If solution should be unsolvable, create a cycle
cycle_length = rng.randint(self.config.min_cycle_length, min(self.config.max_cycle_length, len(courses)))
cycle_length = min(len(courses), rng.randint(self.config.min_cycle_length, self.config.max_cycle_length))
cycle_courses = rng.sample(courses, cycle_length)
for i in range(cycle_length):
prerequisites.append([cycle_courses[i], cycle_courses[(i + 1) % cycle_length]])
Expand All @@ -109,23 +115,63 @@ def __getitem__(self, idx: int) -> dict:
"""Generate a single Course Schedule question"""
rng = Random(self.seed + idx)

courses = list(range(self.config.num_courses))
num_courses = rng.randint(self.config.min_num_courses, self.config.max_num_courses)
courses = list(range(num_courses))
rng.shuffle(courses)

solvable = rng.random() < self.config.p_solvable

prerequisites = self._create_prerequisites(rng, courses, solvable)
answer = self._can_finish(self.config.num_courses, prerequisites)
answer = self._can_finish(num_courses, prerequisites)

return {
"question": QUESTION_TEMPLATE.format(
num_courses=self.config.num_courses,
last_index=self.config.num_courses - 1,
num_courses=num_courses,
last_index=num_courses - 1,
prerequisites=str(prerequisites),
),
"answer": str(answer),
"metadata": {"courses": courses, "prerequisites": prerequisites, "solution": answer, "solvable": solvable},
}


class CourseScheduleCurriculum(BaseCurriculum):
def __init__(self):
super().__init__(CourseScheduleCurriculum.__name__, CourseScheduleConfig)

# Define attributes
self._define_attributes(
RangeAttributeDefinition(
name="num_courses",
levels=[10, 50, 100, 500],
default_level=0, # Start with 5 courses
description="Number of courses in the schedule",
attr_type=AttributeType.APPEND,
min_value=3, # Ensure at least 3 courses
lower_field_name="min_num_courses",
upper_field_name="max_num_courses",
),
RangeAttributeDefinition(
name="num_prerequisites",
levels=[2, 3, 4, 5],
default_level=0, # Start with 2 prerequisites max
description="Number of prerequisites per course",
attr_type=AttributeType.APPEND,
min_value=0,
lower_field_name="min_num_prerequisites",
upper_field_name="max_num_prerequisites",
),
RangeAttributeDefinition(
name="cycle_length",
levels=[3, 4, 5, 6],
default_level=0, # Start with 3 cycle length
description="Length of a cycle in the prerequisites",
attr_type=AttributeType.APPEND,
min_value=3,
lower_field_name="min_cycle_length",
upper_field_name="max_cycle_length",
),
)


register_dataset("course_schedule", CourseScheduleDataset, CourseScheduleConfig)
52 changes: 40 additions & 12 deletions tests/test_course_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,29 @@

import pytest

from reasoning_gym.graphs.course_schedule import CourseScheduleConfig, CourseScheduleDataset
from reasoning_gym.graphs.course_schedule import CourseScheduleConfig, CourseScheduleCurriculum, CourseScheduleDataset


def test_course_schedule_config_validation():
"""Test that invalid configs raise appropriate errors"""
with pytest.raises(AssertionError):
config = CourseScheduleConfig(num_courses=-1) # Negative not allowed
config = CourseScheduleConfig(min_num_courses=2) # must be >= 3
config.validate()

with pytest.raises(AssertionError):
config = CourseScheduleConfig(num_courses=0) # Zero not allowed
config = CourseScheduleConfig(min_num_courses=6, max_num_courses=5) # min > max
config.validate()

with pytest.raises(AssertionError):
config = CourseScheduleConfig(max_num_prerequisites=-1) # Negative not allowed
config = CourseScheduleConfig(min_num_prerequisites=-1) # neg not allowed
config.validate()

with pytest.raises(AssertionError):
config = CourseScheduleConfig(max_num_prerequisites=0) # Zero not allowed
config = CourseScheduleConfig(min_num_prerequisites=5, max_num_prerequisites=4) # min > max
config.validate()

with pytest.raises(AssertionError):
config = CourseScheduleConfig(num_courses=3, max_num_prerequisites=5) # max_num_prerequisites > num_courses
config = CourseScheduleConfig(max_num_prerequisites=0) # Zero not allowed
config.validate()

with pytest.raises(AssertionError):
Expand Down Expand Up @@ -60,7 +60,7 @@ def test_course_schedule_dataset_deterministic():

def test_course_schedule_dataset_items():
"""Test basic properties of generated items"""
config = CourseScheduleConfig(num_courses=15, size=10, seed=42)
config = CourseScheduleConfig(max_num_courses=15, size=10, seed=42)
dataset = CourseScheduleDataset(config)

for i in range(len(dataset)):
Expand All @@ -83,13 +83,12 @@ def test_course_schedule_dataset_items():
solution = item["metadata"]["solution"] # Solution obtained from topological sort

# Verify metadata
assert len(courses) == config.num_courses
assert max(courses) == config.num_courses - 1
assert len(prerequisites) <= config.max_num_prerequisites * config.num_courses
assert len(courses) <= config.max_num_courses
assert len(prerequisites) <= config.max_num_prerequisites * len(courses)
assert all(len(prereq) == 2 for prereq in prerequisites)
for course, prereq in prerequisites:
assert course < config.num_courses
assert prereq < config.num_courses
assert course < len(courses)
assert prereq < len(courses)
assert course != prereq
assert solution == solvable

Expand Down Expand Up @@ -125,3 +124,32 @@ def test_course_schedule_answer():
# Indirect cycle of length 3
prerequisites = [[0, 1], [1, 2], [2, 0]]
assert dataset._can_finish(num_courses=3, prerequisites=prerequisites) == False


def test_course_schedule_curriculum():
curriculum = CourseScheduleCurriculum()

base_value = {"size": 150, "seed": 1}

base_cfg: CourseScheduleConfig = curriculum.generate_configuration(base_value)
assert base_cfg.seed == 1
assert base_cfg.size == 150
assert base_cfg.min_num_courses == 10 and base_cfg.max_num_courses == 10
assert base_cfg.min_num_prerequisites == 2 and base_cfg.max_num_prerequisites == 2
assert base_cfg.min_cycle_length == 3 and base_cfg.max_cycle_length == 3

# test incrementing attribute levels
curriculum.increment_attr_level("num_courses")
curriculum.increment_attr_level("num_prerequisites")
curriculum.increment_attr_level("cycle_length")
increased_cfg = curriculum.generate_configuration(base_value)
assert increased_cfg.min_num_courses == 10 and increased_cfg.max_num_courses == 50
assert increased_cfg.min_num_prerequisites == 2 and increased_cfg.max_num_prerequisites == 3
assert increased_cfg.min_cycle_length == 3 and increased_cfg.max_cycle_length == 4

# test decrementing attribute level for num_courses again
curriculum.decrement_attr_level("num_courses")
partially_decreased_cfg = curriculum.generate_configuration(base_value)
assert partially_decreased_cfg.min_num_courses == 10 and partially_decreased_cfg.max_num_courses == 10
assert partially_decreased_cfg.min_num_prerequisites == 2 and partially_decreased_cfg.max_num_prerequisites == 3
assert partially_decreased_cfg.min_cycle_length == 3 and partially_decreased_cfg.max_cycle_length == 4

0 comments on commit 8ccc4d7

Please sign in to comment.