Skip to content

Commit

Permalink
fix sokoban dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
andreaskoepf committed Mar 4, 2025
1 parent 9d0343e commit 923acff
Show file tree
Hide file tree
Showing 15 changed files with 85 additions and 111 deletions.
10 changes: 0 additions & 10 deletions reasoning_gym/games/contrib/sokoban/levels/lvl0.dat

This file was deleted.

5 changes: 0 additions & 5 deletions reasoning_gym/games/contrib/sokoban/levels/lvl1.dat

This file was deleted.

6 changes: 0 additions & 6 deletions reasoning_gym/games/contrib/sokoban/levels/lvl2.dat

This file was deleted.

7 changes: 0 additions & 7 deletions reasoning_gym/games/contrib/sokoban/levels/lvl3.dat

This file was deleted.

7 changes: 0 additions & 7 deletions reasoning_gym/games/contrib/sokoban/levels/lvl4.dat

This file was deleted.

7 changes: 0 additions & 7 deletions reasoning_gym/games/contrib/sokoban/levels/lvl5.dat

This file was deleted.

9 changes: 0 additions & 9 deletions reasoning_gym/games/contrib/sokoban/levels/lvl6.dat

This file was deleted.

6 changes: 0 additions & 6 deletions reasoning_gym/games/contrib/sokoban/levels/lvl7.dat

This file was deleted.

11 changes: 7 additions & 4 deletions reasoning_gym/games/contrib/sokoban/src/astar.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
)


def astar(matrix, player_pos, debug=False, heuristic="manhattan"):
def astar(matrix, player_pos, debug: bool = False, heuristic: str = "manhattan", max_depth: int = 100):
# print(f'A* - {heuristic.title()} Heuristic')
heur = "[A*]" if heuristic == "manhattan" else "[Dijkstra]"
shape = matrix.shape
Expand Down Expand Up @@ -67,15 +67,18 @@ def astar(matrix, player_pos, debug=False, heuristic="manhattan"):
return (path + direction[move], depth + 1)
if debug:
print(f"{heur} Solution Depth: {depth + 1}\n{path + direction[move]}", 20)
print(f"{heur} Solution not found!\n")

if depth > max_depth:
break

if debug:
print(f"{heur} Solution Not Found!\nDepth {depth + 1}", 20)

return (None, -1 if not heap else depth + 1)


def solve_astar(puzzle, visualizer=False, heuristic="manhattan"):
def solve_astar(puzzle, visualizer: bool = False, heuristic: str = "manhattan", max_depth: int = 100):
matrix = puzzle
where = np.where((matrix == "*") | (matrix == "%"))
player_pos = where[0][0], where[1][0]
return astar(matrix, player_pos, debug=visualizer, heuristic=heuristic)
return astar(matrix, player_pos, debug=visualizer, heuristic=heuristic, max_depth=max_depth)
13 changes: 6 additions & 7 deletions reasoning_gym/games/contrib/sokoban/src/game.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ def __str__(self) -> str:


class Game:
def __init__(self, width=19, height=10, level=None, path=None):
self.level = level
def __init__(self, width=19, height=10, path=None):
self.width = width
self.height = height
self.puzzle = np.empty((height, width), dtype=PuzzleElement)
Expand All @@ -39,7 +38,7 @@ def __init__(self, width=19, height=10, level=None, path=None):
self.puzzle_size = None
self.pad_x = 0
self.pad_y = 0
self.path = path or f"levels/lvl{level}.dat"
self.path = path

if path:
if type(self) == Game:
Expand Down Expand Up @@ -108,7 +107,7 @@ def _process_puzzle_data(self, data):

# Calculate puzzle size and padding
self.puzzle_size = (len(data), len(data[0]) if len(data) > 0 else 0)
pad_x = (self.width - self.puzzle_size[1] - 2) // 2 # -2 matches original file-based logic
pad_x = (self.width - self.puzzle_size[1]) // 2
pad_y = (self.height - self.puzzle_size[0]) // 2
self.pad_x, self.pad_y = pad_x, pad_y

Expand Down Expand Up @@ -140,15 +139,15 @@ def _process_puzzle_data(self, data):


class ReverseGame(Game):
def __init__(self, rng: Random, width=19, height=10, level=None):
super().__init__(width, height, level)
def __init__(self, rng: Random, width: int = 19, height: int = 10):
super().__init__(width, height)
self.rng = rng
self.pad_x = 0
self.pad_y = 0

