Skip to content

Commit

Permalink
fix manipulate matrix (#247)
Browse files Browse the repository at this point in the history
  • Loading branch information
zafstojano authored Mar 1, 2025
1 parent 39f151a commit f549909
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 69 deletions.
134 changes: 87 additions & 47 deletions reasoning_gym/algorithmic/manipulate_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
from copy import deepcopy
from dataclasses import dataclass
from random import Random
from typing import Optional
from typing import Any, Optional

import numpy as np

from ..factory import ProceduralDataset, register_dataset

Expand All @@ -28,21 +30,22 @@ def num_cols(matrix: list[list[int]]) -> int:
class ManipulateMatrixConfig:
"""Configuration for Manipulate Matrix dataset generation"""

min_rows: int = 1 # Minimum number of rows
min_cols: int = 1 # Minimum number of columns
min_rows: int = 2 # Minimum number of rows
min_cols: int = 2 # Minimum number of columns
max_rows: int = 10 # Maximum number of rows
max_cols: int = 10 # Maximum number of columns
max_transforms: int = 5 # Maximum number of transformations to apply
p_rotate: float = 0.2 # Probability of rotating the matrix
p_hmirror: float = 0.2 # Probability of horizontally mirroring the matrix
p_vmirror: float = 0.2 # Probability of vertically mirroring the matrix
p_dmirror: float = 0.2 # Probability of mirroring along the diagonal
p_cmirror: float = 0.2 # Probability of mirroring along the counterdiagonal
p_map: float = 0.2 # Probability of mapping a certain value to another
p_crop: float = 0.2 # Probability of cropping the matrix
p_remove_every_nth_row: float = 0.2 # Probability of removing every nth row
p_remove_every_nth_col: float = 0.2 # Probability of removing every nth column
p_zero_divisible: float = 0.2 # Probability of setting elements divisible by some number to zero
min_transforms: int = 1 # Minimum number of transformations to apply
max_transforms: int = 10 # Maximum number of transformations to apply
w_rotate: float = 1 # Weight of rotating the matrix
w_hmirror: float = 1 # Weight of horizontally mirroring the matrix
w_vmirror: float = 1 # Weight of vertically mirroring the matrix
w_dmirror: float = 1 # Weight of mirroring along the diagonal
w_cmirror: float = 1 # Weight of mirroring along the counterdiagonal
w_map: float = 1 # Weight of mapping a certain value to another
w_crop: float = 1 # Weight of cropping the matrix
w_remove_every_nth_row: float = 1 # Weight of removing every nth row
w_remove_every_nth_col: float = 1 # Weight of removing every nth column
w_zero_divisible: float = 1 # Weight of setting elements divisible by some number to zero

size: int = 500 # Virtual dataset size
seed: Optional[int] = None
Expand All @@ -53,17 +56,27 @@ def validate(self):
assert 1 <= self.min_cols, "min_cols must be at least 1"
assert self.min_rows <= self.max_rows, "max_rows must be at least min_rows"
assert self.min_cols <= self.max_cols, "max_cols must be at least min_cols"
assert 0 <= self.max_transforms, "max_transforms must be non-negative"
assert 0 <= self.p_rotate <= 1, "p_rotate must be between 0 and 1"
assert 0 <= self.p_hmirror <= 1, "p_hmirror must be between 0 and 1"
assert 0 <= self.p_vmirror <= 1, "p_vmirror must be between 0 and 1"
assert 0 <= self.p_dmirror <= 1, "p_dmirror must be between 0 and 1"
assert 0 <= self.p_cmirror <= 1, "p_cmirror must be between 0 and 1"
assert 0 <= self.p_map <= 1, "p_map must be between 0 and 1"
assert 0 <= self.p_crop <= 1, "p_crop must be between 0 and 1"
assert 0 <= self.p_remove_every_nth_row <= 1, "p_remove_every_nth_row must be between 0 and 1"
assert 0 <= self.p_remove_every_nth_col <= 1, "p_remove_nth_col must be between 0 and 1"
assert 0 <= self.p_zero_divisible <= 1, "p_zero_divisible must be between 0 and 1"
assert 1 <= self.min_transforms, "min_transforms must be at least 1"
assert self.min_transforms <= self.max_transforms, "max_transforms must be at least min_transforms"
assert (
np.sum(
np.exp(
[
self.w_rotate,
self.w_hmirror,
self.w_vmirror,
self.w_dmirror,
self.w_cmirror,
self.w_map,
self.w_crop,
self.w_remove_every_nth_row,
self.w_remove_every_nth_col,
self.w_zero_divisible,
]
)
)
> 0
), "At least one weight must be non-zero"


class ManipulateMatrixDataset(ProceduralDataset):
Expand All @@ -89,6 +102,21 @@ def __init__(self, config: ManipulateMatrixConfig):
"remove_every_nth_row",
"remove_every_nth_col",
]
weights = np.array(
[
config.w_rotate,
config.w_hmirror,
config.w_vmirror,
config.w_dmirror,
config.w_cmirror,
config.w_map,
config.w_crop,
config.w_remove_every_nth_row,
config.w_remove_every_nth_col,
config.w_zero_divisible,
]
)
self._weights = np.exp(weights) / np.sum(np.exp(weights))

