Skip to content

Commit

Permalink
largest island curriculum
Browse files Browse the repository at this point in the history
  • Loading branch information
zafstojano committed Mar 5, 2025
1 parent 5d7fbac commit d3f96aa
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 37 deletions.
4 changes: 3 additions & 1 deletion reasoning_gym/graphs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .course_schedule import CourseScheduleConfig, CourseScheduleDataset
from .family_relationships import FamilyRelationshipsConfig, FamilyRelationshipsDataset
from .largest_island import LargestIslandDataset
from .largest_island import LargestIslandConfig, LargestIslandCurriculum, LargestIslandDataset
from .quantum_lock import QuantumLockConfig, QuantumLockDataset
from .shortest_path import ShortestPathConfig, ShortestPathDataset

Expand All @@ -10,6 +10,8 @@
"QuantumLockConfig",
"QuantumLockDataset",
"LargestIslandDataset",
"LargestIslandConfig",
"LargestIslandCurriculum",
"CourseScheduleDataset",
"CourseScheduleConfig",
"ShortestPathConfig",
Expand Down
101 changes: 80 additions & 21 deletions reasoning_gym/graphs/largest_island.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@
from random import Random
from typing import Optional

from ..coaching import AttributeType, BaseCurriculum, RangeAttributeDefinition
from ..factory import ProceduralDataset, register_dataset

MIN_MAP_DIM = 1

