From d3f96aa717a3c420630a0d2a991e42bbee19c62d Mon Sep 17 00:00:00 2001 From: Zafir Stojanovski Date: Wed, 5 Mar 2025 13:45:15 +0100 Subject: [PATCH] largest island curriculum --- reasoning_gym/graphs/__init__.py | 4 +- reasoning_gym/graphs/largest_island.py | 101 ++++++++++++++++++++----- tests/test_largest_island.py | 75 ++++++++++++++---- 3 files changed, 143 insertions(+), 37 deletions(-) diff --git a/reasoning_gym/graphs/__init__.py b/reasoning_gym/graphs/__init__.py index fc4b9f19..cd648071 100644 --- a/reasoning_gym/graphs/__init__.py +++ b/reasoning_gym/graphs/__init__.py @@ -1,6 +1,6 @@ from .course_schedule import CourseScheduleConfig, CourseScheduleDataset from .family_relationships import FamilyRelationshipsConfig, FamilyRelationshipsDataset -from .largest_island import LargestIslandDataset +from .largest_island import LargestIslandConfig, LargestIslandCurriculum, LargestIslandDataset from .quantum_lock import QuantumLockConfig, QuantumLockDataset from .shortest_path import ShortestPathConfig, ShortestPathDataset @@ -10,6 +10,8 @@ "QuantumLockConfig", "QuantumLockDataset", "LargestIslandDataset", + "LargestIslandConfig", + "LargestIslandCurriculum", "CourseScheduleDataset", "CourseScheduleConfig", "ShortestPathConfig", diff --git a/reasoning_gym/graphs/largest_island.py b/reasoning_gym/graphs/largest_island.py index 528a7713..42b08095 100644 --- a/reasoning_gym/graphs/largest_island.py +++ b/reasoning_gym/graphs/largest_island.py @@ -9,10 +9,9 @@ from random import Random from typing import Optional +from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition from ..factory import ProceduralDataset, register_dataset -MIN_MAP_DIM = 1 - QUESTION_TEMPLATE = """You are given the following {rows} x {cols} binary matrix grid: {grid} @@ -29,11 +28,15 @@ class LargestIslandConfig: """Configuration for Largest Island dataset generation""" - rows: int = 10 # Number of rows in the grid - cols: int = 10 # Number of columns in the grid + min_rows: int = 5 # Minimum number of rows in the grid + max_rows: int = 10 # Maximum number of rows in the grid + min_cols: int = 5 # Minimum number of columns in the grid + max_cols: int = 10 # Maximum number of columns in the grid + min_num_islands: int = 0 max_num_islands: int = ( 5 # Maximum number of islands (actual max might be smaller due to merging of islands during random walk) ) + min_island_size: int = 0 max_island_size: int = ( 10 # Maximum size of an island (actual max might be larger due to merging of islands during random walk) ) @@ -43,10 +46,10 @@ class LargestIslandConfig: def validate(self): """Validate configuration parameters""" - assert MIN_MAP_DIM <= self.rows, f"rows must be between larger than {MIN_MAP_DIM}" - assert MIN_MAP_DIM <= self.cols, f"cols must be between larger than {MIN_MAP_DIM}" - assert 0 <= self.max_num_islands, "max_num_islands must be non-negative" - assert 0 <= self.max_island_size, "max_island_size must be non-negative" + assert 1 <= self.min_rows <= self.max_rows, "Invalid rows range" + assert 1 <= self.min_cols <= self.max_cols, "Invalid cols range" + assert 0 <= self.min_num_islands <= self.max_num_islands, "Invalid num_islands range" + assert 0 <= self.min_island_size <= self.max_island_size, "Invalid island_size range" class LargestIslandDataset(ProceduralDataset): @@ -55,27 +58,27 @@ class LargestIslandDataset(ProceduralDataset): def __init__(self, config: LargestIslandConfig): super().__init__(config=config, seed=config.seed, size=config.size) - def _is_valid_cell(self, r: int, c: int) -> bool: - return 0 <= r < self.config.rows and 0 <= c < self.config.cols + def _is_valid_cell(self, r: int, c: int, rows: int, cols: int) -> bool: + return 0 <= r < rows and 0 <= c < cols - def _create_grid(self, rng: Random) -> list[list[int]]: + def _create_grid(self, rng: Random, rows: int, cols: int) -> list[list[int]]: """Create a random grid of islands using a random walk algorithm""" - grid = [[0] * self.config.cols for _ in range(self.config.rows)] + grid = [[0] * cols for _ in range(rows)] directions = [(-1, 0), (1, 0), (0, -1), (0, 1)] # Up, Down, Left, Right def create_island(): - r, c = rng.randint(0, self.config.rows - 1), rng.randint(0, self.config.cols - 1) - capped_size = min(rng.randint(0, self.config.max_island_size), self.config.rows * self.config.cols) + r, c = rng.randint(0, rows - 1), rng.randint(0, cols - 1) + capped_size = min(rng.randint(self.config.min_island_size, self.config.max_island_size), rows * cols) for _ in range(capped_size): grid[r][c] = 1 rng.shuffle(directions) for dr, dc in directions: new_r, new_c = r + dr, c + dc - if self._is_valid_cell(new_r, new_c) and grid[new_r][new_c] == 0: + if self._is_valid_cell(new_r, new_c, rows, cols) and grid[new_r][new_c] == 0: r, c = new_r, new_c break - num_islands = rng.randint(0, self.config.max_num_islands) + num_islands = rng.randint(self.config.min_num_islands, self.config.max_num_islands) for _ in range(num_islands): create_island() @@ -83,6 +86,7 @@ def create_island(): def _get_largest_island(self, grid: list[list[int]]) -> int: """Find the largest island in the grid""" + rows, cols = len(grid), len(grid[0]) directions = [(-1, 0), (1, 0), (0, -1), (0, 1)] # Up, Down, Left, Right visited = set() @@ -94,15 +98,19 @@ def bfs(r, c): r, c = queue.popleft() for dr, dc in directions: new_r, new_c = r + dr, c + dc - if self._is_valid_cell(new_r, new_c) and (new_r, new_c) not in visited and grid[new_r][new_c] == 1: + if ( + self._is_valid_cell(new_r, new_c, rows, cols) + and (new_r, new_c) not in visited + and grid[new_r][new_c] == 1 + ): area += 1 visited.add((new_r, new_c)) queue.append((new_r, new_c)) return area max_area = 0 - for r in range(self.config.rows): - for c in range(self.config.cols): + for r in range(rows): + for c in range(cols): if grid[r][c] == 1 and (r, c) not in visited: max_area = max(max_area, bfs(r, c)) @@ -120,16 +128,67 @@ def __getitem__(self, idx: int) -> dict: """Generate a single Largest Island question""" rng = Random(self.seed + idx) - grid = self._create_grid(rng) + rows = rng.randint(self.config.min_rows, self.config.max_rows) + cols = rng.randint(self.config.min_cols, self.config.max_cols) + grid = self._create_grid(rng, rows, cols) grid_str = self._grid_to_string(grid) answer = self._get_largest_island(grid) return { - "question": QUESTION_TEMPLATE.format(rows=self.config.rows, cols=self.config.cols, grid=grid_str), + "question": QUESTION_TEMPLATE.format(rows=rows, cols=cols, grid=grid_str), "answer": str(answer), "metadata": {"grid": grid, "solution": answer}, } +class LargestIslandCurriculum(BaseCurriculum): + def __init__(self): + super().__init__(LargestIslandCurriculum.__name__, LargestIslandConfig) + + # Define attributes + self._define_attributes( + RangeAttributeDefinition( + name="rows", + levels=[5, 10, 50, 100], + default_level=0, + description="Number of rows in the grid", + attr_type=AttributeType.APPEND, + min_value=1, + lower_field_name="min_rows", + upper_field_name="max_rows", + ), + RangeAttributeDefinition( + name="cols", + levels=[5, 10, 50, 100], + default_level=0, + description="Number of columns in the grid", + attr_type=AttributeType.APPEND, + min_value=1, + lower_field_name="min_cols", + upper_field_name="max_cols", + ), + RangeAttributeDefinition( + name="num_islands", + levels=[2, 5, 10, 20], + default_level=0, + description="Number of islands in the grid", + attr_type=AttributeType.APPEND, + min_value=0, + lower_field_name="min_num_islands", + upper_field_name="max_num_islands", + ), + RangeAttributeDefinition( + name="island_size", + levels=[5, 10, 20, 30], + default_level=0, + description="Size of the islands in the grid", + attr_type=AttributeType.APPEND, + min_value=0, + lower_field_name="min_island_size", + upper_field_name="max_island_size", + ), + ) + + register_dataset("largest_island", LargestIslandDataset, LargestIslandConfig) diff --git a/tests/test_largest_island.py b/tests/test_largest_island.py index 70475dd6..6495eb9f 100644 --- a/tests/test_largest_island.py +++ b/tests/test_largest_island.py @@ -2,33 +2,41 @@ import pytest -from reasoning_gym.graphs.largest_island import LargestIslandConfig, LargestIslandDataset +from reasoning_gym.graphs.largest_island import LargestIslandConfig, LargestIslandCurriculum, LargestIslandDataset def test_largest_island_config_validation(): """Test that invalid configs raise appropriate errors""" with pytest.raises(AssertionError): - config = LargestIslandConfig(rows=-1) # Negative not allowed + config = LargestIslandConfig(min_rows=0) # 0 not allowed config.validate() with pytest.raises(AssertionError): - config = LargestIslandConfig(rows=0) # Zero not allowed + config = LargestIslandConfig(min_cols=0) # 0 not allowed config.validate() with pytest.raises(AssertionError): - config = LargestIslandConfig(cols=-1) # Negative not allowed + config = LargestIslandConfig(min_rows=10, max_rows=5) # min > max config.validate() with pytest.raises(AssertionError): - config = LargestIslandConfig(cols=0) # Zero not allowed + config = LargestIslandConfig(min_cols=10, max_cols=5) # min > max config.validate() with pytest.raises(AssertionError): - config = LargestIslandConfig(max_num_islands=-1) # Negative not allowed + config = LargestIslandConfig(min_num_islands=-1) # neg not allowed config.validate() with pytest.raises(AssertionError): - config = LargestIslandConfig(max_island_size=-1) # Negative not allowed + config = LargestIslandConfig(min_island_size=-1) # neg not allowed + config.validate() + + with pytest.raises(AssertionError): + config = LargestIslandConfig(min_num_islands=5, max_num_islands=3) # min > max + config.validate() + + with pytest.raises(AssertionError): + config = LargestIslandConfig(min_island_size=5, max_island_size=3) # min > max config.validate() @@ -44,7 +52,14 @@ def test_largest_island_dataset_deterministic(): def test_largest_island_dataset_items(): """Test basic properties of generated items""" - config = LargestIslandConfig(rows=8, cols=8, max_island_size=5, size=10, seed=42) + config = LargestIslandConfig( + min_rows=5, + max_rows=10, + min_cols=5, + max_cols=10, + size=10, + seed=42, + ) dataset = LargestIslandDataset(config) for i in range(len(dataset)): @@ -60,12 +75,10 @@ def test_largest_island_dataset_items(): assert "solution" in item["metadata"] grid = item["metadata"]["grid"] - solution = item["metadata"]["solution"] # Verify grid dimensions - assert len(grid) == 8 - assert all(len(row) == 8 for row in grid) - assert 0 <= solution <= 5 + assert 5 <= len(grid) <= 10 + assert all(0 <= len(row) <= 10 for row in grid) def test_largest_island_dataset_iteration(): @@ -82,19 +95,18 @@ def test_largest_island_dataset_iteration(): def test_largest_island_grid_generation(): """Test that generated grids are valid""" - config = LargestIslandConfig(rows=10, cols=10, max_island_size=3, size=5, seed=42) + config = LargestIslandConfig(size=5, seed=42) dataset = LargestIslandDataset(config) for i in range(len(dataset)): item = dataset[i] - assert item["metadata"]["solution"] <= 3 for row in item["metadata"]["grid"]: assert all(cell in {0, 1} for cell in row) def test_largest_island_answer(): """Test the _get_largest_island method""" - config = LargestIslandConfig(rows=5, cols=5, seed=42) + config = LargestIslandConfig(seed=42) dataset = LargestIslandDataset(config) grid = [ @@ -125,3 +137,36 @@ def test_largest_island_answer(): [0, 0, 0, 1, 1], ] assert dataset._get_largest_island(grid) == 9 + + +def test_largest_island_curriculum(): + curriculum = LargestIslandCurriculum() + + base_value = {"size": 150, "seed": 1} + + base_cfg: LargestIslandConfig = curriculum.generate_configuration(base_value) + assert base_cfg.seed == 1 + assert base_cfg.size == 150 + assert base_cfg.min_rows == 5 and base_cfg.max_rows == 5 + assert base_cfg.min_cols == 5 and base_cfg.max_cols == 5 + assert base_cfg.min_num_islands == 2 and base_cfg.max_num_islands == 2 + assert base_cfg.min_island_size == 5 and base_cfg.max_island_size == 5 + + # test incrementing attribute levels + curriculum.increment_attr_level("rows") + curriculum.increment_attr_level("cols") + curriculum.increment_attr_level("num_islands") + curriculum.increment_attr_level("island_size") + increased_cfg = curriculum.generate_configuration(base_value) + assert increased_cfg.min_rows == 5 and increased_cfg.max_rows == 10 + assert increased_cfg.min_cols == 5 and increased_cfg.max_cols == 10 + assert increased_cfg.min_num_islands == 2 and increased_cfg.max_num_islands == 5 + assert increased_cfg.min_island_size == 5 and increased_cfg.max_island_size == 10 + + # test decrementing attribute level for num_islands again + curriculum.decrement_attr_level("num_islands") + partially_decreased_cfg = curriculum.generate_configuration(base_value) + assert partially_decreased_cfg.min_rows == 5 and partially_decreased_cfg.max_rows == 10 + assert partially_decreased_cfg.min_cols == 5 and partially_decreased_cfg.max_cols == 10 + assert partially_decreased_cfg.min_num_islands == 2 and partially_decreased_cfg.max_num_islands == 2 + assert partially_decreased_cfg.min_island_size == 5 and partially_decreased_cfg.max_island_size == 10