def load_puzzle(self, puzzle):
self.puzzle_size = (len(puzzle), len(puzzle[0]) if len(puzzle) > 0 else 0)
pad_x = (self.width - len(puzzle[0]) - 2) // 2
pad_x = (self.width - len(puzzle[0])) // 2
pad_y = (self.height - len(puzzle)) // 2
self.pad_x, self.pad_y = pad_x, pad_y
for i, row in enumerate(puzzle):
Expand Down
38 changes: 23 additions & 15 deletions reasoning_gym/games/contrib/sokoban/src/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@


def num_boxes(puzzle_area, min_boxes, max_boxes, min_w, min_h, max_w, max_h):
if min_w == max_w or min_h == max_h or min_boxes == max_boxes:
return max_boxes

m = (max_boxes - min_boxes) / (max_w * max_h - min_w * min_h)
b = min_boxes - m * min_w * min_h
return int(m * puzzle_area + b)
Expand All @@ -19,31 +22,33 @@ def random_valid(rng: Random, width: int = 10, height: int = 10):
def generate(
rng: Random,
debug: bool = False,
path: str = None,
min_w: int = 6,
min_h: int = 6,
max_w: int = 15,
max_h: int = 10,
min_boxes: int = 4,
max_boxes: int = 10,
max_depth: int = 100,
path: str = None,
) -> tuple[str, str, dict]:
"""
Generates a level with the given configuration parameters.
Parameters:
rng: Random number generator for reproducibility.
visualizer: Whether to visualize the generation process.
path: Path to save the level file (default 'levels/lvl0.dat').
min_w: Minimum width of the puzzle.
min_h: Minimum height of the puzzle.
max_w: Maximum width of the puzzle.
max_h: Maximum height of the puzzle.
min_boxes: Minimum number of boxes.
max_boxes: Maximum number of boxes.
rng: Random number generator
visualizer: Whether to visualize the generation process
min_w: Minimum width of the puzzle
min_h: Minimum height of the puzzle
max_w: Maximum width of the puzzle
max_h: Maximum height of the puzzle
min_boxes: Minimum number of boxes
max_boxes: Maximum number of boxes
max_depth: Maximum search depth
path: Path to save the level file (optional)
Returns:
puzzle_string, solution
"""
path = path or "levels/lvl0.dat"