QUESTION_TEMPLATE = """You are given the following {rows} x {cols} binary matrix grid:
{grid}
Expand All @@ -29,11 +28,15 @@
class LargestIslandConfig:
"""Configuration for Largest Island dataset generation"""

rows: int = 10 # Number of rows in the grid
cols: int = 10 # Number of columns in the grid
min_rows: int = 5 # Minimum number of rows in the grid
max_rows: int = 10 # Maximum number of rows in the grid
min_cols: int = 5 # Minimum number of columns in the grid
max_cols: int = 10 # Maximum number of columns in the grid
min_num_islands: int = 0
max_num_islands: int = (
5 # Maximum number of islands (actual max might be smaller due to merging of islands during random walk)
)
min_island_size: int = 0
max_island_size: int = (
10 # Maximum size of an island (actual max might be larger due to merging of islands during random walk)
)
Expand All @@ -43,10 +46,10 @@ class LargestIslandConfig:

def validate(self):
"""Validate configuration parameters"""
assert MIN_MAP_DIM <= self.rows, f"rows must be between larger than {MIN_MAP_DIM}"
assert MIN_MAP_DIM <= self.cols, f"cols must be between larger than {MIN_MAP_DIM}"
assert 0 <= self.max_num_islands, "max_num_islands must be non-negative"
assert 0 <= self.max_island_size, "max_island_size must be non-negative"
assert 1 <= self.min_rows <= self.max_rows, "Invalid rows range"
assert 1 <= self.min_cols <= self.max_cols, "Invalid cols range"
assert 0 <= self.min_num_islands <= self.max_num_islands, "Invalid num_islands range"
assert 0 <= self.min_island_size <= self.max_island_size, "Invalid island_size range"


class LargestIslandDataset(ProceduralDataset):
Expand All @@ -55,34 +58,35 @@ class LargestIslandDataset(ProceduralDataset):
def __init__(self, config: LargestIslandConfig):
super().__init__(config=config, seed=config.seed, size=config.size)

def _is_valid_cell(self, r: int, c: int) -> bool:
return 0 <= r < self.config.rows and 0 <= c < self.config.cols
def _is_valid_cell(self, r: int, c: int, rows: int, cols: int) -> bool:
return 0 <= r < rows and 0 <= c < cols

def _create_grid(self, rng: Random) -> list[list[int]]:
def _create_grid(self, rng: Random, rows: int, cols: int) -> list[list[int]]:
"""Create a random grid of islands using a random walk algorithm"""
grid = [[0] * self.config.cols for _ in range(self.config.rows)]
grid = [[0] * cols for _ in range(rows)]
directions = [(-1, 0), (1, 0), (0, -1), (0, 1)] # Up, Down, Left, Right

def create_island():
r, c = rng.randint(0, self.config.rows - 1), rng.randint(0, self.config.cols - 1)
capped_size = min(rng.randint(0, self.config.max_island_size), self.config.rows * self.config.cols)
r, c = rng.randint(0, rows - 1), rng.randint(0, cols - 1)
capped_size = min(rng.randint(self.config.min_island_size, self.config.max_island_size), rows * cols)
for _ in range(capped_size):
grid[r][c] = 1
rng.shuffle(directions)
for dr, dc in directions:
new_r, new_c = r + dr, c + dc
if self._is_valid_cell(new_r, new_c) and grid[new_r][new_c] == 0:
if self._is_valid_cell(new_r, new_c, rows, cols) and grid[new_r][new_c] == 0:
r, c = new_r, new_c
break

num_islands = rng.randint(0, self.config.max_num_islands)
num_islands = rng.randint(self.config.min_num_islands, self.config.max_num_islands)
for _ in range(num_islands):
create_island()

return grid

def _get_largest_island(self, grid: list[list[int]]) -> int:
"""Find the largest island in the grid"""
rows, cols = len(grid), len(grid[0])
directions = [(-1, 0), (1, 0), (0, -1), (0, 1)] # Up, Down, Left, Right
visited = set()

Expand All @@ -94,15 +98,19 @@ def bfs(r, c):
r, c = queue.popleft()
for dr, dc in directions:
new_r, new_c = r + dr, c + dc
if self._is_valid_cell(new_r, new_c) and (new_r, new_c) not in visited and grid[new_r][new_c] == 1:
if (
self._is_valid_cell(new_r, new_c, rows, cols)
and (new_r, new_c) not in visited
and grid[new_r][new_c] == 1
):
area += 1
visited.add((new_r, new_c))
queue.append((new_r, new_c))
return area

max_area = 0
for r in range(self.config.rows):
for c in range(self.config.cols):
for r in range(rows):
for c in range(cols):
if grid[r][c] == 1 and (r, c) not in visited:
max_area = max(max_area, bfs(r, c))

Expand All @@ -120,16 +128,67 @@ def __getitem__(self, idx: int) -> dict:
"""Generate a single Largest Island question"""
rng = Random(self.seed + idx)

grid = self._create_grid(rng)
rows = rng.randint(self.config.min_rows, self.config.max_rows)
cols = rng.randint(self.config.min_cols, self.config.max_cols)
grid = self._create_grid(rng, rows, cols)
grid_str = self._grid_to_string(grid)

answer = self._get_largest_island(grid)

return {
"question": QUESTION_TEMPLATE.format(rows=self.config.rows, cols=self.config.cols, grid=grid_str),
"question": QUESTION_TEMPLATE.format(rows=rows, cols=cols, grid=grid_str),
"answer": str(answer),
"metadata": {"grid": grid, "solution": answer},
}


class LargestIslandCurriculum(BaseCurriculum):
def __init__(self):
super().__init__(LargestIslandCurriculum.__name__, LargestIslandConfig)

# Define attributes
self._define_attributes(
RangeAttributeDefinition(
name="rows",
levels=[5, 10, 50, 100],
default_level=0,
description="Number of rows in the grid",
attr_type=AttributeType.APPEND,
min_value=1,
lower_field_name="min_rows",
upper_field_name="max_rows",
),
RangeAttributeDefinition(
name="cols",
levels=[5, 10, 50, 100],
default_level=0,
description="Number of columns in the grid",
attr_type=AttributeType.APPEND,
min_value=1,
lower_field_name="min_cols",
upper_field_name="max_cols",
),
RangeAttributeDefinition(
name="num_islands",
levels=[2, 5, 10, 20],
default_level=0,
description="Number of islands in the grid",
attr_type=AttributeType.APPEND,
min_value=0,
lower_field_name="min_num_islands",
upper_field_name="max_num_islands",
),
RangeAttributeDefinition(
name="island_size",
levels=[5, 10, 20, 30],
default_level=0,
description="Size of the islands in the grid",
attr_type=AttributeType.APPEND,
min_value=0,
lower_field_name="min_island_size",
upper_field_name="max_island_size",
),
)


register_dataset("largest_island", LargestIslandDataset, LargestIslandConfig)
75 changes: 60 additions & 15 deletions tests/test_largest_island.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,41 @@

import pytest

from reasoning_gym.graphs.largest_island import LargestIslandConfig, LargestIslandDataset
from reasoning_gym.graphs.largest_island import LargestIslandConfig, LargestIslandCurriculum, LargestIslandDataset


def test_largest_island_config_validation():
"""Test that invalid configs raise appropriate errors"""
with pytest.raises(AssertionError):
config = LargestIslandConfig(rows=-1) # Negative not allowed
config = LargestIslandConfig(min_rows=0) # 0 not allowed
config.validate()

with pytest.raises(AssertionError):
config = LargestIslandConfig(rows=0) # Zero not allowed
config = LargestIslandConfig(min_cols=0) # 0 not allowed
config.validate()

with pytest.raises(AssertionError):
config = LargestIslandConfig(cols=-1) # Negative not allowed
config = LargestIslandConfig(min_rows=10, max_rows=5) # min > max
config.validate()

with pytest.raises(AssertionError):
config = LargestIslandConfig(cols=0) # Zero not allowed
config = LargestIslandConfig(min_cols=10, max_cols=5) # min > max
config.validate()

with pytest.raises(AssertionError):
config = LargestIslandConfig(max_num_islands=-1) # Negative not allowed
config = LargestIslandConfig(min_num_islands=-1) # neg not allowed
config.validate()

with pytest.raises(AssertionError):
config = LargestIslandConfig(max_island_size=-1) # Negative not allowed
config = LargestIslandConfig(min_island_size=-1) # neg not allowed
config.validate()

with pytest.raises(AssertionError):
config = LargestIslandConfig(min_num_islands=5, max_num_islands=3) # min > max
config.validate()

with pytest.raises(AssertionError):
config = LargestIslandConfig(min_island_size=5, max_island_size=3) # min > max
config.validate()


Expand All @@ -44,7 +52,14 @@ def test_largest_island_dataset_deterministic():

def test_largest_island_dataset_items():
"""Test basic properties of generated items"""
config = LargestIslandConfig(rows=8, cols=8, max_island_size=5, size=10, seed=42)
config = LargestIslandConfig(
min_rows=5,
max_rows=10,
min_cols=5,
max_cols=10,
size=10,
seed=42,
)
dataset = LargestIslandDataset(config)

for i in range(len(dataset)):
Expand All @@ -60,12 +75,10 @@ def test_largest_island_dataset_items():
assert "solution" in item["metadata"]

grid = item["metadata"]["grid"]
solution = item["metadata"]["solution"]

# Verify grid dimensions
assert len(grid) == 8
assert all(len(row) == 8 for row in grid)
assert 0 <= solution <= 5
assert 5 <= len(grid) <= 10
assert all(0 <= len(row) <= 10 for row in grid)


def test_largest_island_dataset_iteration():
Expand All @@ -82,19 +95,18 @@ def test_largest_island_dataset_iteration():

def test_largest_island_grid_generation():
"""Test that generated grids are valid"""
config = LargestIslandConfig(rows=10, cols=10, max_island_size=3, size=5, seed=42)
config = LargestIslandConfig(size=5, seed=42)
dataset = LargestIslandDataset(config)

for i in range(len(dataset)):
item = dataset[i]
assert item["metadata"]["solution"] <= 3
for row in item["metadata"]["grid"]:
assert all(cell in {0, 1} for cell in row)


def test_largest_island_answer():
"""Test the _get_largest_island method"""
config = LargestIslandConfig(rows=5, cols=5, seed=42)
config = LargestIslandConfig(seed=42)
dataset = LargestIslandDataset(config)

grid = [
Expand Down Expand Up @@ -125,3 +137,36 @@ def test_largest_island_answer():
[0, 0, 0, 1, 1],
]
assert dataset._get_largest_island(grid) == 9


def test_largest_island_curriculum():
curriculum = LargestIslandCurriculum()

base_value = {"size": 150, "seed": 1}

base_cfg: LargestIslandConfig = curriculum.generate_configuration(base_value)
assert base_cfg.seed == 1
assert base_cfg.size == 150
assert base_cfg.min_rows == 5 and base_cfg.max_rows == 5
assert base_cfg.min_cols == 5 and base_cfg.max_cols == 5
assert base_cfg.min_num_islands == 2 and base_cfg.max_num_islands == 2
assert base_cfg.min_island_size == 5 and base_cfg.max_island_size == 5

# test incrementing attribute levels
curriculum.increment_attr_level("rows")
curriculum.increment_attr_level("cols")
curriculum.increment_attr_level("num_islands")
curriculum.increment_attr_level("island_size")
increased_cfg = curriculum.generate_configuration(base_value)
assert increased_cfg.min_rows == 5 and increased_cfg.max_rows == 10
assert increased_cfg.min_cols == 5 and increased_cfg.max_cols == 10
assert increased_cfg.min_num_islands == 2 and increased_cfg.max_num_islands == 5
assert increased_cfg.min_island_size == 5 and increased_cfg.max_island_size == 10

# test decrementing attribute level for num_islands again
curriculum.decrement_attr_level("num_islands")
partially_decreased_cfg = curriculum.generate_configuration(base_value)
assert partially_decreased_cfg.min_rows == 5 and partially_decreased_cfg.max_rows == 10
assert partially_decreased_cfg.min_cols == 5 and partially_decreased_cfg.max_cols == 10
assert partially_decreased_cfg.min_num_islands == 2 and partially_decreased_cfg.max_num_islands == 2
assert partially_decreased_cfg.min_island_size == 5 and partially_decreased_cfg.max_island_size == 10

0 comments on commit d3f96aa

Please sign in to comment.