def _get_matrix(self, rng: Random) -> list[list[int]]:
"""Generate a random matrix"""
Expand All @@ -102,6 +130,25 @@ def _matrix_to_str(self, matrix: list[list[int]]) -> str:
"""Get a string representation of the matrix"""
return "\n".join(" ".join(str(x) for x in row) for row in matrix)

def score_answer(self, answer: Optional[str], entry: dict[str, Any]) -> float:
oracle_answer = entry["answer"].strip()
if answer is not None and len(answer) > 0:
answer = answer.strip()
if answer == oracle_answer:
return 1.0

# perhaps the model's answer has unnecessary spaces (e.g. after last row element)
answer = self._matrix_to_str([row.strip().split() for row in answer.strip().split("\n")]).strip()
if answer == oracle_answer:
return 1.0

if oracle_answer in answer:
return len(oracle_answer) / len(answer)
else:
return 0.01

return 0.0

def _identity(self, matrix: list[list[int]]) -> list[list[int]]:
"""Identity transformation"""
return matrix
Expand Down Expand Up @@ -163,15 +210,16 @@ def __getitem__(self, idx: int) -> dict:
matrix = self._get_matrix(rng)
matrix_str = self._matrix_to_str(matrix)

num_transforms = rng.randint(0, self.config.max_transforms)
transforms = rng.sample(self._all_transforms, num_transforms)
num_transforms = rng.randint(self.config.min_transforms, self.config.max_transforms)
operations = []

answer = deepcopy(matrix)

for transform in transforms:
while len(operations) < num_transforms:
# Choose a transform randomly, weighted by the probability of each transform
transform = rng.choices(self._all_transforms, weights=self._weights, k=1)[0]