while True:
width = rng.randint(min_w, max_w)
height = rng.randint(min_h, max_h)
Expand All @@ -60,7 +65,7 @@ def generate(
puzzle[box_pos] = "$"
boxes_created += 1
boxes_seen.add(box_pos)
reverse_game = ReverseGame(rng=rng, level=0)
reverse_game = ReverseGame(rng=rng, width=width, height=height)
reverse_game.load_puzzle(puzzle)
player = reverse_game.player
counter = round(height * width * rng.uniform(1.8, 3.6))
Expand All @@ -79,16 +84,19 @@ def generate(
out_of_place_boxes = np.sum([str(x) == "@" for x in matrix.flatten()])
if out_of_place_boxes >= boxes // 2:
# Optionally save the puzzle to a file:
# np.savetxt(path, matrix, fmt='%s')
if path:
np.savetxt(path, matrix, fmt='%s')
puzzle_str = player.puzzle_to_string(matrix)

grid_list = [list(line) for line in puzzle_str.replace(" ", "").strip().split("\n")]
grid_array = np.array(grid_list)
solution, _ = solve_astar(grid_array)
solution, depth = solve_astar(grid_array, max_depth=max_depth)
if solution is None:
continue # retry generation

if debug:
print(f"solution={solution}")
game = Game()
game = Game(width=width, height=height)
game.load_puzzle_matrix(grid_array)

for step, move in enumerate(solution):
Expand Down
34 changes: 25 additions & 9 deletions reasoning_gym/games/sokoban.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,26 @@
class SokobanConfig:
"""Configuration for sokoban puzzle generation"""

min_w: int = 6 # Minimum width of the puzzle.
min_h: int = 6 # Minimum height of the puzzle.
max_w: int = 10 # Maximum width of the puzzle.
max_h: int = 10 # Maximum height of the puzzle.
min_boxes: int = 6 # Minimum number of boxes.
max_boxes: int = 10 # Maximum number of boxes.
min_w: int = 6 # Minimum width of the puzzle
min_h: int = 6 # Minimum height of the puzzle
max_w: int = 10 # Maximum width of the puzzle
max_h: int = 10 # Maximum height of the puzzle
min_boxes: int = 4 # Minimum number of boxes
max_boxes: int = 10 # Maximum number of boxes
max_depth: int = 80 # Maximum search depth
seed: Optional[int] = None
size: int = 500

def validate(self):
"""Validate configuration parameters"""
assert 0 < self.max_w <= 20
assert 0 < self.max_h <= 20
assert self.min_h > 0
assert self.min_w > 0
assert self.min_w <= self.max_w, "min_w must be lte max_w"
assert self.min_h <= self.max_h, "min_h must be lte max_h"
assert self.min_boxes <= self.max_boxes, "min_boxes must be lte max_boxes"
assert self.max_depth > 1


class SokobanDataset(ProceduralDataset):
Expand Down Expand Up @@ -58,7 +64,16 @@ def __getitem__(self, idx: int) -> dict:

# Make the Sokoban!
rng = Random(self.seed + idx)
gamestr, solution, difficulty = self._generate(rng=rng)
gamestr, solution, difficulty = self._generate(
rng=rng,
min_w=self.config.min_w,
min_h=self.config.min_h,
max_w=self.config.max_w,
max_h=self.config.max_h,
min_boxes=self.config.min_boxes,
max_boxes=self.config.max_boxes,
max_depth=self.config.max_depth,
)

return {
"question": """You are going to solve a 'sokoban' puzzle.
Expand Down Expand Up @@ -100,15 +115,16 @@ def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
grid_list = [list(line) for line in entry["metadata"]["gamestr"].replace(" ", "").strip().split("\n")]
matrix = np.array(grid_list)

game = self._Game()
h, w = matrix.shape
game = self._Game(height=h, width=w)
game.load_puzzle_matrix(matrix)

for move in answer:
game.player.update(key=move)

if self._is_solved(game.get_curr_state()):
return 1.0
except Exception as e:
except:
pass

return 0.0
Expand Down
15 changes: 0 additions & 15 deletions tests/test_dataset_basics.py

This file was deleted.

14 changes: 10 additions & 4 deletions tests/test_sokoban.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
def test_sokoban():
"""Test basic properties and solution of generated items"""

dataset = SokobanDataset(SokobanConfig(size=10, seed=1234))
for i, item in enumerate(dataset):
assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0

# Easy
config = SokobanConfig(seed=42, size=20)
dataset = SokobanDataset(config)
Expand All @@ -21,8 +25,8 @@ def test_sokoban():
assert dataset.score_answer(answer="RU", entry=item) == 0.0
assert dataset.score_answer(answer=None, entry=item) == 0.0

# Medium
config = SokobanConfig(seed=42, min_h=40, max_h=50, min_w=40, max_w=50, min_boxes=20, max_boxes=30, size=3)
# Hard
config = SokobanConfig(seed=42, min_h=15, max_h=20, min_w=15, max_w=20, min_boxes=10, max_boxes=15, size=3, max_depth=90)
dataset = SokobanDataset(config)

for item in dataset:
Expand All @@ -35,8 +39,9 @@ def test_sokoban():
assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0
assert dataset.score_answer(answer=None, entry=item) == 0.0

# Hard
config = SokobanConfig(seed=42, min_h=400, max_h=500, min_w=400, max_w=500, min_boxes=50, max_boxes=50, size=1)

# min == max ranges
config = SokobanConfig(seed=42, min_h=11, max_h=11, min_w=11, max_w=11, min_boxes=11, max_boxes=11, size=3, max_depth=60)
dataset = SokobanDataset(config)

for item in dataset:
Expand All @@ -48,3 +53,4 @@ def test_sokoban():
# Test the scoring
assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0
assert dataset.score_answer(answer=None, entry=item) == 0.0

14 changes: 14 additions & 0 deletions tests_test_dataset_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import reasoning_gym
from reasoning_gym.factory import DATASETS

def test_score_answer_consistency():
for dataset_name in DATASETS.keys():
dataset = reasoning_gym.create_dataset(dataset_name, size=10, seed=1234)
for entry in dataset:
assert entry["answer"] is None or isinstance(
entry["answer"], str
), f"{dataset_name} answer must be str, is {type(entry['answer'])}"
if entry["answer"] is not None:
assert (
dataset.score_answer(answer=entry["answer"], entry=entry) == 1.0
), f"inconsistent score_answer {dataset_name}"

0 comments on commit 923acff

Please sign in to comment.