Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
andreaskoepf committed Mar 4, 2025
1 parent 9ce8ea6 commit bec36e6
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 12 deletions.
4 changes: 2 additions & 2 deletions reasoning_gym/games/contrib/sokoban/src/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,14 @@ def generate(
if out_of_place_boxes >= boxes // 2:
# Optionally save the puzzle to a file:
if path:
np.savetxt(path, matrix, fmt='%s')
np.savetxt(path, matrix, fmt="%s")
puzzle_str = player.puzzle_to_string(matrix)

grid_list = [list(line) for line in puzzle_str.replace(" ", "").strip().split("\n")]
grid_array = np.array(grid_list)
solution, depth = solve_astar(grid_array, max_depth=max_depth)
if solution is None:
continue # retry generation
continue # retry generation

if debug:
print(f"solution={solution}")
Expand Down
10 changes: 4 additions & 6 deletions reasoning_gym/graphs/family_relationships.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,16 +358,14 @@ def _generate_story(self, family: set[Person]) -> str:

def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
reward = 0.0
if answer is not None:
if isinstance(answer, str):
try:
answer_formatted = answer.strip().lower()
solved = answer_formatted == entry["answer"].strip().lower()
if solved:
oracle_answer = entry["answer"].strip().lower()
if answer_formatted == oracle_answer:
reward = 1.0
else:
reward = 0.01
except:
reward = 0.01
pass
return reward


Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import reasoning_gym
from reasoning_gym.factory import DATASETS


def test_score_answer_consistency():
for dataset_name in DATASETS.keys():
dataset = reasoning_gym.create_dataset(dataset_name, size=10, seed=1234)
Expand Down
10 changes: 6 additions & 4 deletions tests/test_sokoban.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ def test_sokoban():
assert dataset.score_answer(answer=None, entry=item) == 0.0

# Hard
config = SokobanConfig(seed=42, min_h=15, max_h=20, min_w=15, max_w=20, min_boxes=10, max_boxes=15, size=3, max_depth=90)
config = SokobanConfig(
seed=42, min_h=15, max_h=20, min_w=15, max_w=20, min_boxes=10, max_boxes=15, size=3, max_depth=90
)
dataset = SokobanDataset(config)

for item in dataset:
Expand All @@ -39,9 +41,10 @@ def test_sokoban():
assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0
assert dataset.score_answer(answer=None, entry=item) == 0.0


# min == max ranges
config = SokobanConfig(seed=42, min_h=11, max_h=11, min_w=11, max_w=11, min_boxes=11, max_boxes=11, size=3, max_depth=60)
config = SokobanConfig(
seed=42, min_h=11, max_h=11, min_w=11, max_w=11, min_boxes=11, max_boxes=11, size=3, max_depth=60
)
dataset = SokobanDataset(config)

for item in dataset:
Expand All @@ -53,4 +56,3 @@ def test_sokoban():
# Test the scoring
assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0
assert dataset.score_answer(answer=None, entry=item) == 0.0

0 comments on commit bec36e6

Please sign in to comment.