# Rotate
if transform == "rotate" and rng.random() < self.config.p_rotate:
if transform == "rotate":
rotation = rng.choice(list(self._rotations.keys()))
answer = self._rotations[rotation](answer)
operations.append(
Expand All @@ -182,39 +230,39 @@ def __getitem__(self, idx: int) -> dict:
}
)
# Horizontal mirror
if transform == "hmirror" and rng.random() < self.config.p_hmirror:
if transform == "hmirror":
answer = self._hmirror(answer)
operations.append({"transform": transform, "instruction": "- Horizontally mirror the matrix"})
# Vertical mirror
if transform == "vmirror" and rng.random() < self.config.p_vmirror:
if transform == "vmirror":
answer = self._vmirror(answer)
operations.append({"transform": transform, "instruction": "- Vertically mirror the matrix"})
# Diagonal mirror
if transform == "dmirror" and rng.random() < self.config.p_dmirror:
if transform == "dmirror":
answer = self._dmirror(answer)
operations.append({"transform": transform, "instruction": "- Mirror the matrix along the diagonal"})
# Counterdiagonal mirror
if transform == "cmirror" and rng.random() < self.config.p_cmirror:
if transform == "cmirror":
answer = self._cmirror(answer)
operations.append(
{"transform": transform, "instruction": "- Mirror the matrix along the counterdiagonal"}
)
# Map a value to another
if transform == "map" and rng.random() < self.config.p_map:
if transform == "map":
a, b = rng.sample(range(10), 2)
answer = self._map(answer, a, b)
operations.append(
{"transform": transform, "from": a, "to": b, "instruction": f"- Map each occurrence of {a} to {b}"}
)
# Set elements divisible by k to zero
if transform == "zero_divisible" and rng.random() < self.config.p_zero_divisible:
if transform == "zero_divisible":
k = rng.randint(1, 9)
answer = self._zero_divisible(answer, k)
operations.append(
{"transform": transform, "k": k, "instruction": f"- Set all elements divisible by {k} to zero"}
)
# Crop the matrix
if transform == "crop" and rng.random() < self.config.p_crop:
if transform == "crop":
row_start = rng.randint(1, num_rows(answer))
row_end = rng.randint(row_start, num_rows(answer))
col_start = rng.randint(1, num_cols(answer))
Expand All @@ -231,23 +279,15 @@ def __getitem__(self, idx: int) -> dict:
}
)
# Remove every nth row
if (
transform == "remove_every_nth_row"
and rng.random() < self.config.p_remove_every_nth_row
and num_rows(answer) > 1
):
if transform == "remove_every_nth_row" and num_rows(answer) > 1:
n = rng.randint(2, num_rows(answer))
answer = self._remove_every_nth_row(answer, n)
formatting = "st" if n == 1 else "nd" if n == 2 else "th"
operations.append(
{"transform": transform, "n": n, "instruction": f"- Remove every {n}-{formatting} row (1-indexed)"}
)
# Remove every nth column
if (
transform == "remove_every_nth_col"
and rng.random() < self.config.p_remove_every_nth_col
and num_cols(answer) > 1
):
if transform == "remove_every_nth_col" and num_cols(answer) > 1:
n = rng.randint(2, num_cols(answer))
answer = self._remove_every_nth_col(answer, n)
formatting = "st" if n == 1 else "nd" if n == 2 else "th"
Expand Down
51 changes: 29 additions & 22 deletions tests/test_manipulate_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
def test_manipulate_matrix_config_validation():
"""Test that invalid configs raise appropriate errors"""

with pytest.raises(AssertionError):
config = ManipulateMatrixConfig(max_transforms=-1) # max_transforms should be non-negative
config.validate()
for field in ["min_transforms", "max_transforms"]:
for num_transforms in [-1, 0]: # Number of transforms should be positive
with pytest.raises(AssertionError):
config = ManipulateMatrixConfig(**{field: num_transforms})
config.validate()

invalid_dims = [-1, 0] # Dimensions should be positive integers
dim_fields = ["min_rows", "min_cols", "max_rows", "max_cols"]
Expand All @@ -21,25 +23,6 @@ def test_manipulate_matrix_config_validation():
config = ManipulateMatrixConfig(**{field: dim})
config.validate()

invalid_probabilities = [-0.01, 1.01] # Probabilities should be between 0 and 1 inclusive
probability_fields = [
"p_hmirror",
"p_vmirror",
"p_dmirror",
"p_cmirror",
"p_map",
"p_crop",
"p_remove_every_nth_row",
"p_remove_every_nth_col",
"p_zero_divisible",
]

for field in probability_fields:
for prob in invalid_probabilities:
with pytest.raises(AssertionError):
config = ManipulateMatrixConfig(**{field: prob})
config.validate()


def test_manipulate_matrix_dataset_deterministic():
"""Test that dataset generates same items with same seed"""
Expand Down Expand Up @@ -212,3 +195,27 @@ def test_manipulate_matrix_transforms():
[16, 18, 20],
[21, 23, 25],
]


def test_manipulate_matrix_score_answer():
"""Test the score_answer method"""
config = ManipulateMatrixConfig(seed=42)
dataset = ManipulateMatrixDataset(config)

entry = {"answer": dataset._matrix_to_str([[1, 2, 3], [4, 5, 6], [7, 8, 9]])}

# perfect match
answer = "1 2 3\n4 5 6\n7 8 9"
assert dataset.score_answer(answer, entry) == 1.0

# model answer contains unnecessary empty spaces
answer = "1 2 3\n4 5 6 \n7 8 9 "
assert dataset.score_answer(answer, entry) == 1.0

# incorrect answer
answer = "1 2 3\n4 5 6\n7 8 8"
assert dataset.score_answer(answer, entry) == 0.01

# answer is none
answer = None
assert dataset.score_answer(answer, entry) == 0.0

0 comments on commit f549909

Please sign in to comment.