diff --git a/reasoning_gym/graphs/__init__.py b/reasoning_gym/graphs/__init__.py index fc4b9f19..adf57fd8 100644 --- a/reasoning_gym/graphs/__init__.py +++ b/reasoning_gym/graphs/__init__.py @@ -2,7 +2,7 @@ from .family_relationships import FamilyRelationshipsConfig, FamilyRelationshipsDataset from .largest_island import LargestIslandDataset from .quantum_lock import QuantumLockConfig, QuantumLockDataset -from .shortest_path import ShortestPathConfig, ShortestPathDataset +from .shortest_path import ShortestPathConfig, ShortestPathCurriculum, ShortestPathDataset __all__ = [ "FamilyRelationshipsConfig", @@ -14,4 +14,5 @@ "CourseScheduleConfig", "ShortestPathConfig", "ShortestPathDataset", + "ShortestPathCurriculum", ] diff --git a/reasoning_gym/graphs/shortest_path.py b/reasoning_gym/graphs/shortest_path.py index 43c45f54..f7924242 100644 --- a/reasoning_gym/graphs/shortest_path.py +++ b/reasoning_gym/graphs/shortest_path.py @@ -5,6 +5,7 @@ from random import Random from typing import Any, Optional +from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset QUESTION_TEMPLATE = """Your task is to find the shortest path from the start to the destination point in a grid. @@ -163,4 +164,33 @@ def __getitem__(self, idx: int) -> dict: } +class ShortestPathCurriculum(BaseCurriculum): + def __init__(self): + super().__init__(ShortestPathCurriculum.__name__, ShortestPathConfig) + + # Define attributes + self._define_attributes( + RangeAttributeDefinition( + name="rows", + levels=[10, 25, 50, 100], + default_level=0, + description="Number of rows in the grid", + attr_type=AttributeType.APPEND, + min_value=2, + lower_field_name="min_rows", + upper_field_name="max_rows", + ), + RangeAttributeDefinition( + name="cols", + levels=[10, 25, 50, 100], + default_level=0, + description="Number of columns in the grid", + attr_type=AttributeType.APPEND, + min_value=2, + lower_field_name="min_cols", + upper_field_name="max_cols", + ), + ) + + register_dataset("shortest_path", ShortestPathDataset, ShortestPathConfig) diff --git a/tests/test_shortest_path.py b/tests/test_shortest_path.py index 57950a1e..e363f655 100644 --- a/tests/test_shortest_path.py +++ b/tests/test_shortest_path.py @@ -2,7 +2,7 @@ import pytest -from reasoning_gym.graphs.shortest_path import ShortestPathConfig, ShortestPathDataset +from reasoning_gym.graphs.shortest_path import ShortestPathConfig, ShortestPathCurriculum, ShortestPathDataset def test_shortest_path_config_validation(): @@ -179,3 +179,28 @@ def test_shortest_path_answer(): }, } assert dataset.score_answer(None, entry) == 0.0 + + +def test_chain_sum_curriculum(): + curriculum = ShortestPathCurriculum() + + base_value = {"size": 150, "seed": 1} + + base_cfg: ShortestPathConfig = curriculum.generate_configuration(base_value) + assert base_cfg.seed == 1 + assert base_cfg.size == 150 + assert base_cfg.min_rows == 10 and base_cfg.max_rows == 10 + assert base_cfg.min_cols == 10 and base_cfg.max_cols == 10 + + # test incrementing attribute levels + curriculum.increment_attr_level("rows") + curriculum.increment_attr_level("cols") + increased_cfg = curriculum.generate_configuration(base_value) + assert increased_cfg.min_rows == 10 and increased_cfg.max_rows == 25 + assert increased_cfg.min_cols == 10 and increased_cfg.max_cols == 25 + + # test decrementing attribute level for rows again + curriculum.decrement_attr_level("rows") + partially_decreased_cfg = curriculum.generate_configuration(base_value) + assert partially_decreased_cfg.min_rows == 10 and partially_decreased_cfg.max_rows == 10 + assert partially_decreased_cfg.min_cols == 10 and partially_decreased_cfg.max_cols == 25