From 8ecc72360709dae8aec41eccc86c8f6369dce219 Mon Sep 17 00:00:00 2001 From: Zafir Stojanovski Date: Wed, 5 Mar 2025 15:05:17 +0100 Subject: [PATCH] feat(env): NQueens Curriculum (#262) * curriculum & tests --- reasoning_gym/coaching/__init__.py | 3 ++- reasoning_gym/games/__init__.py | 4 +++- reasoning_gym/games/n_queens.py | 28 ++++++++++++++++++++++++++++ tests/test_n_queens.py | 27 ++++++++++++++++++++++++++- 4 files changed, 59 insertions(+), 3 deletions(-) diff --git a/reasoning_gym/coaching/__init__.py b/reasoning_gym/coaching/__init__.py index 50683d66..4d96bca4 100644 --- a/reasoning_gym/coaching/__init__.py +++ b/reasoning_gym/coaching/__init__.py @@ -1,10 +1,11 @@ -from .attributes import AttributeDefinition, AttributeType, RangeAttributeDefinition +from .attributes import AttributeDefinition, AttributeType, RangeAttributeDefinition, ScalarAttributeDefinition from .base_curriculum import BaseCurriculum from .coach import Coach, GroupedScores, ScoreBoard, ScoreStats __all__ = [ "AttributeType", "AttributeDefinition", + "ScalarAttributeDefinition", "RangeAttributeDefinition", "BaseCurriculum", "Coach", diff --git a/reasoning_gym/games/__init__.py b/reasoning_gym/games/__init__.py index 4b450134..f4d718b7 100644 --- a/reasoning_gym/games/__init__.py +++ b/reasoning_gym/games/__init__.py @@ -13,7 +13,7 @@ from .mahjong import MahjongPuzzleConfig, MahjongPuzzleDataset from .maze import MazeConfig, MazeDataset from .mini_sudoku import MiniSudokuConfig, MiniSudokuDataset -from .n_queens import NQueensDataset +from .n_queens import NQueensConfig, NQueensCurriculum, NQueensDataset from .rush_hour import RushHourConfig, RushHourDataset from .sokoban import SokobanConfig, SokobanDataset from .sudoku import SudokuConfig, SudokuDataset @@ -40,6 +40,8 @@ "HanoiConfig", "HanoiDataset", "NQueensDataset", + "NQueensConfig", + "NQueensCurriculum", "TsumegoConfig", "TsumegoDataset", "KnightSwapConfig", diff --git a/reasoning_gym/games/n_queens.py b/reasoning_gym/games/n_queens.py index 3a07bb10..723a4ef8 100644 --- a/reasoning_gym/games/n_queens.py +++ b/reasoning_gym/games/n_queens.py @@ -9,6 +9,7 @@ from random import Random from typing import Any, Optional +from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition, ScalarAttributeDefinition from ..factory import ProceduralDataset, register_dataset MIN_BOARD_SIZE = 4 @@ -151,4 +152,31 @@ def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float: return 0.0 +class NQueensCurriculum(BaseCurriculum): + def __init__(self): + super().__init__(NQueensCurriculum.__name__, NQueensConfig) + + self._define_attributes( + ScalarAttributeDefinition( + name="n", + field_name="n", + levels=[4, 6, 8, 12], + default_level=0, + description="Board size", + attr_type=AttributeType.STATIC, + min_value=4, + ), + RangeAttributeDefinition( + name="num_removed", + levels=[2, 4, 6, 10], + default_level=0, + description="Number of queens to remove", + attr_type=AttributeType.APPEND, + min_value=1, + lower_field_name="min_remove", + upper_field_name="max_remove", + ), + ) + + register_dataset("n_queens", NQueensDataset, NQueensConfig) diff --git a/tests/test_n_queens.py b/tests/test_n_queens.py index 6182540b..0a9cc37d 100644 --- a/tests/test_n_queens.py +++ b/tests/test_n_queens.py @@ -2,7 +2,7 @@ import pytest -from reasoning_gym.games.n_queens import NQueensConfig, NQueensDataset +from reasoning_gym.games.n_queens import NQueensConfig, NQueensCurriculum, NQueensDataset def test_nqueens_config_validation(): @@ -146,3 +146,28 @@ def is_valid_solution(board: list[list[str]]) -> bool: off_diags.add(r - c) return num_queens == n + + +def test_n_queens_curriculum(): + curriculum = NQueensCurriculum() + + base_value = {"size": 150, "seed": 1} + + base_cfg: NQueensConfig = curriculum.generate_configuration(base_value) + assert base_cfg.seed == 1 + assert base_cfg.size == 150 + assert base_cfg.n == 4 + assert base_cfg.min_remove == 2 and base_cfg.max_remove == 2 + + # test incrementing attribute levels for n & num_removed attributes + curriculum.increment_attr_level("n") + curriculum.increment_attr_level("num_removed") + increased_cfg = curriculum.generate_configuration(base_value) + assert increased_cfg.n == 6 + assert increased_cfg.min_remove == 2 and increased_cfg.max_remove == 4 + + # test decrementing attribute level for n again + curriculum.decrement_attr_level("n") + partially_decreased_cfg = curriculum.generate_configuration(base_value) + assert partially_decreased_cfg.n == 4 + assert partially_decreased_cfg.min_remove == 2 and partially_decreased_cfg.max_remove == 4