Skip to content

Commit

Permalink
shortest path curriculum
Browse files Browse the repository at this point in the history
  • Loading branch information
zafstojano committed Mar 5, 2025
1 parent 5d7fbac commit 5b73af7
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 2 deletions.
3 changes: 2 additions & 1 deletion reasoning_gym/graphs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .family_relationships import FamilyRelationshipsConfig, FamilyRelationshipsDataset
from .largest_island import LargestIslandDataset
from .quantum_lock import QuantumLockConfig, QuantumLockDataset
from .shortest_path import ShortestPathConfig, ShortestPathDataset
from .shortest_path import ShortestPathConfig, ShortestPathCurriculum, ShortestPathDataset

__all__ = [
"FamilyRelationshipsConfig",
Expand All @@ -14,4 +14,5 @@
"CourseScheduleConfig",
"ShortestPathConfig",
"ShortestPathDataset",
"ShortestPathCurriculum",
]
30 changes: 30 additions & 0 deletions reasoning_gym/graphs/shortest_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from random import Random
from typing import Any, Optional

from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset

QUESTION_TEMPLATE = """Your task is to find the shortest path from the start to the destination point in a grid.
Expand Down Expand Up @@ -163,4 +164,33 @@ def __getitem__(self, idx: int) -> dict:
}


class ShortestPathCurriculum(BaseCurriculum):
def __init__(self):
super().__init__(ShortestPathCurriculum.__name__, ShortestPathConfig)

# Define attributes
self._define_attributes(
RangeAttributeDefinition(
name="rows",
levels=[10, 25, 50, 100],
default_level=0,
description="Number of rows in the grid",
attr_type=AttributeType.APPEND,
min_value=2,
lower_field_name="min_rows",
upper_field_name="max_rows",
),
RangeAttributeDefinition(
name="cols",
levels=[10, 25, 50, 100],
default_level=0,
description="Number of columns in the grid",
attr_type=AttributeType.APPEND,
min_value=2,
lower_field_name="min_cols",
upper_field_name="max_cols",
),
)


register_dataset("shortest_path", ShortestPathDataset, ShortestPathConfig)
27 changes: 26 additions & 1 deletion tests/test_shortest_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from reasoning_gym.graphs.shortest_path import ShortestPathConfig, ShortestPathDataset
from reasoning_gym.graphs.shortest_path import ShortestPathConfig, ShortestPathCurriculum, ShortestPathDataset


def test_shortest_path_config_validation():
Expand Down Expand Up @@ -179,3 +179,28 @@ def test_shortest_path_answer():
},
}
assert dataset.score_answer(None, entry) == 0.0


def test_chain_sum_curriculum():
curriculum = ShortestPathCurriculum()

base_value = {"size": 150, "seed": 1}

base_cfg: ShortestPathConfig = curriculum.generate_configuration(base_value)
assert base_cfg.seed == 1
assert base_cfg.size == 150
assert base_cfg.min_rows == 10 and base_cfg.max_rows == 10
assert base_cfg.min_cols == 10 and base_cfg.max_cols == 10

# test incrementing attribute levels
curriculum.increment_attr_level("rows")
curriculum.increment_attr_level("cols")
increased_cfg = curriculum.generate_configuration(base_value)
assert increased_cfg.min_rows == 10 and increased_cfg.max_rows == 25
assert increased_cfg.min_cols == 10 and increased_cfg.max_cols == 25

# test decrementing attribute level for rows again
curriculum.decrement_attr_level("rows")
partially_decreased_cfg = curriculum.generate_configuration(base_value)
assert partially_decreased_cfg.min_rows == 10 and partially_decreased_cfg.max_rows == 10
assert partially_decreased_cfg.min_cols == 10 and partially_decreased_cfg.max_cols == 25

0 comments on commit 5b73af7

Please sign in to comment.