diff --git a/README.md b/README.md
index a1aa3791..ec19b022 100644
--- a/README.md
+++ b/README.md
@@ -115,6 +115,8 @@ See the [Dataset Gallery](GALLERY.md) for a complete list of available datasets
- `PropositionalLogicDataset`: Generate propositional logic reasoning problems
- `SyllogismDataset`: Generates a [syllogism](https://en.wikipedia.org/wiki/Syllogism) reasoning dataset
- `AliceInWonderlandDataset`: Generates [AIW](https://openreview.net/forum?id=Mkl7dzjYiW) (Alice In Wonderland) problems with a few variations
+- `ZebraDataset`: Generates [Zebra Puzzles](https://en.wikipedia.org/wiki/Zebra_Puzzle) of varying difficulty.
+
### Graph Tasks
- `FamilyRelationshipsDataset`: Generate family relationship reasoning tasks with family trees
diff --git a/pyproject.toml b/pyproject.toml
index 964c0b4c..e873922f 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -16,8 +16,10 @@ dependencies = [
"cellpylib==2.4.0",
"sympy>=1.13.1",
"magiccube==0.3.0",
+ "pycosat==0.6.6",
"pyfiglet==1.0.2",
- "pytz>=2024.1"
+ "pytz>=2024.1",
+ "tabulate==0.9.0",
]
classifiers = [
"Programming Language :: Python :: 3",
diff --git a/reasoning_gym/__init__.py b/reasoning_gym/__init__.py
index 054cbd95..019873ff 100644
--- a/reasoning_gym/__init__.py
+++ b/reasoning_gym/__init__.py
@@ -2,7 +2,7 @@
Reasoning Gym - A library of procedural dataset generators for training reasoning models
"""
-from . import algebra, algorithmic, arithmetic, cognition, data, games, geometry, graphs, logic
+from . import algebra, algorithmic, arithmetic, code, cognition, data, games, geometry, graphs, logic
from .factory import create_dataset, register_dataset
__version__ = "0.1.3"
@@ -10,6 +10,7 @@
"algebra",
"algorithmic",
"arithmetic",
+ "code",
"cognition",
"data",
"games",
diff --git a/reasoning_gym/logic/__init__.py b/reasoning_gym/logic/__init__.py
index 38307647..96b24277 100644
--- a/reasoning_gym/logic/__init__.py
+++ b/reasoning_gym/logic/__init__.py
@@ -9,6 +9,7 @@
from .aiw import AliceInWonderlandConfig, AliceInWonderlandDataset
from .propositional_logic import PropositionalLogicConfig, PropositionalLogicDataset
from .syllogisms import SyllogismConfig, SyllogismDataset, Term
+from .zebra_puzzles import ZebraConfig, ZebraDataset
__all__ = [
"AliceInWonderlandConfig",
@@ -19,4 +20,6 @@
"SyllogismDataset",
"syllogism_dataset",
"Term",
+ "ZebraConfig",
+ "ZebraDataset",
]
diff --git a/reasoning_gym/logic/contrib/__init__.py b/reasoning_gym/logic/contrib/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/reasoning_gym/logic/contrib/logic_puzzle/README.md b/reasoning_gym/logic/contrib/logic_puzzle/README.md
new file mode 100644
index 00000000..d7378bea
--- /dev/null
+++ b/reasoning_gym/logic/contrib/logic_puzzle/README.md
@@ -0,0 +1,8 @@
+via https://github.com/nouhadziri/faith-and-fate
+
+@article{dziri2023faith,
+title={Faith and Fate: Limits of Transformers on Compositionality},
+author={Dziri, Nouha and Lu, Ximing and Sclar, Melanie and Li, Xiang Lorraine and Jian, Liwei and Lin, Bill Yuchen and West, Peter and Bhagavatula, Chandra and Bras, Ronan Le and Hwang, Jena D and others},
+journal={arXiv preprint arXiv:2305.18654},
+year={2023}
+}
diff --git a/reasoning_gym/logic/contrib/logic_puzzle/__init__.py b/reasoning_gym/logic/contrib/logic_puzzle/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/reasoning_gym/logic/contrib/logic_puzzle/clues.py b/reasoning_gym/logic/contrib/logic_puzzle/clues.py
new file mode 100644
index 00000000..b5b01e37
--- /dev/null
+++ b/reasoning_gym/logic/contrib/logic_puzzle/clues.py
@@ -0,0 +1,266 @@
+"""
+clues.py
+
+These are all the clue types that a puzzle can have. Things like "the tea drinker lives in the
+green house" and "the cat owner lives left of the person who likes grilled cheese."
+
+There's a Clue ABC that requires you implement an `as_cnf` method, to convert the clue to an
+and-of-ors (probably using things defined in `sat_utils`), and a human-readable __repr__ that
+can be used in a puzzle description.
+
+"""
+
+from abc import ABC, abstractmethod
+from dataclasses import dataclass, field
+from functools import wraps
+from itertools import product
+from typing import Iterable, List, Tuple
+
+from reasoning_gym.logic.contrib.logic_puzzle.literals import Literal
+from reasoning_gym.logic.contrib.logic_puzzle.sat_utils import from_dnf, neg
+
+
+def _capitalize_first(repr_func):
+ """
+ Decorator for a __repr__ function that capitalizes the first letter without chagning the rest
+
+ (in contrast to str.capitalize(), which capitalizes the first letter and makes the rest lower)
+ """
+
+ @wraps(repr_func)
+ def wrapper(*args, **kwargs):
+ output = repr_func(*args, **kwargs)
+ return output[0].upper() + output[1:]
+
+ return wrapper
+
+
+class Clue(ABC):
+ """Base class for the types of clues that we allow."""
+
+ @abstractmethod
+ def as_cnf(self) -> Iterable[Tuple[str]]: ...
+
+ @abstractmethod
+ def __repr__(self) -> str: ...
+
+
+def comb(value: Literal, house: int) -> str:
+ """Format how a value is shown at a given house"""
+
+ return f"{value} {house}"
+
+
+@dataclass(eq=True, frozen=True)
+class found_at(Clue):
+ """
+ A literal is known to be at a specific house
+
+ Examples:
+ - the tea drinker lives in the middle house
+ - the fourth house is red
+ """
+
+ value: Literal
+ house: int
+
+ def as_cnf(self) -> List[Tuple[str]]:
+ return [(comb(self.value, self.house),)]
+
+ @_capitalize_first
+ def __repr__(self) -> str:
+ houses = [None, "first", "second", "third", "fourth", "fifth", "sixth", "seventh"]
+ return f"{self.value.value} is in the {houses[self.house]} house."
+
+
+@dataclass(eq=True, frozen=True)
+class not_at(Clue):
+ """
+ Two values are known *not* to be at the same house
+
+ Examples:
+ - the musician does not drink tea
+ - the red house does not contain a cat
+ """
+
+ value: Literal
+ house: int
+
+ def as_cnf(self) -> List[Tuple[str]]:
+ return [(neg(comb(self.value, self.house)),)]
+
+ @_capitalize_first
+ def __repr__(self) -> str:
+ houses = [None, "first", "second", "third", "fourth", "fifth", "sixth", "seventh"]
+ return f"{self.value.value} is not in the {houses[self.house]} house."
+
+
+@dataclass(eq=True, frozen=True)
+class same_house(Clue):
+ """
+ Two values are known to be at the same house
+
+ Examples:
+ - the musician drinks tea
+ - the red house contains a cat
+ """
+
+ value1: Literal
+ value2: Literal
+ houses: Tuple[int, ...] = field(default_factory=lambda: (1, 2, 3, 4, 5))
+
+ def as_cnf(self) -> List[Tuple[str]]:
+ return from_dnf((comb(self.value1, i), comb(self.value2, i)) for i in self.houses)
+
+ @_capitalize_first
+ def __repr__(self) -> str:
+ return f"{self.value1.value} is {self.value2.value}."
+
+
+@dataclass(eq=True, frozen=True)
+class consecutive(Clue):
+ """
+ The first value is directly to the left of the second value
+
+ Examples:
+ - the green house is directly to the left of the white house
+ (green in 1, white in 2 OR green in 2, white in 3 OR etc.)
+ - the house with the kittens is directly to the right of the tea drinker's home
+ (kittens in 2, tea in 1 OR kittens in 3, tea in 2 OR etc.)
+ """
+
+ value1: Literal
+ value2: Literal
+ houses: Tuple[int, ...] = field(default_factory=lambda: (1, 2, 3, 4, 5))
+
+ def as_cnf(self) -> List[Tuple[str]]:
+ return from_dnf((comb(self.value1, i), comb(self.value2, j)) for i, j in zip(self.houses, self.houses[1:]))
+
+ @_capitalize_first
+ def __repr__(self) -> str:
+ return f"{self.value1.value} is directly left of {self.value2.value}."
+
+
+@dataclass(eq=True, frozen=True)
+class beside(Clue):
+ """
+ The two values occur side-by-side (either left or right)
+
+ Examples:
+ - the coffee drinker is (left or right) of the tea drinker
+ - the cat owner is (left or right) of the green house
+ """
+
+ value1: Literal
+ value2: Literal
+ houses: Tuple[int, ...] = field(default_factory=lambda: (1, 2, 3, 4, 5))
+
+ def as_cnf(self) -> List[Tuple[str]]:
+ return from_dnf(
+ [(comb(self.value1, i), comb(self.value2, j)) for i, j in zip(self.houses, self.houses[1:])]
+ + [(comb(self.value2, i), comb(self.value1, j)) for i, j in zip(self.houses, self.houses[1:])]
+ )
+
+ @_capitalize_first
+ def __repr__(self) -> str:
+ return f"{self.value1.value} and {self.value2.value} are next to each other."
+
+
+@dataclass(eq=True, frozen=True)
+class left_of(Clue):
+ """
+ The first value is somewhere to the left of the second value
+
+ Examples:
+ - the tea drinker is in house 1 and the musician in 2, 3, 4, or 5;
+ OR the tea drinker in 2, and musician in 3, 4, or 5;
+ OR the tea drinker in 3, musician in 4, 5; OR tea 4, musician 5.
+ """
+
+ value1: Literal
+ value2: Literal
+ houses: Tuple[int, ...] = field(default_factory=lambda: (1, 2, 3, 4, 5))
+
+ def as_cnf(self) -> List[Tuple[str]]:
+ return from_dnf(
+ (comb(self.value1, i), comb(self.value2, j)) for i, j in product(self.houses, self.houses) if i < j
+ )
+
+ @_capitalize_first
+ def __repr__(self) -> str:
+ return f"{self.value1.value} is somewhere to the left of {self.value2.value}."
+
+
+@dataclass(eq=True, frozen=True)
+class right_of(Clue):
+ """
+ The first value is somewhere to the right of the second value.
+
+ Examples:
+ - the coffee drinker is in house 5 and the artist in 1, 2, 3, 4;
+ OR the coffee drinker in 4, and artist in 1, 2, or 3;
+ OR the coffee drinker in 3, artist in 1, 2; OR coffee 2, artist 1.
+ """
+
+ value1: Literal
+ value2: Literal
+ houses: Tuple[int, ...] = field(default_factory=lambda: (1, 2, 3, 4, 5))
+
+ def as_cnf(self) -> List[Tuple[str]]:
+ return sat_utils.from_dnf(
+ (comb(self.value1, i), comb(self.value2, j)) for i, j in product(self.houses, self.houses) if i > j
+ )
+
+ @_capitalize_first
+ def __repr__(self) -> str:
+ return f"{self.value1.value} is somewhere to the right of {self.value2.value}."
+
+
+@dataclass(eq=True, frozen=True)
+class one_between(Clue):
+ """
+ The values are separated by one house
+
+ Examples (if 5 houses):
+ - the cat is in house 1 and tea drinker in house 3; OR cat 2, tea 4;
+ OR cat 4 house 5
+ - the green house is #1 and the musician in house 3; or green house 2, musician 4;
+ OR green house 3, musician 5.
+ """
+
+ value1: Literal
+ value2: Literal
+ houses: Tuple[int, ...] = field(default_factory=lambda: (1, 2, 3, 4, 5))
+
+ def as_cnf(self) -> List[Tuple[str]]:
+ return from_dnf(
+ [(comb(self.value1, i), comb(self.value2, j)) for i, j in zip(self.houses, self.houses[2:])]
+ + [(comb(self.value2, i), comb(self.value1, j)) for i, j in zip(self.houses, self.houses[2:])]
+ )
+
+ def __repr__(self) -> str:
+ return f"There is one house between {self.value1.value} and {self.value2.value}."
+
+
+@dataclass(eq=True, frozen=True)
+class two_between(Clue):
+ """
+ The values are separated by two houses
+
+ Examples (if 5 houses):
+ - the cat is in house 1 and artist in house 4; or cat 2, artist 5
+ - the dog is in house 1 and red house is #4; or dog 2, red house 5
+ """
+
+ value1: Literal
+ value2: Literal
+ houses: Tuple[int, ...] = field(default_factory=lambda: (1, 2, 3, 4, 5))
+
+ def as_cnf(self) -> List[Tuple[str]]:
+ return from_dnf(
+ [(comb(self.value1, i), comb(self.value2, j)) for i, j in zip(self.houses, self.houses[3:])]
+ + [(comb(self.value2, i), comb(self.value1, j)) for i, j in zip(self.houses, self.houses[3:])]
+ )
+
+ def __repr__(self) -> str:
+ return f"There are two houses between {self.value1.value} and {self.value2.value}."
diff --git a/reasoning_gym/logic/contrib/logic_puzzle/generate.py b/reasoning_gym/logic/contrib/logic_puzzle/generate.py
new file mode 100644
index 00000000..28f679cd
--- /dev/null
+++ b/reasoning_gym/logic/contrib/logic_puzzle/generate.py
@@ -0,0 +1,450 @@
+"""
+puzzle_generator.py
+
+This is a driver script that can be used to generate new zebra puzzles.
+"""
+
+import json
+import pickle
+import sys
+
+# from tqdm import tqdm
+from itertools import product
+from random import choices, randint, sample, seed, shuffle
+from typing import Dict, Iterable, List, Optional, Set, Tuple, Type
+
+from tabulate import tabulate
+
+from reasoning_gym.logic.contrib.logic_puzzle.literals import *
+from reasoning_gym.logic.contrib.logic_puzzle.puzzle import Puzzle
+from reasoning_gym.logic.contrib.logic_puzzle.sat_utils import itersolve
+
+from .clues import Clue, beside, consecutive, found_at, left_of, not_at, one_between, right_of, same_house, two_between
+
+
+def generate_found_at(puzzle: Puzzle, solution: Dict[Literal, int]) -> Set[Clue]:
+ """Generate the `found_at` / `not_at` Clue instances"""
+ clues: Set[Clue] = set()
+ for element, loc in solution.items():
+ clues.add(found_at(element, loc))
+
+ return clues
+
+
+def generate_not_found_at(puzzle: Puzzle, solution: Dict[Literal, int]) -> Set[Clue]:
+ """Generate the `found_at` / `not_at` Clue instances"""
+ clues: Set[Clue] = set()
+ for element, loc in solution.items():
+ for house in puzzle.houses:
+ if house != loc:
+ clues.add(not_at(element, house))
+
+ return clues
+
+
+def generate_same_house(puzzle: Puzzle, solution: Dict[Literal, int]) -> Set[Clue]:
+ """Generate the `same_house` Clue instances"""
+
+ clues: Set[Clue] = set()
+ for house in puzzle.houses:
+ items_at_house = {item: loc for item, loc in solution.items() if loc == house}
+ pairs: Set[Tuple[Literal, Literal]] = {
+ (item1, item2) for item1, item2 in product(items_at_house, repeat=2) if item1 != item2
+ }
+ for pair in pairs:
+ clues.add(same_house(pair[0], pair[1], puzzle.houses))
+
+ return clues
+
+
+def generate_consecutive_beside(puzzle: Puzzle, solution: Dict[Literal, int]) -> Set[Clue]:
+ """Generate the `consecutive` / `beside` Clue instances
+
+ (Note that consecutive is just a more informative version of beside. Since they have the same
+ structure, for every possible combination we'll just keep one.
+ """
+
+ clues: Set[Clue] = set()
+ for left, right in zip(puzzle.houses, puzzle.houses[1:]):
+ items_left = {item: loc for item, loc in solution.items() if loc == left}
+ items_right = {item: loc for item, loc in solution.items() if loc == right}
+ pairs: Set[Tuple[Literal, Literal]] = {(item1, item2) for item1, item2 in product(items_left, items_right)}
+ for pair in pairs:
+ # consecutive is just a more informative version of beside, but they have same structure
+ # because of this, don't include both
+ if randint(0, 1) == 0:
+ clues.add(consecutive(pair[0], pair[1], puzzle.houses))
+ else:
+ clues.add(beside(pair[0], pair[1], puzzle.houses))
+
+ return clues
+
+
+def generate_left_right_of(puzzle: Puzzle, solution: Dict[Literal, int]) -> Set[Clue]:
+ """Generate the `left_of` / `right_of` Clue instances
+
+ Note that since (x left-of y) is guaranteed to be redundant with (b right-of a), we only add
+ one of these clues to the final set.
+ """
+
+ clues: Set[Clue] = set()
+ for left, right in product(puzzle.houses, puzzle.houses):
+ if left >= right:
+ continue
+
+ items_left = {item: loc for item, loc in solution.items() if loc == left}
+ items_right = {item: loc for item, loc in solution.items() if loc == right}
+ pairs: Set[Tuple[Literal, Literal]] = {(item1, item2) for item1, item2 in product(items_left, items_right)}
+ for pair in pairs:
+ if randint(0, 1) == 0:
+ clues.add(left_of(pair[0], pair[1], puzzle.houses))
+ else:
+ clues.add(right_of(pair[1], pair[0], puzzle.houses))
+
+ return clues
+
+
+def generate_one_between(puzzle: Puzzle, solution: Dict[Literal, int]) -> Set[Clue]:
+ """Generate the `one_between` Clue instances"""
+
+ clues: Set[Clue] = set()
+ for left, right in zip(puzzle.houses, puzzle.houses[2:]):
+ items_left = {item: loc for item, loc in solution.items() if loc == left}
+ items_right = {item: loc for item, loc in solution.items() if loc == right}
+ pairs: Set[Tuple[Literal, Literal]] = {(item1, item2) for item1, item2 in product(items_left, items_right)}
+ for pair in pairs:
+ clues.add(one_between(pair[0], pair[1], puzzle.houses))
+
+ return clues
+
+
+def generate_two_between(puzzle: Puzzle, solution: Dict[Literal, int]) -> Set[Clue]:
+ """Generate the `two_between` Clue instances"""
+
+ clues: Set[Clue] = set()
+ for left, right in zip(puzzle.houses, puzzle.houses[3:]):
+ items_left = {item: loc for item, loc in solution.items() if loc == left}
+ items_right = {item: loc for item, loc in solution.items() if loc == right}
+ pairs: Set[Tuple[Literal, Literal]] = {(item1, item2) for item1, item2 in product(items_left, items_right)}
+ for pair in pairs:
+ clues.add(two_between(pair[0], pair[1], puzzle.houses))
+
+ return clues
+
+
+def has_unique_solution(puzzle: Puzzle, clues: Iterable[Clue], remove_after=False) -> bool:
+ """Test if a puzzle has a unique solution under a given set of clues."""
+
+ with puzzle.with_clues(clues, remove_after=remove_after):
+ # print(f"Testing puzzle with {len(puzzle.clues)} clues")
+ solutions = itersolve(puzzle.as_cnf())
+ _first_solution = next(solutions)
+
+ # test if second solution exists or not; if it doesn't, uniquely solvable
+ if next(solutions, None) is None:
+ return True
+ else:
+ return False
+
+
+def try_to_remove(puzzle: Puzzle, clues: Set[Clue], n: int, must_have=set()) -> Set[Clue]:
+ """
+ Attempt to remove n clues from a set of candidate clues; if we are able to, return the new,
+ smaller set of clues. If not, return the original set.
+ """
+
+ def weight(clue: Clue) -> float:
+ # relative probabilities of each type of clue being selected for removal
+ weights: Dict[Type[Clue], float] = {
+ not_at: 0.75,
+ found_at: 0.75,
+ same_house: 0.75,
+ beside: 1.2,
+ left_of: 1.2,
+ right_of: 1.2,
+ one_between: 1.5,
+ two_between: 1.5,
+ }
+
+ return weights.get(type(clue), 1)
+
+ weights = [weight(clue) for clue in clues]
+ candidates: Set[Clue] = set(choices(list(clues), weights, k=n))
+ candidates = candidates - must_have
+ clues = clues.difference(candidates)
+ if has_unique_solution(puzzle, clues):
+ # print(f"Removed {len(candidates)} clues.")
+ return clues
+
+ # we needed at least one of those, add them all back
+ clues = clues | candidates
+ return clues
+
+
+def reduce_individually(
+ puzzle: Puzzle, clues: Set[Clue], removed: Set[Clue], must_have=set()
+) -> Tuple[Set[Clue], Set[Clue]]:
+ """
+ Attempt to remove each candidate clue one by one.
+
+ The sets `clues` and `removed` are modified in-place. Unnecessary clues get removed from `clues`
+ and added to `removed`. If no clues can be removed, we return the original two sets.
+ """
+
+ candidates = set(sample(list(clues), len(clues)))
+ for clue in candidates:
+ if clue not in must_have:
+ clues.remove(clue)
+ if has_unique_solution(puzzle, clues):
+ # print(f"Removed {clue=}")
+ removed.add(clue)
+ continue # we were fine to remove this clue
+ clues.add(clue)
+
+ return clues, removed
+
+
+def reduce_clues(puzzle: Puzzle, clues: Set[Clue], must_have=set()) -> Tuple[Set[Clue], Set[Clue]]:
+ """
+ Reduce a set of clues to a minimally solvable set.
+
+ A minimally solvable 5-house, 4-attribute puzzle takes between 10 and 20 clues to solve. The
+ original set of clues will likely be in the hundreds, and the majority are likely to be
+ redundant. We can quickly reduce the set of clues by batch removing clues from the large
+ candidate pool.
+
+ The algorithm for batch reduction:
+ 1. shuffle all the clues
+ 2. attempt to remove 10% of the clues; with this 90%-clue set, test if the puzzle is solvable.
+ 3a. if yes: keep them removed, go back to 2 and repeat
+ 3b. if no: add them back, keep going to 4
+ 4. the same as step (3), but this time trying to remove 5% of the clues
+ 5. the same as step (3), but this time trying to remove a single clue
+
+ After we've tried and failed to remove a *single* clue, then the (first part of the) reduction
+ algorithm is done; having that clue was necessary for us to have a unique solution. This doesn't
+ necessarily mean that *all* the clues are need, though, which is what the secondary reduction
+ is for.
+
+ The *secondary reduction process* is much simpler: now that the set is narrowed substantially,
+ we can be more brute-forcey. Attempt to remove each clue and see if the puzzle is still
+ solvable.
+
+ However, the secondary reduction process can result in a puzzle that is *too hard* to solve
+ (though technically uniquely solvable by a computer or sufficiently skilled human). This
+ algorithm returns a second set of clues that were removed during the secondary reduction
+ process. These can be thought of as extra clues to add or hints to give to anyone solving the
+ puzzle.
+
+ """
+
+ # this is a stupid way to shuffle the set of clues without modifying it
+ minimal_clues = set(sample(list(clues), k=len(clues)))
+ while True:
+ # print(f"There are {len(minimal_clues)} clues in ba sing se")
+
+ # Walrus time!
+ #
+ # If the size of minimal_clues before we try to remove some clues is greater than the size
+ # after, then those clues were fine to remove. Go back to the top of the loop and keep
+ # removing more. But if the size is the same, we needed some of those clues. Try to remove
+ # a smaller amount.
+ #
+ # We use the walrus operator to update minimal_clues in place during the comparison. It'll
+ # either be a reduced set of clues or the original set of clues.
+ if len(minimal_clues) > len(
+ (minimal_clues := try_to_remove(puzzle, minimal_clues, len(minimal_clues) // 10, must_have))
+ ):
+ continue
+
+ if len(minimal_clues) != len(
+ (minimal_clues := try_to_remove(puzzle, minimal_clues, len(minimal_clues) // 20, must_have))
+ ):
+ continue
+
+ if len(minimal_clues) == len((minimal_clues := try_to_remove(puzzle, minimal_clues, 1, must_have))):
+ break
+
+ # secondary reduction time! While we can still remove clues, do so; then we're done.
+ # print(f"Starting the secondary reduction.")
+ removed_clues: Set[Clue] = set()
+ while True:
+ minimal_clues_size = len(minimal_clues)
+ minimal_clues, removed_clues = reduce_individually(puzzle, minimal_clues, removed_clues, must_have)
+ if len(minimal_clues) == minimal_clues_size:
+ # cannot reduce anymore
+ break
+
+ return minimal_clues, removed_clues
+
+
+def question_generation(col_name, table_data):
+ values_by_cols = {}
+ for row in table_data:
+ for idx, value in enumerate(row):
+ c = col_name[idx]
+ if c not in values_by_cols:
+ values_by_cols[c] = []
+ values_by_cols[c].append(value)
+
+ questions_data = []
+ for row in table_data:
+ for cid, col in enumerate(col_name):
+ if cid == 0:
+ continue
+ question = f"What is {col} of the person who lives in House {row[0]}?"
+ options = values_by_cols[col][:]
+ shuffle(options)
+ truth = row[cid]
+ assert truth in options
+ questions_data.append(
+ {"question": question, "choices": options, "truth_idx": options.index(truth), "answer": truth}
+ )
+ assert questions_data[-1]["answer"] in questions_data[-1]["choices"]
+ assert questions_data[-1]["choices"][questions_data[-1]["truth_idx"]] == questions_data[-1]["answer"]
+
+ return questions_data
+
+
+def generate_solution_dict(selected_elements: List[Literal], n: int) -> Dict[Literal, int]:
+ solution = {}
+ house_ids = list(range(1, n + 1))
+ for element in selected_elements:
+ shuffle(house_ids)
+ attributes: List[element] = list(element.__members__.values())
+ for i in range(n):
+ solution[attributes[i]] = house_ids[i]
+ return solution
+
+
+def wrap_up_dict(random_elements, solution, puzzle, reduced, extra_clues, context, K, M):
+ col_names = [e.__name__ for e in random_elements]
+ house_data = {}
+ for item, house in solution.items():
+ element_name, attrname = str(item).split(".")
+ if house not in house_data:
+ house_data[house] = {}
+ house_data[house][element_name] = attrname
+ table_data = []
+ for i in range(1, len(house_data) + 1):
+ row = [i]
+ for c in col_names:
+ row.append(house_data[i][c].replace("_", " "))
+ table_data.append(row)
+
+ col_names = ["House"] + col_names
+
+ table = tabulate(table_data, headers=col_names, tablefmt="grid")
+
+ ## Generate multiple-choice questions
+ q_data = question_generation(col_names, table_data)
+ all_in_one = {}
+ all_in_one["size"] = f"{K}*{M}"
+ all_in_one["puzzle_context"] = context
+ all_in_one["core_rules"] = [str(clue) for clue in reduced]
+ all_in_one["extra_rules"] = [str(clue) for clue in extra_clues]
+ all_in_one["core_rules_types"] = [str(type(clue)) for clue in reduced]
+ all_in_one["extra_rules_types"] = [str(type(clue)) for clue in extra_clues]
+ all_in_one["puzzle"] = str(puzzle)
+ all_in_one["questions"] = q_data
+ all_in_one["solution"] = {"table_str": table, "table_rows": table_data, "table_header": col_names}
+ return all_in_one
+
+
+def check_correctness(p):
+ solutions = itersolve(p.as_cnf())
+ _first_solution = next(solutions)
+ solution_set = [f"{str(k)} {v}" for k, v in p.solution.items()]
+ return set(solution_set) == set(_first_solution)
+
+
+def generate_puzzle(K=2, M=3, mode="train"):
+ elements = [Color, Nationality, Animal, Drink, Cigar, Food, Flower, PhoneModel, Children, Smoothie]
+ clue_types = [
+ generate_found_at,
+ generate_same_house,
+ generate_consecutive_beside,
+ ]
+
+ shuffle(elements)
+ random_elements = [Name] + elements[: M - 1]
+ solution = generate_solution_dict(random_elements, K)
+
+ # set up the puzzle with default constraints
+ puzzle = Puzzle(element_types=random_elements, elements=solution.keys(), n_houses=K).set_constraints()
+ puzzle.solution = solution
+ context = str(puzzle)
+
+ # generate all the clues
+ clues: Set[Clue] = set()
+
+ for generate_function in clue_types:
+ clues = clues.union(generate_function(puzzle, solution))
+
+ reduced, _ = reduce_clues(puzzle, clues)
+ extra_clues = clues - reduced
+ extra_clues = set(sample(list(extra_clues), min(len(extra_clues), 30)))
+ for clue in reduced:
+ puzzle.add_clue(clue)
+
+ assert has_unique_solution(puzzle, puzzle.clues, remove_after=False)
+ assert check_correctness(puzzle)
+ all_in_one = wrap_up_dict(random_elements, solution, puzzle, reduced, extra_clues, context, K, M)
+ return all_in_one, puzzle
+
+
+# def main():
+# mode = sys.argv[1]
+# print(f"mode={mode}")
+# if mode.startswith("train"):
+# seed(1337)
+# N = 30
+# if mode.endswith("_large"):
+# N = 150
+# if mode.endswith("_xl"):
+# N = 1000
+# Ks = [2,3,4]
+# Ms = [2,3,4]
+
+# if mode.endswith("_xxl"):
+# N = 500
+# Ks = [2,3,4,5,6]
+# Ms = [2,3,4,5,6]
+
+# elif mode == "dev" or mode.startswith("test_"):
+# seed(42+len(mode))
+# N = 10
+# Ks = [2,3,4,5]
+# Ms = [2,3,4,5]
+# if mode.startswith("test_id_xl"):
+# Ks = [2,3,4,5,6]
+# Ms = [2,3,4,5,6]
+# if mode.startswith("test_id_xxl"):
+# Ks = [2,3,4,5,6,7]
+# Ms = [2,3,4,5,6,7]
+# if mode.endswith("_50"):
+# N = 50
+
+# instances = []
+# puzzle_objs = []
+# for K, M, idx in tqdm(list(product(Ks, Ms, list(range(N))))):
+# if mode.startswith("test_id_xl"):
+# if K != 6 and M != 6:
+# continue
+# if mode.startswith("test_id_xxl"):
+# if K != 7 and M != 7:
+# continue
+# instance, puzzle = generate_puzzle(K, M, mode)
+# instance["idx"] = f"lgp-{mode}-{K}x{M}-{idx}"
+# instances.append(instance)
+# puzzle_objs.append({"idx": instance["idx"], "puzzle": puzzle})
+
+# with open(f"logic_grid_puzzles.{mode}.pkl", "wb") as f:
+# pickle.dump(puzzle_objs, f)
+
+# with open(f"logic_grid_puzzles.{mode}.json", "w") as f:
+# json.dump(instances, f, indent=2)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/reasoning_gym/logic/contrib/logic_puzzle/graph/check_clues.py b/reasoning_gym/logic/contrib/logic_puzzle/graph/check_clues.py
new file mode 100644
index 00000000..407cfb23
--- /dev/null
+++ b/reasoning_gym/logic/contrib/logic_puzzle/graph/check_clues.py
@@ -0,0 +1,57 @@
+import itertools
+from collections import defaultdict
+
+import sat_utils
+import solver as solver
+from z3 import *
+
+
+def check_singe_clues(puzzle_idx, answer_header, answer_value, running_clues, idx2clue, used_clue, all_cells):
+ solved_cell = defaultdict(int)
+ valid_solutions = defaultdict(int)
+ print("check single clues")
+ for idx in idx2clue:
+ if idx not in used_clue:
+ new_clues = running_clues + idx2clue[idx].as_cnf()
+ numbered_cnf, num2var = sat_utils.translate(new_clues)
+ single_solver = solver.my_solver(puzzle_idx, answer_header, answer_value, numbered_cnf, num2var)
+ cell_info = single_solver.check_cell_difficulty()
+ num_cell = 0
+ for cell in cell_info:
+ if cell not in all_cells:
+ num_cell += 1
+ solved_cell[idx] = num_cell
+ del single_solver
+ idx = sorted(solved_cell.items(), key=lambda x: x[1], reverse=True)[0][0]
+ if solved_cell[idx] != 0:
+ return [idx]
+ else:
+ return [1000]
+
+
+def check(puzzle_idx, answer_header, answer_value, running_clues, idx2clue, used_clue, all_cells):
+ idx = check_singe_clues(puzzle_idx, answer_header, answer_value, running_clues, idx2clue, used_clue, all_cells)
+ if idx[0] < 1000:
+ return idx
+ else:
+ print("check multiple clues")
+ solved_cell = defaultdict(int)
+ all_left_clues_idx = [i for i in idx2clue if i not in used_clue]
+ print(all_left_clues_idx)
+ for n in range(2, 10):
+ combinations = itertools.combinations(all_left_clues_idx, n)
+ for comb in combinations:
+ new_clues = [clues for clues in running_clues]
+ for comb_idx in comb:
+ new_clues += idx2clue[comb_idx].as_cnf()
+ numbered_cnf, num2var = sat_utils.translate(new_clues)
+ single_solver = solver.my_solver(puzzle_idx, answer_header, answer_value, numbered_cnf, num2var)
+ cell_info = single_solver.check_cell_difficulty()
+ num_cell = 0
+ for cell in cell_info:
+ if cell not in all_cells:
+ num_cell += 1
+ solved_cell[comb] = num_cell
+ del single_solver
+ if num_cell != 0:
+ return comb
diff --git a/reasoning_gym/logic/contrib/logic_puzzle/graph/graph_literals.py b/reasoning_gym/logic/contrib/logic_puzzle/graph/graph_literals.py
new file mode 100644
index 00000000..ef90dc4b
--- /dev/null
+++ b/reasoning_gym/logic/contrib/logic_puzzle/graph/graph_literals.py
@@ -0,0 +1,32 @@
+def print_clue(clue_idxs, idx2clue, new_cell, first=False):
+ if len(clue_idxs) == 1:
+ # print(type(idx2clue[clue_idxs[0]]))
+ # if type(idx2clue[clue_idxs[0]]) == clues.found_at:
+ return single_find_at(idx2clue[clue_idxs[0]], new_cell, first)
+ else:
+ if first:
+ result = "First combining clues: "
+ else:
+ result = "Then combining clues: "
+ for current_idx in clue_idxs:
+ result += "<{}>".format(idx2clue[current_idx])
+ result += " Unique Values Rules and the fixed table structure. We know that "
+ for cell in new_cell:
+ result += cell
+ result += ". "
+ return result
+
+
+def single_find_at(clue, new_cell, first=False):
+ if first:
+ result = "First applying clue: "
+ else:
+ result = "Then applying clue: "
+ if len(new_cell) > 1:
+ result += "<{}> and Unique Values We know that ".format(clue)
+ else:
+ result += "<{}> We know that ".format(clue)
+ for cell in new_cell:
+ result += cell
+ result += ". "
+ return result
diff --git a/reasoning_gym/logic/contrib/logic_puzzle/graph/main.py b/reasoning_gym/logic/contrib/logic_puzzle/graph/main.py
new file mode 100644
index 00000000..763a31ab
--- /dev/null
+++ b/reasoning_gym/logic/contrib/logic_puzzle/graph/main.py
@@ -0,0 +1,49 @@
+import argparse
+import json
+import pickle
+
+import sat_utils
+import solver
+from tqdm import tqdm
+from z3 import *
+
+
+def solve_logic_grid_puzzle(inputfile, ground_truth):
+ answers = [json.loads(answer) for answer in list(open(ground_truth, "r"))][0]
+ cell_difficulty = {}
+ with open(inputfile, "rb") as f:
+ puzzles = pickle.load(f)
+ for i in tqdm(range(len(puzzles[:]))):
+ d = puzzles[i]
+ assert d["idx"] == answers[i]["idx"]
+ answer_header = answers[i]["solution"]["table_header"]
+ answer_value = answers[i]["solution"]["table_rows"]
+ # read cnf form
+ symbolic_cnf = d["puzzle"].as_cnf()
+ numbered_cnf, num2var = sat_utils.translate(symbolic_cnf)
+
+ single_solver = solver.my_solver(d["idx"], answer_header, answer_value, numbered_cnf, num2var)
+ solution, unique = single_solver.check_solution()
+ if unique:
+ for row_num in range(len(solution)):
+ for column_num in range(len(solution[row_num])):
+ assert row_num == solution[row_num][column_num]
+ difficulty = single_solver.check_problem_difficulty()
+ cell_difficulty[d["idx"]] = difficulty
+
+ with open("./logic_grid_puzzles.test.difficulty.pkl", "wb") as outputfile:
+ pickle.dump(cell_difficulty, outputfile)
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--input_data", type=str, default="../logic_grid_puzzles.test.pkl")
+ parser.add_argument("--ground_truth", type=str, default="../logic_grid_puzzles.test.json")
+ args = parser.parse_args()
+
+ solve_logic_grid_puzzle(args.input_data, args.ground_truth)
+ return 1
+
+
+if __name__ == "__main__":
+ main()
diff --git a/reasoning_gym/logic/contrib/logic_puzzle/graph/puzzle_analysis.py b/reasoning_gym/logic/contrib/logic_puzzle/graph/puzzle_analysis.py
new file mode 100644
index 00000000..c09a50ef
--- /dev/null
+++ b/reasoning_gym/logic/contrib/logic_puzzle/graph/puzzle_analysis.py
@@ -0,0 +1,190 @@
+import argparse
+import json
+import os
+from copy import deepcopy
+from pathlib import Path
+from typing import Sequence
+
+from tqdm import tqdm
+
+
+def parse_table(table_string):
+ flat_rows = [r for r in table_string.split("\n") if r.startswith("$")]
+ table = {}
+ for flat_row in flat_rows:
+ flat_cells = [c for c in flat_row.split("|") if c]
+ house, attributes = flat_cells[0], flat_cells[1:]
+ h_prefix, h_number = house.split(":")
+ assert h_prefix == "$ House"
+ house_key = f"House {h_number.strip()}"
+ table[house_key] = {}
+
+ for att in attributes:
+ if not ":" in att:
+ continue
+ name, value = att.split(":")
+ table[house_key][name.strip()] = value.strip()
+
+ return table
+
+
+def copy_table(table):
+ new_table = {}
+ for house, attributes in table.items():
+ new_table[house] = {}
+ for att_name, att_value in attributes.items():
+ new_table[house][att_name] = None
+ return new_table
+
+
+def parse_step(step_text, table_to_fill):
+ parsed_clues = [s.split(">")[0] for s in step_text.split("<") if ">" in s]
+ ans_texts = [s.strip() for s in step_text.split("We know that ")[-1].split(".") if s.strip().startswith("The ")]
+ for ans_text in ans_texts:
+ try:
+ att_name = ans_text.split(" in house ")[0].split("The ")[-1].strip()
+ house_num = ans_text.split(" in house ")[1].split(" is ")[0].strip()
+ att_value = ans_text.split(" is ")[-1].strip()
+ table_to_fill[f"House {house_num}"][att_name] = att_value
+ except:
+ print(parsed_clues)
+ print(ans_text)
+ print(table_to_fill)
+ print()
+ continue
+ return parsed_clues, table_to_fill
+
+
+def pre_process(src_file: os.PathLike, dest_dir: os.PathLike, overwrite: bool = False):
+ outputs = json.load(open(src_file, "r"))["data"]
+
+ src_file = Path(src_file)
+ os.makedirs(dest_dir, exist_ok=True)
+ save_file = os.path.join(dest_dir, f"parsed_{src_file.name}")
+
+ if not overwrite and os.path.exists(save_file):
+ return
+
+ items = []
+ for output in tqdm(outputs, desc=f"{src_file.stem}"):
+ groundtruth_table = parse_table(output["truth_outputs"][0])
+ if type(output["output_text"]) == list:
+ output_text = output["output_text"][0].split("\n")
+ else:
+ output_text = output["output_text"].split("\n")
+ step_texts = [o for o in output_text if o.startswith("Step ")]
+ steps = []
+ partial_table = copy_table(groundtruth_table)
+ for s_text in step_texts:
+ clues, filled_table = parse_step(s_text, partial_table)
+ steps.append(
+ {
+ "clues": clues,
+ "partial_table": filled_table,
+ }
+ )
+ partial_table = deepcopy(filled_table)
+
+ predict_table_texts = [o for o in output_text if o.startswith("$ House")]
+ if not predict_table_texts:
+ predicted_table = partial_table
+ else:
+ predicted_table = parse_table("\n".join(predict_table_texts))
+
+ items.append({"groundtruth_table": groundtruth_table, "predicted_table": predicted_table, "steps": steps})
+
+ with open(save_file, "w") as f:
+ json.dump(items, f, indent=4)
+
+
+def error_analysis(home_path: os.PathLike, file_names: Sequence[str] = None):
+ if not file_names:
+ parsed_files = [p for p in Path(home_path).glob("parsed_table*judged.json")]
+ else:
+ parsed_files = [os.path.join(home_path, f) for f in file_names]
+
+ datasets = []
+ for parsed_file in parsed_files:
+ datasets.extend([json.loads(s) for s in open(parsed_file, "r").readlines()])
+
+ error_types = ["correct", "type1", "type2", "type3"]
+ error_stats = {}
+ for data in tqdm(datasets):
+ gt_table = data["groundtruth_table"]
+ previous_type = None
+ for i, step in enumerate(data["steps"]):
+ partial_table = step["partial_table"]
+ # whether value correct:
+ value_to_check = []
+ for house_name, attributes in partial_table.items():
+ for att, value in attributes.items():
+ if not i:
+ previous_empty = True
+ else:
+ if att not in data["steps"][i - 1]["partial_table"][house_name]:
+ previous_empty = False
+ else:
+ previous_empty = data["steps"][i - 1]["partial_table"][house_name][att] is None
+ if value is not None and previous_empty:
+ value_to_check.append((house_name, att))
+
+ value_correct = all(
+ [partial_table[house_name][att] == gt_table[house_name][att] for house_name, att in value_to_check]
+ )
+ operation_correct = step["label"]
+ if not i:
+ ancestor_correct = True
+ else:
+ ancestor_correct = previous_type == "correct"
+
+ if value_correct:
+ if ancestor_correct:
+ node_type = "correct"
+ else:
+ node_type = "type3"
+ else:
+ if operation_correct:
+ node_type = "type2"
+ else:
+ node_type = "type1"
+
+ previous_type = node_type
+
+ if i not in error_stats.keys():
+ error_stats[i] = {e: 0 for e in error_types}
+ error_stats[i][node_type] += 1
+
+ print(json.dumps(error_stats))
+ percent_stats = {}
+ number_nodes = {}
+ for key, values in error_stats.items():
+ percent_stats[key] = {}
+ total_nodes = sum([values[t] for t in error_types])
+ number_nodes[key] = total_nodes
+ for t in error_types:
+ percent_stats[key][t] = values[t] / total_nodes
+ print(json.dumps(percent_stats))
+ print(number_nodes)
+
+
+def main():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument("--cot_dir", type=str, default=None)
+ parser.add_argument("--output_dir", type=str, default="tmp")
+ parser.add_argument("--overwrite", action="store_true", default=False)
+
+ args = parser.parse_args()
+
+ cot_dir = Path(args.cot_dir)
+ assert cot_dir.exists(), "cot_dir does not exist"
+ assert cot_dir.is_dir(), "cot_dir is not a directory"
+
+ for cot_file in cot_dir.glob("table*.json"):
+ pre_process(cot_file, args.output_dir, args.overwrite)
+
+ error_analysis(args.output_dir)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/reasoning_gym/logic/contrib/logic_puzzle/graph/reasoning_path.py b/reasoning_gym/logic/contrib/logic_puzzle/graph/reasoning_path.py
new file mode 100644
index 00000000..18a286c9
--- /dev/null
+++ b/reasoning_gym/logic/contrib/logic_puzzle/graph/reasoning_path.py
@@ -0,0 +1,113 @@
+import argparse
+import json
+import pickle
+from collections import defaultdict
+
+import check_clues
+import graph_literals
+import sat_utils
+import solver as solver
+from tqdm import tqdm
+from z3 import *
+
+
+def get_idx2clue(clues):
+ clue_num2_clue = defaultdict(int)
+ clue_type2_clue_num = defaultdict(list)
+ clue_num = 0
+ for clue in list(clues):
+ clue_num2_clue[clue_num] = clue
+ clue_type2_clue_num[type(clue)].append(clue_num)
+ clue_num += 1
+ return clue_num2_clue, clue_type2_clue_num
+
+
+def logic_grid_puzzle(inputfile, ground_truth, size, lower_part, higher_part):
+ reasoning_result = []
+ answers = json.load(open(ground_truth, "r"))
+ puzzles = pickle.load(open(inputfile, "rb"))
+ cell_difficulty = {}
+ mode = inputfile[inputfile.find("puzzles.") + 8 : inputfile.find(".pkl")]
+ print("Number of puzzles", len(answers))
+ assert len(answers) == len(puzzles)
+ new_data = []
+ for i in tqdm(range(len(puzzles[:]))):
+ d = puzzles[i]
+ per_size_idx = int(puzzles[i]["idx"].split("-")[-1])
+ if (
+ d["idx"].startswith("lgp-" + mode + "-" + size)
+ and per_size_idx >= lower_part
+ and per_size_idx < higher_part
+ ):
+ print("Puzzle id:", d["idx"])
+ print("Puzzle", d)
+ print("Solving puzzle" + "===============" * 7 + "Solving puzzle")
+
+ assert d["idx"] == answers[i]["idx"]
+ answer_header = answers[i]["solution"]["table_header"]
+ answer_value = answers[i]["solution"]["table_rows"]
+
+ puzzle_clues = d["puzzle"].clues
+ idx2clue, clue_type2_idx = get_idx2clue(puzzle_clues)
+
+ running_clues = []
+ all_cells = []
+ used_clue = []
+ self_constraints = d["puzzle"].constraints
+ running_clues.extend(self_constraints)
+
+ # for i in range(len(idx2clue)):
+ first = True
+ step_num = 1
+ reasoning = ""
+ while len(used_clue) < len(idx2clue):
+ reasoning += "Step {}: ".format(step_num)
+ step_num += 1
+ clue_idxs = check_clues.check(
+ d["idx"], answer_header, answer_value, running_clues, idx2clue, used_clue, all_cells
+ )
+ for current_clue_idx in clue_idxs:
+ running_clues.extend(idx2clue[current_clue_idx].as_cnf())
+ used_clue.append(current_clue_idx)
+
+ numbered_cnf, num2var = sat_utils.translate(running_clues)
+ single_solver = solver.my_solver(d["idx"], answer_header, answer_value, numbered_cnf, num2var)
+ cell_info = single_solver.check_cell_difficulty()
+ new_cell = []
+ for cell in cell_info:
+ if cell not in all_cells:
+ new_cell.append(cell)
+ all_cells.append(cell)
+ reasoning += graph_literals.print_clue(clue_idxs, idx2clue, new_cell, first)
+ first = False
+ reasoning += "\n"
+ assert len(all_cells) == len(answer_value) * (len(answer_value[0]) - 1)
+ reasoning += "The puzzle is solved."
+ reasoning_result.append(reasoning)
+ print(reasoning)
+ single_data = answers[i]
+ single_data["reasoning"] = reasoning
+ new_data.append(single_data)
+
+ with open(
+ "./data/logic_grid_puzzles.reasoning." + mode + size + "-" + str(lower_part) + "_" + str(higher_part) + ".json",
+ "w",
+ ) as outputfile:
+ json.dump(new_data, outputfile)
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--input_data", type=str, default="./data/logic_grid_puzzles.test_id_xl.pkl")
+ parser.add_argument("--ground_truth", type=str, default="./data/ogic_grid_puzzles.test_id_xl.json")
+ parser.add_argument("--size", type=str, default="2x")
+ parser.add_argument("--lower_part", type=int, default=0) # min data index
+ parser.add_argument("--higher_part", type=int, default=100) # max data index
+ args = parser.parse_args()
+
+ logic_grid_puzzle(args.input_data, args.ground_truth, args.size, args.lower_part, args.higher_part)
+ return 1
+
+
+if __name__ == "__main__":
+ main()
diff --git a/reasoning_gym/logic/contrib/logic_puzzle/graph/solver.py b/reasoning_gym/logic/contrib/logic_puzzle/graph/solver.py
new file mode 100644
index 00000000..ed462270
--- /dev/null
+++ b/reasoning_gym/logic/contrib/logic_puzzle/graph/solver.py
@@ -0,0 +1,143 @@
+import collections
+
+import numpy as np
+import solver_utils
+from z3 import *
+
+
+class my_solver:
+ def __init__(self, id, feature_name, feature_value, numbered_cnf, num2var):
+ self.id = id
+ self.num2var = num2var
+ self.clue_num = 0
+
+ self.feature_name = feature_name
+ self.feature_value = feature_value
+ self.numbered_cnf = numbered_cnf
+
+ self.init_solver(feature_name, feature_value, numbered_cnf)
+
+ def init_solver(self, feature_name, feature_value, numbered_cnf):
+ set_param(proof=True)
+ self.s = Solver()
+
+ self.s.set(unsat_core=True)
+ house_total = len(feature_value)
+ self.house_total = house_total
+ features, mapping = self.define_feature(feature_name, feature_value)
+
+ self.houses = [
+ [
+ solver_utils.instanciate_int_constrained("%s%d" % (prop, n), self.s, house_total)
+ for prop in features.keys()
+ ]
+ for n in range(house_total)
+ ]
+ self.mapping = mapping
+ self.feature = features
+ # because we read feature from GT, we can make this assumption to calculate difficulty.
+ self.ground_truth = self.feature
+ for single_cnf in numbered_cnf:
+ if len(single_cnf) == 1:
+ self.add_single_clue(single_cnf)
+ else:
+ self.add_mul_val_clue(single_cnf)
+ self.clue_num += 1
+
+ def define_feature(self, feature_name, feature_value):
+ features = collections.OrderedDict()
+ for i in range(len(feature_name)):
+ if feature_name[i] != "House":
+ features[feature_name[i]] = []
+ for house_j in feature_value:
+ features[feature_name[i]].append(house_j[i])
+
+ feature_mapping = {list(features.keys())[i]: i for i in range(len(features.keys()))}
+ return features, feature_mapping
+
+ def decode_var(self, var_num):
+ attribute, house_num = self.num2var[var_num].split(" ")
+ attribute_name, attribute_value = attribute.split(".")
+ if "_" in attribute_value:
+ attribute_value = attribute_value.replace("_", " ")
+ house_num = int(house_num) - 1
+ return house_num, attribute_name, attribute_value
+
+ def normal_vs_not_constraints(self, house_num, attribute_name, attribute_value):
+ if attribute_name.startswith("~"):
+ attribute_name = attribute_name[1:]
+ return Not(
+ self.houses[house_num][self.mapping[attribute_name]]
+ == self.feature[attribute_name].index(attribute_value)
+ )
+ else:
+ return self.houses[house_num][self.mapping[attribute_name]] == self.feature[attribute_name].index(
+ attribute_value
+ )
+
+ def add_single_clue(self, single_cnf):
+ house_num, attribute_name, attribute_value = self.decode_var(single_cnf[0])
+ self.s.assert_and_track(
+ self.normal_vs_not_constraints(house_num, attribute_name, attribute_value), "a" + str(self.clue_num)
+ )
+
+ def add_mul_val_clue(self, single_cnf):
+ cons = []
+ for i in range(len(single_cnf)):
+ house_num, attribute_name, attribute_value = self.decode_var(single_cnf[i])
+ cons.append(self.normal_vs_not_constraints(house_num, attribute_name, attribute_value))
+ self.s.assert_and_track(Or(*cons), "a" + str(self.clue_num))
+
+ def check_solution(self):
+ if self.s.check() == unsat:
+ c = self.s.unsat_core()
+ print("Size of the unsat core:", len(c))
+ print("Unsat core:", ", ".join([str(i) for i in c]))
+
+ proof_str = str(self.s.proof())
+ print(self.s.proof())
+ print("Proof length:", len(proof_str))
+
+ else:
+ m = self.s.model()
+ solution = [[m[case].as_long() for case in line] for line in self.houses]
+ unique = True
+ if count_solutions(self.s) != 1:
+ print(self.id)
+ unique = False
+ # print("Number of solutions:", count_solutions(self.s))
+ return solution, unique
+
+ def print_solution(self, solution):
+ print("Solution:")
+ print(self.feature)
+
+ def check_cell_difficulty(self):
+ results = []
+ clue_num = 0
+ for i in range(self.house_total):
+ for feature in self.ground_truth:
+ self.s.assert_and_track(Not(self.houses[i][self.mapping[feature]] == i), "check" + str(clue_num))
+ if self.s.check() == unsat:
+ results.append("The {} in house {} is {}".format(feature, i + 1, self.ground_truth[feature][i]))
+ self.s.reset()
+ self.init_solver(self.feature_name, self.feature_value, self.numbered_cnf)
+ return results
+
+ def check_statement_difficulty(self):
+ clue_num = 0
+ proof_length = np.zeros((self.house_total, len(self.ground_truth)))
+ for i in range(self.house_total):
+ for feature in self.ground_truth:
+ self.s.assert_and_track(Not(self.houses[i][self.mapping[feature]] == i), "check" + str(clue_num))
+ clue_num += 1
+ if self.s.check() == unsat:
+ c = self.s.unsat_core()
+ proof_length[i, self.mapping[feature]] = self.s.statistics().propagations
+ print(proof_length)
+ return proof_length
+
+ def check_problem_difficulty(self):
+ proof_length = np.zeros((1, 1))
+ proof_length[0, 0] = self.s.statistics().propagations
+ return proof_length
diff --git a/reasoning_gym/logic/contrib/logic_puzzle/graph/solver_utils.py b/reasoning_gym/logic/contrib/logic_puzzle/graph/solver_utils.py
new file mode 100644
index 00000000..021d0763
--- /dev/null
+++ b/reasoning_gym/logic/contrib/logic_puzzle/graph/solver_utils.py
@@ -0,0 +1,36 @@
+from z3 import *
+
+
+def column(matrix, i):
+ return [matrix[j][i] for j in range(len(matrix))]
+
+
+def instanciate_int_constrained(name, s, card):
+ x = Int(name)
+ # Each int represent an index in p[name]
+ s.add(x >= 0, x <= card - 1)
+ return x
+
+
+def count_solutions(s, max=1e9):
+ count = 0
+ while s.check() == sat:
+ count += 1
+ if count >= max:
+ return count
+ m = s.model()
+
+ # Create a new constraint the blocks the current model
+ block = []
+ for d in m:
+ # d is a declaration
+ if d.arity() > 0:
+ raise Z3Exception("uninterpreted functions are not supported")
+ # create a constant from declaration
+ c = d()
+ if is_array(c) or c.sort().kind() == Z3_UNINTERPRETED_SORT:
+ raise Z3Exception("arrays and uninterpreted sorts are not supported")
+ block.append(c != m[d])
+
+ s.add(Or(block))
+ return count
diff --git a/reasoning_gym/logic/contrib/logic_puzzle/literals.py b/reasoning_gym/logic/contrib/logic_puzzle/literals.py
new file mode 100644
index 00000000..f1ecfabc
--- /dev/null
+++ b/reasoning_gym/logic/contrib/logic_puzzle/literals.py
@@ -0,0 +1,524 @@
+"""
+literals.py
+
+This is a collection of "puzzle elements"---categories, literals, whatever you want to call
+them---which are used as the building blocks of zebra puzzles. Examples include people's
+favorite colors, preferred drinks, pets, etc.
+
+Each class must provide (but we have no way of enforcing this) a description of each
+puzzle element. These get used to make human-readable clues. The classes must also provide
+a custom __repr__ method that gets used in the puzzle description.
+
+Included is a base Literal class from which all literals should inherit. To extend these,
+just import Literal and implement a class like the ones here.
+
+"""
+
+from enum import Enum
+
+
+class Literal(Enum):
+ """
+ Common parent class for all puzzle elements (colors, occupations, pets, etc.).
+
+ We can't make this an ABC because ABC and Enum have different metaclasses, and that'd be
+ super confusing. But override the description method!
+ """
+
+ @classmethod
+ def description(cls) -> str:
+ return "".join(cls.__members__) # type:ignore
+
+
+class Color(Literal):
+ @classmethod
+ def description(cls) -> str:
+ return f"Each person has a favorite color:"
+
+ yellow = "the person who loves yellow"
+ red = "the person whose favorite color is red"
+ white = "the person who loves white"
+ green = "the person whose favorite color is green"
+ blue = "the person who loves blue"
+ purple = "the person who loves purple"
+ brown = "the person who loves brown"
+
+
+class Nationality(Literal):
+ @classmethod
+ def description(cls) -> str:
+ return f"The people are of nationalities:"
+
+ dane = "the Dane"
+ brit = "the British person"
+ swede = "the Swedish person"
+ norwegian = "the Norwegian"
+ german = "the German"
+ chinese = "the Chinese"
+ mexican = "the Mexican"
+
+
+class Animal(Literal):
+ @classmethod
+ def description(cls) -> str:
+ return f"The people keep different animals:"
+
+ horse = "the person who keeps horses"
+ cat = "the cat lover"
+ bird = "the bird keeper"
+ fish = "the fish enthusiast"
+ dog = "the dog owner"
+ rabbit = "the rabbit owner"
+ pig = "the pig owner"
+
+
+class Drink(Literal):
+ @classmethod
+ def description(cls) -> str:
+ return f"Each person has a favorite drink:"
+
+ water = "the one who only drinks water"
+ tea = "the tea drinker"
+ milk = "the person who likes milk"
+ coffee = "the coffee drinker"
+ root_beer = "the root beer lover"
+ boba_tea = "the boba tea drinker"
+ wine = "the wine drinker"
+
+
+class Cigar(Literal):
+ @classmethod
+ def description(cls) -> str:
+ return f"Everyone has a different favorite cigar:"
+
+ pall_mall = "the person partial to Pall Mall"
+ prince = "the Prince smoker"
+ blue_master = "the person who smokes Blue Master"
+ dunhill = "the Dunhill smoker"
+ blends = "the person who smokes many different blends"
+ yellow_monster = "the person who smokes Yellow Monster"
+ red_eye = "the person who smokes Red Eye"
+
+
+class Mother(Literal):
+ @classmethod
+ def description(cls) -> str:
+ return f"The mothers' names are:"
+
+ aniya = "The person whose mother's name is Aniya"
+ holly = "The person whose mother's name is Holly"
+ janelle = "The person whose mother's name is Janelle"
+ kailyn = "The person whose mother's name is Kailyn"
+ penny = "The person whose mother's name is Penny"
+
+
+class Children(Literal):
+ @classmethod
+ def description(cls) -> str:
+ return f"Each mother is accompanied by their child: Bella, Fred, Meredith, Samantha, Timothy."
+
+ bella = "the person's child is named Bella"
+ fred = "the person's child is named Fred"
+ meredith = "the person's child is named Meredith"
+ samantha = "the person's child is named Samantha"
+ timothy = "the person who is the mother of Timothy"
+ alice = "the person's child is named Alice"
+ billy = "the person who is the mother of Billy"
+
+
+class Flower(Literal):
+ @classmethod
+ def description(cls) -> str:
+ return f"They all have a different favorite flower:"
+
+ carnations = "the person who loves a carnations arrangement"
+ daffodils = "the person who loves a bouquet of daffodils"
+ lilies = "the person who loves the boquet of lilies"
+ roses = "the person who loves the rose bouquet"
+ tulips = "the person who loves the vase of tulips"
+ iris = "the person who loves the boquet of iris"
+ orchid = "the person who loves the boquet of orchid"
+
+
+class Food(Literal):
+ @classmethod
+ def description(cls) -> str:
+ return f"Everyone has something different for lunch:"
+
+ grilled_cheese = "the person who loves eating grilled cheese"
+ pizza = "the person who is a pizza lover"
+ spaghetti = "the person who loves the spaghetti eater"
+ stew = "the person who loves the stew"
+ stir_fry = "the person who loves stir fry"
+ soup = "the person who loves the soup"
+ sushi = "the person who loves the sushi"
+
+
+class Kiro(Literal):
+ @classmethod
+ def description(cls) -> str:
+ return f"Each house has a different type of Kiro:"
+
+ kaya = "the Kaya Kiro"
+ sugar_sketch = "the Sugar Sketch Kiro"
+ silosaur = "the Kiro disguised as a Silosaur"
+ skyrant = "the Kiro in a Skyrant costume "
+ traptor_costume = "the Kiro in a Traptor costume"
+ terasaur_costume = "the Terasaur Costume Kiro"
+ skeleko = "the Skeleko Kiro"
+ zodiac_dragon = "the Zodiac Dragon Kiro"
+ gem_dragon = "the Gem Dragon Kiro"
+ plushie = "the Plushie Kiro"
+ gloray = "the Gloray Kiro"
+ rabbit = "the Rabbit Kiro"
+ holiday_sweets = "the Holiday Sweets Kiro"
+ baby = "the Baby Kiro"
+ zaeris = "the Zaeris Kiro"
+
+
+class Smoothie(Literal):
+ @classmethod
+ def description(cls) -> str:
+ return f"Everyone has a favorite smoothie:"
+
+ cherry = "the person who likes Cherry smoothies"
+ desert = "the Desert smoothie lover"
+ watermelon = "the Watermelon smoothie lover"
+ dragonfruit = "the Dragonfruit smoothie lover"
+ lime = "the person who drinks Lime smoothies"
+ blueberry = "the person who drinks Blueberry smoothies"
+ lemon = "the Lemon smoothie lover"
+ dusk = "the person whose favorite smoothie is Dusk"
+ dawn = "the person who likes Dawn smoothies"
+ spring = "the person who likes Spring smoothies"
+ seafoam = "the person who likes Seafoam smoothies"
+ phantom_spring = "the person who likes Phantom Spring smoothies"
+ abyss = "the person whose favorite smoothie is Abyss"
+ butterscotch = "the Butterscotch smoothie drinker"
+ lilac = "the Lilac smoothie drinker"
+ sakura = "the person whose favorite smoothie is Sakura"
+ life = "the Life smoothie drinker"
+ darkness = "the Darkness smoothie drinker"
+ earth = "the person who likes Earth smoothies"
+
+
+class Bottlecap(Literal):
+ @classmethod
+ def description(cls) -> str:
+ return f"Everyone keeps a certain type of Bottlecap:"
+
+ red = "the person who has red bottlecaps"
+ yellow = "the person who likes YBC"
+ green = "the GBC keeper"
+ blue = "the blue bottlecap hoarder"
+ silver = "the SBC winner"
+
+
+class RecolorMedal(Literal):
+ @classmethod
+ def description(cls) -> str:
+ return f"Everyone has a recolor or medal:"
+
+ top_level = "the top level person"
+ second_ed = "the 2nd edition person"
+ ghost = "the ghost recolor"
+ pink = "the pink person"
+ gold = "the person with a heart of gold"
+
+
+class NPC(Literal):
+ @classmethod
+ def description(cls) -> str:
+ return f"Each is an NPC on the site:"
+
+ jim = "Dirt Digger Jim"
+ amelia = "Amelia"
+ chip = "Fishin' Chip"
+ riley = "Ringmaster Riley"
+ crowley = "Crowley"
+ silver = "Silver the Kua"
+ jagger = "Jagger"
+
+
+class FavoriteGame(Literal):
+ @classmethod
+ def description(cls) -> str:
+ return f"Everyone has a favorite game:"
+
+ dirt_digger = "the person who likes Dirt Digger"
+ guess_the_number = "the one who plays Guess the Number"
+ fishing_fever = "the Fishing Fever lover"
+ sock_summoning = "the person who plays Sock Summoning"
+ wonder_wheel = "the person who spins the Wonder Wheel"
+ fetch = "the person playing Fetch"
+ quality_assurance = "the person who plays Quality Assurance"
+ stop_and_swap = "the one who often plays Stop and Swap"
+ uchi_fusion = "the one who plays Uchi Fusion"
+ freedom_forest = "the one in Freedom Forest"
+
+
+class Tribe(Literal):
+ @classmethod
+ def description(cls) -> str:
+ return f"Everyone has an Altazan tribe:"
+
+ quake = "the one in the Quake tribe"
+ cursed = "the Cursed tribe member"
+ forest = "the Forest tribe member"
+ volcano = "the person in the Volcano tribe"
+ storm = "the person in the Storm tribe"
+
+
+class Kaya(Literal):
+ @classmethod
+ def description(cls) -> str:
+ return f"They are five different types of Kaya:"
+
+ joy = "the Kaya of Joy"
+ life = "the Kaya of Life"
+ harmony = "the Kaya of Harmony"
+ wisdom = "the Kaya of Wisdom"
+ love = "the Kaya of Love"
+
+
+class TraptorPrimary(Literal):
+ @classmethod
+ def description(cls) -> str:
+ return f"They have different primary colors:"
+
+ majestic = "the Majestic Traptor"
+ grand = "the Grand Traptor"
+ stunning = "the Stunning Traptor"
+ marvellous = "the Marvellous Traptor"
+ heroic = "the Heroic Traptor"
+
+
+class TraptorSecondary(Literal):
+ @classmethod
+ def description(cls) -> str:
+ return f"They have different secondary colors:"
+
+ sky = "the Sky Traptor"
+ forest = "the Forest Traptor"
+ night = "the Night Traptor"
+ sun = "the Sun Traptor"
+ sand = "the Sand Traptor"
+
+
+class TraptorTertiary(Literal):
+ @classmethod
+ def description(cls) -> str:
+ return f"They have different tertiary colors:"
+
+ soarer = "the Soarer Traptor"
+ diver = "the Diver Traptor"
+ screecher = "the Screecher Traptor"
+ hunter = "the Hunter Traptor"
+ nurturer = "the Nurturer Traptor"
+
+
+class Egg(Literal):
+ @classmethod
+ def description(cls) -> str:
+ return f"They are each giving out a type of egg:"
+
+ golden = "the one giving out Golden Eggs"
+ trollden = "the one who keeps Trollden Eggs"
+ topaz = "the one with Topaz Eggs"
+ crystal = "the one giving out Crystal Eggs"
+ traptor = "the one who has Traptor Eggs"
+ marinodon = "the one giving out Marinodon Eggs"
+
+
+class Dinomon(Literal):
+ @classmethod
+ def description(cls) -> str:
+ return f"Each is a different species of Dinomon:"
+
+ terasaur = "the Terasaur"
+ carnodon = "the Carnodon"
+ silosaur = "the Silosaur"
+ marinodon = "the Marinodon"
+ traptor = "the Traptor"
+
+
+# added by yuchenl
+
+
+class Name(Literal):
+ @classmethod
+ def description(cls) -> str:
+ return f"Each person has a unique name:"
+
+ arnold = "Arnold"
+ eric = "Eric"
+ peter = "Peter"
+ alice = "Alice"
+ bob = "Bob"
+ carol = "Carol"
+ david = "David"
+ emily = "Emily"
+
+
+class Age(Literal):
+ @classmethod
+ def description(cls) -> str:
+ return f"Each person has a unique age:"
+
+ age_7 = "the person who is 7 years old"
+ age_8 = "the person who is 8"
+ age_9 = "the person who is 9"
+ age_10 = "the person who is 10"
+ age_11 = "the person who is 11"
+ age_12 = "the person who is 12"
+ age_13 = "the person who is 13"
+ age_14 = "the person who is 14"
+
+
+class Birthday(Literal):
+ @classmethod
+ def description(cls) -> str:
+ return f"Each person has a unique birthday month:"
+
+ april = "the person whose birthday is in April"
+ sept = "the person whose birthday is in September"
+ jan = "the person whose birthday is in January"
+ feb = "the person whose birthday is in February"
+ mar = "the person whose birthday is in March"
+ may = "the person whose birthday is in May"
+ june = "the person whose birthday is in June"
+ july = "the person whose birthday is in July"
+
+
+# ChatGPT generated
+
+"""
+generate more classes like this according to the common properties that you would describe a person or a house, such as the occupation, age, height, hair color, and so on
+
+"""
+
+
+class Occupation(Literal):
+ @classmethod
+ def description(cls) -> str:
+ return "Each person has an occupation: doctor, engineer, teacher, artist, lawyer, nurse, or accountant."
+
+ doctor = "the person who is a doctor"
+ engineer = "the person who is an engineer"
+ teacher = "the person who is a teacher"
+ artist = "the person who is an artist"
+ lawyer = "the person who is a lawyer"
+ nurse = "the person who is a nurse"
+ accountant = "the person who is an accountant"
+
+
+# class AgeGroup(Literal):
+# @classmethod
+# def description(cls) -> str:
+# return "People are grouped by age: infant, child, teenager, young adult, adult, or senior."
+
+# infant = "the person who is an infant"
+# child = "the person who is a child"
+# teenager = "the person who is a teenager"
+# young_adult = "the person who is a young adult"
+# adult = "the person who is an adult"
+# senior = "the person who is a senior"
+
+
+class Height(Literal):
+ @classmethod
+ def description(cls) -> str:
+ return "People have different heights: very short, short, average, tall, very tall, or super tall."
+
+ very_short = "the person who is very short"
+ short = "the person who is short"
+ average = "the person who has an average height"
+ tall = "the person who is tall"
+ very_tall = "the person who is very tall"
+ super_tall = "the person who is super tall"
+
+
+class HairColor(Literal):
+ @classmethod
+ def description(cls) -> str:
+ return "People have different hair colors: black, brown, blonde, red, gray, auburn, or white."
+
+ black = "the person who has black hair"
+ brown = "the person who has brown hair"
+ blonde = "the person who has blonde hair"
+ red = "the person who has red hair"
+ gray = "the person who has gray hair"
+ auburn = "the person who has auburn hair"
+ white = "the person who has white hair"
+
+
+class CarModel(Literal):
+ @classmethod
+ def description(cls) -> str:
+ return "People own different car models: Tesla Model 3, Ford F-150, Toyota Camry, Honda Civic, BMW 3 Series, Chevrolet Silverado, or Audi A4."
+
+ tesla_model_3 = "the person who owns a Tesla Model 3"
+ ford_f150 = "the person who owns a Ford F-150"
+ toyota_camry = "the person who owns a Toyota Camry"
+ honda_civic = "the person who owns a Honda Civic"
+ bmw_3_series = "the person who owns a BMW 3 Series"
+ chevrolet_silverado = "the person who owns a Chevrolet Silverado"
+ audi_a4 = "the person who owns an Audi A4"
+
+
+class PhoneModel(Literal):
+ @classmethod
+ def description(cls) -> str:
+ return "People use different phone models: iPhone 13, Samsung Galaxy S21, Google Pixel 6, OnePlus 9, Huawei P50, Xiaomi Mi 11, or Sony Xperia 5."
+
+ iphone_13 = "the person who uses an iPhone 13"
+ samsung_galaxy_s21 = "the person who uses a Samsung Galaxy S21"
+ google_pixel_6 = "the person who uses a Google Pixel 6"
+ oneplus_9 = "the person who uses a OnePlus 9"
+ huawei_p50 = "the person who uses a Huawei P50"
+ xiaomi_mi_11 = "the person who uses a Xiaomi Mi 11"
+ sony_xperia_5 = "the person who uses a Sony Xperia 5"
+
+
+class FavoriteSport(Literal):
+ @classmethod
+ def description(cls) -> str:
+ return "People have different favorite sports: soccer, basketball, tennis, swimming, baseball, volleyball, or golf."
+
+ soccer = "the person who loves soccer"
+ basketball = "the person who loves basketball"
+ tennis = "the person who loves tennis"
+ swimming = "the person who loves swimming"
+ baseball = "the person who loves baseball"
+ volleyball = "the person who loves volleyball"
+ golf = "the person who loves golf"
+
+
+class MusicGenre(Literal):
+ @classmethod
+ def description(cls) -> str:
+ return (
+ "People have different favorite music genres: rock, pop, classical, jazz, hip-hop, country, or electronic."
+ )
+
+ rock = "the person who loves rock music"
+ pop = "the person who loves pop music"
+ classical = "the person who loves classical music"
+ jazz = "the person who loves jazz music"
+ hip_hop = "the person who loves hip-hop music"
+ country = "the person who loves country music"
+ electronic = "the person who loves electronic music"
+
+
+class BookGenre(Literal):
+ @classmethod
+ def description(cls) -> str:
+ return "People have different favorite book genres: mystery, science fiction, romance, fantasy, biography, historical fiction, or non-fiction."
+
+ mystery = "the person who loves mystery books"
+ science_fiction = "the person who loves science fiction books"
+ romance = "the person who loves romance books"
+ fantasy = "the person who loves fantasy books"
+ biography = "the person who loves biography books"
+ historical_fiction = "the person who loves historical fiction books"
+ non_fiction = "the person who loves non-fiction books"
diff --git a/reasoning_gym/logic/contrib/logic_puzzle/puzzle.py b/reasoning_gym/logic/contrib/logic_puzzle/puzzle.py
new file mode 100644
index 00000000..a72c2bf7
--- /dev/null
+++ b/reasoning_gym/logic/contrib/logic_puzzle/puzzle.py
@@ -0,0 +1,286 @@
+"""solver.py
+
+Solve the Einstein puzzle using Raymond Hettinger's approach.
+"""
+
+from __future__ import annotations
+
+from contextlib import contextmanager
+from random import shuffle
+from typing import Dict, Generator, Iterable, List, Set, Tuple, Type
+
+from reasoning_gym.logic.contrib.logic_puzzle.clues import (
+ Clue,
+ beside,
+ comb,
+ consecutive,
+ found_at,
+ left_of,
+ one_between,
+ right_of,
+ same_house,
+ two_between,
+)
+from reasoning_gym.logic.contrib.logic_puzzle.literals import (
+ Animal,
+ Children,
+ Cigar,
+ Color,
+ Drink,
+ Flower,
+ Food,
+ Literal,
+ Mother,
+ Nationality,
+)
+from reasoning_gym.logic.contrib.logic_puzzle.sat_utils import one_of, solve_all
+
+
+class Puzzle:
+ """
+ A Puzzle is defined as a collection of constraints and clues.
+
+ Clues are subclassess of Clue. They represent information about the puzzle that can be used by
+ a human to solve it, like "the man who drinks tea owns a cat." Clues aren't converted to CNF
+ until the `as_cnf` method is called.
+
+ Constraints are structural properties of the puzzle, given to us in CNF to start. They're
+ things like "each house gets exactly one type of flower" and "each flower must be assigned
+ to one house." These will be the same for every Puzzle, so we have a default `set_constraints`
+ method that takes care of them.
+
+ We can add clues with `add_clue`. This returns the instance, so they can be chained together.
+
+ Since in constraint satisfaction land, clues and constraints are the same thing (they're just
+ logical clauses), we lump them all together at solve time.
+ """
+
+ def __init__(
+ self,
+ *,
+ element_types: Iterable[Type[Literal]],
+ elements: Iterable[Literal] = None,
+ n_houses: int = 5,
+ ) -> None:
+ """
+ Initialize a puzzle with different kinds of elements. The `element_types` is a list of the
+ *kinds* of literals we're using, i.e., Smoothie, FavoriteFood, FavoriteColor, etc. The
+ `elements` is a list of the literals themselves, since some of the literal types have more
+ than `n_houses` elements.
+
+ If `elements` is not provided, we assume that every member of each of `element_types` is
+ part of the puzzle. This is the case in example puzzles, but rarely the case for generated
+ ones.
+ """
+
+ self.element_classes = list(element_types)
+ if elements is None:
+ self.literals = [el for el_class in self.element_classes for el in el_class]
+ else:
+ self.literals = list(elements)
+
+ self.houses = tuple(range(1, n_houses + 1))
+ self.clues: Set[Clue] = set()
+ self.constraints: List[Tuple[str]] = []
+ self.extra_clues: Set[Clue] = set()
+ self.solution = None
+
+ def _add_constraint(self, constraints: List[Tuple[str]]) -> Puzzle:
+ self.constraints.extend(constraints)
+ return self
+
+ def set_constraints(self) -> Puzzle:
+ # each house gets exactly one value from each set of literals
+ for house in self.houses:
+ for element_type in self.element_classes:
+ literals_of_that_type = [l for l in self.literals if isinstance(l, element_type)]
+ self._add_constraint(one_of(comb(value, house) for value in literals_of_that_type))
+
+ # each value gets assigned to exactly one house
+ for literal in self.literals:
+ self._add_constraint(one_of(comb(literal, house) for house in self.houses))
+
+ return self
+
+ def add_clue(self, clue: Clue) -> Puzzle:
+ self.clues.add(clue)
+ return self
+
+ def remove_clue(self, clue: Clue) -> Puzzle:
+ self.clues.remove(clue)
+ return self
+
+ @contextmanager
+ def with_clues(self, clues: Iterable[Clue], remove_after=True) -> Generator[Puzzle]:
+ """Create a context in which this Puzzle temporarily has clues added to it"""
+
+ clues = list(clues) # so we don't accidentally exhaust the iterable
+ empty_clue = len(self.clues) == 0
+ for clue in clues:
+ self.add_clue(clue)
+
+ yield self
+ if empty_clue:
+ for clue in clues:
+ self.remove_clue(clue)
+
+ return self
+
+ def as_cnf(self) -> List[Tuple[str]]:
+ """Express puzzle as solvable CNF"""
+
+ # this would be a comprehension if we could use iterable unpacking
+ cnf = []
+ for clue in self.clues:
+ cnf.extend(clue.as_cnf())
+
+ cnf.extend(self.constraints)
+ return cnf
+
+ def __repr__(self) -> str:
+ s = f"This is a logic puzzle. "
+ s += f"There are {len(self.houses)} houses (numbered {self.houses[0]} on the left, "
+ s += f"{self.houses[-1]} on the right), from the perspective of someone standing across "
+ s += f"the street from them. Each has a different person in them. "
+ s += f"They have different characteristics:\n"
+ for element_type in self.element_classes:
+ literals = [l for l in self.literals if isinstance(l, element_type)]
+ shuffle(literals)
+ desc = element_type.description()
+ idx = desc.index(":") if ":" in desc else None
+ desc = desc[:idx]
+ s += f" - {desc}: " + ", ".join(e.name.replace("_", " ") for e in literals) + "\n"
+
+ s += "\n"
+ for i, clue in enumerate(self.clues):
+ s += f"{i + 1}. {clue}\n"
+
+ return s
+
+
+"""
+Original version
+
+Taken straight from rhettinger.github.io and the associated talk.
+
+Entities:
+ * There are five houses in unique colors: Blue, green, red, white and yellow.
+ * In each house lives a person of unique nationality: British, Danish, German, Norwegian and Swedish.
+ * Each person drinks a unique beverage: Beer, coffee, milk, tea and water.
+ * Each person smokes a unique cigar brand: Blue Master, Dunhill, Pall Mall, Prince and blend.
+ * Each person keeps a unique pet: Cats, birds, dogs, fish and horses.
+
+Constraints:
+ * The Brit lives in a red house.
+ * The Swede keeps dogs as pets.
+ * The Dane drinks tea.
+ * The green house is on the left of the white, next to it.
+ * The green house owner drinks coffee.
+ * The person who smokes Pall Mall rears birds.
+ * The owner of the yellow house smokes Dunhill.
+ * The man living in the house right in the center drinks milk.
+ * The Norwegian lives in the first house.
+ * The man who smokes blend lives next to the one who keeps cats.
+ * The man who keeps horses lives next to the man who smokes Dunhill.
+ * The owner who smokes Blue Master drinks beer.
+ * The German smokes Prince.
+ * The Norwegian lives next to the blue house.
+ * The man who smokes blend has a neighbor who drinks water.
+
+For each house, find out what color it is, who lives there, what they drinkk, what
+they smoke, and what pet they own.
+"""
+
+if __name__ == "__main__":
+ enum_classes: List[Type[Literal]] = [Color, Nationality, Animal, Drink, Cigar]
+ literals: List[Literal] = [el for group in enum_classes for el in group]
+
+ # set up the puzzle with constraints and clues
+ puzzle = Puzzle(element_types=[Color, Nationality, Drink, Cigar, Animal])
+
+ puzzle = (
+ puzzle.set_constraints()
+ .add_clue(same_house(Nationality.brit, Color.red))
+ .add_clue(same_house(Nationality.swede, Animal.dog))
+ .add_clue(same_house(Nationality.dane, Drink.tea))
+ .add_clue(consecutive(Color.green, Color.white))
+ .add_clue(same_house(Color.green, Drink.coffee))
+ .add_clue(same_house(Cigar.pall_mall, Animal.bird))
+ .add_clue(same_house(Color.yellow, Cigar.dunhill))
+ .add_clue(found_at(Drink.milk, 3))
+ .add_clue(found_at(Nationality.norwegian, 1))
+ .add_clue(beside(Cigar.blends, Animal.cat))
+ .add_clue(beside(Animal.horse, Cigar.dunhill))
+ .add_clue(same_house(Cigar.blue_master, Drink.root_beer))
+ .add_clue(same_house(Nationality.german, Cigar.prince))
+ .add_clue(beside(Nationality.norwegian, Color.blue))
+ .add_clue(beside(Cigar.blends, Drink.water))
+ )
+
+ print(puzzle)
+
+ sols = solve_all(puzzle.as_cnf())
+ print(f"{len(sols)} solutions found")
+ for sol in sols:
+ print(sol)
+
+"""
+Quag's version
+
+In honor of Mother's Day, a feast is being held to celebrate five Moms: Aniya, Holly,
+Janelle, Kailyn, and Penny. Each Mom will be served by their son or daughter (Bella,
+Fred, Meredith, Samantha, and Timothy), who will also place a bouquet of flowers
+(Carnations, Daffodils, Lilies, Roses, or Tulips) at their Mom's place setting and
+prepare a meal for them (Grilled Cheese, Pizza, Spaghetti, Stew, or Stir Fry).
+
+The seats are arranged in a straight line at the head table, with the first being
+the furthest to the left (from our perspective, not the Mom's perspectives).
+
+Also, when it says there is "one chair" between two people, it means one person might
+be in the second chair while the other person is in the fourth (i.e. there is one chair
+in between them that neither is sitting in).
+"""
+
+if __name__ == "__main__":
+ enum_classes: List[Type[Literal]] = [Mother, Children, Flower, Food]
+ literals = [el for group in enum_classes for el in group]
+
+ # set up the puzzle with constraints and clues
+ puzzle = Puzzle(element_types=[Mother, Children, Flower, Food])
+
+ puzzle = (
+ puzzle.set_constraints()
+ # 1. There is one chair between the place setting with Lilies and the one eating Grilled Cheese.
+ .add_clue(one_between(Flower.lilies, Food.grilled_cheese))
+ # 2. There is one chair between Timothy's Mom and the one eating Stew.
+ .add_clue(one_between(Children.timothy, Food.stew))
+ # 3. There are two chairs between the Bella's Mom and Penny's seat on the right.
+ .add_clue(two_between(Children.bella, Mother.penny))
+ .add_clue(right_of(Mother.penny, Children.bella))
+ # 4. There is one chair between the place setting with Roses and the one eating Spaghetti on the left.
+ .add_clue(one_between(Flower.roses, Food.spaghetti))
+ .add_clue(left_of(Food.spaghetti, Flower.roses))
+ # 5. There are two chairs between the place setting with Carnations and Samantha's Mom.
+ .add_clue(two_between(Flower.carnations, Children.samantha))
+ # 6. There is one chair between Meredith's Mom and Timothy's Mom on the left.
+ .add_clue(one_between(Children.meredith, Children.timothy))
+ .add_clue(left_of(Children.timothy, Children.meredith))
+ # 7. Aniya's place setting has a lovely Carnation bouquet.
+ .add_clue(same_house(Mother.aniya, Flower.carnations))
+ # 8. There are two chairs between the one eating Grilled Cheese and the one eating Spaghetti.
+ .add_clue(two_between(Food.grilled_cheese, Food.spaghetti))
+ # 9. The person in the first chair (left-most) is eating Pizza.
+ .add_clue(found_at(Food.pizza, 1))
+ # 10. The Tulips were placed at one of the place settings somewhere to the left of Penny's chair.
+ .add_clue(left_of(Flower.tulips, Mother.penny))
+ # 11. There are two chairs between the one eating Spaghetti and Kailyn's seat.
+ .add_clue(two_between(Food.spaghetti, Mother.kailyn))
+ # 12. There is one chair between the one eating Pizza and Holly's chair on the right.
+ .add_clue(one_between(Food.pizza, Mother.holly))
+ .add_clue(right_of(Mother.holly, Food.pizza))
+ )
+
+ print(puzzle)
+ all_solutions = solve_all(puzzle.as_cnf())
+ print(f"{len(all_solutions)} solutions found")
+ print(all_solutions)
diff --git a/reasoning_gym/logic/contrib/logic_puzzle/sat_utils.py b/reasoning_gym/logic/contrib/logic_puzzle/sat_utils.py
new file mode 100644
index 00000000..df5e26e9
--- /dev/null
+++ b/reasoning_gym/logic/contrib/logic_puzzle/sat_utils.py
@@ -0,0 +1,170 @@
+"""Utility functions to humanize interaction with pycosat"""
+
+__author__ = "Raymond Hettinger"
+
+from functools import lru_cache
+from itertools import combinations
+from typing import Dict, FrozenSet, Iterable, List, Set, Tuple
+
+import pycosat
+
+Element = str # literal; any string, but here it's e.g., "tushar 5" or "chai 2"
+CNF = List[Tuple[Element, ...]]
+
+
+def make_translate(cnf: CNF) -> Tuple[Dict[Element, int], Dict[int, Element]]:
+ """Make a translator from symbolic CNF to pycosat's numbered clauses.
+
+ Return literal to number dictionary and reverse lookup dict.
+
+ >>> make_translate([['~a', 'b', '~c'], ['a', '~c']])
+ ({'a': 1, 'c': 3, 'b': 2, '~a': -1, '~b': -2, '~c': -3},
+ {1: 'a', 2: 'b', 3: 'c', -1: '~a', -3: '~c', -2: '~b'})
+ """
+
+ lit2num: Dict[Element, int] = {}
+ for clause in cnf:
+ for literal in clause:
+ if literal not in lit2num:
+ var = literal[1:] if literal[0] == "~" else literal
+ num = len(lit2num) // 2 + 1
+ lit2num[var] = num
+ lit2num["~" + var] = -num
+
+ num2var = {num: lit for lit, num in lit2num.items()}
+
+ return lit2num, num2var
+
+
+def translate(cnf: CNF, uniquify=False) -> Tuple[List[Tuple[int, ...]], Dict[int, Element]]:
+ """Translate a symbolic CNF to a numbered CNF and return reverse mapping.
+
+ >>> translate([['~P', 'Q'],['~P', 'R']])
+ [(-1, 2), (-1, 3)], {-3: '~R', -2: '~Q', -1: '~P', 1: 'P', 2: 'Q', 3: 'R'}
+ """
+
+ # DIMACS CNF file format:
+ # http://people.sc.fsu.edu/~jburkardt/data/cnf/cnf.html
+ if uniquify:
+ cnf = list(dict.fromkeys(cnf))
+
+ lit2num, num2var = make_translate(cnf)
+ numbered_cnf = [tuple([lit2num[lit] for lit in clause]) for clause in cnf]
+
+ return numbered_cnf, num2var
+
+
+def itersolve(symbolic_cnf: CNF, include_neg: bool = False):
+ numbered_cnf, num2var = translate(symbolic_cnf)
+ for solution in pycosat.itersolve(numbered_cnf):
+ yield [num2var[n] for n in solution if include_neg or n > 0]
+
+
+def solve_all(symbolic_cnf: CNF, include_neg: bool = False):
+ return list(itersolve(symbolic_cnf, include_neg))
+
+
+def solve_one(symbolic_cnf: CNF, include_neg: bool = False):
+ return next(itersolve(symbolic_cnf, include_neg))
+
+
+############### Support for Building CNFs ##########################
+
+
+@lru_cache(maxsize=None)
+def neg(element: str) -> str:
+ """Negate a single element"""
+
+ return element[1:] if element.startswith("~") else "~" + element
+
+
+def from_dnf(groups: Iterable[Tuple[str, ...]]) -> CNF:
+ """Convert from or-of-ands to and-of-ors
+
+ >>> from_dnf([['~P'], ['Q', 'R']])
+ [('~P', 'Q'), ('~P', 'R')]
+ """
+
+ cnf: Set[FrozenSet[str]] = {frozenset()}
+ for group in groups:
+ nl = {frozenset([literal]): neg(literal) for literal in group}
+ # The "clause | literal" prevents dup lits: {x, x, y} -> {x, y}
+ # The nl check skips over identities: {x, ~x, y} -> True
+ cnf = {clause | literal for literal in nl for clause in cnf if nl[literal] not in clause}
+ # The sc check removes clauses with superfluous terms:
+ # {{x}, {x, z}, {y, z}} -> {{x}, {y, z}}
+ # Should this be left until the end?
+ sc = min(cnf, key=len) # XXX not deterministic
+ cnf -= {clause for clause in cnf if clause > sc}
+
+ return [tuple(clause) for clause in cnf]
+
+
+class Q:
+ """Quantifier for the number of elements that are true
+
+ >>> Q(['A', 'B', 'C']) <= 1
+ [('~A', '~B'),
+ ('~A', '~C'),
+ ('~B', '~C')]
+ """
+
+ def __init__(self, elements: Iterable[Element]):
+ self.elements = tuple(elements)
+
+ def __lt__(self, n: int) -> CNF:
+ return list(combinations(map(neg, self.elements), n))
+
+ def __le__(self, n: int) -> CNF:
+ return self < n + 1
+
+ def __gt__(self, n: int) -> CNF:
+ return list(combinations(self.elements, len(self.elements) - n))
+
+ def __ge__(self, n: int) -> CNF:
+ return self > n - 1
+
+ def __eq__(self, n: int) -> CNF: # type:ignore
+ return (self <= n) + (self >= n)
+
+ def __ne__(self, n) -> CNF: # type:ignore
+ raise NotImplementedError
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}(elements={self.elements!r})"
+
+
+def all_of(elements: List[Element]) -> CNF:
+ """Forces inclusion of matching rows on a truth table"""
+ return Q(elements) == len(elements)
+
+
+def some_of(elements: Iterable[Element]) -> CNF:
+ """At least one of the elements must be true
+
+ >>> some_of(['A', 'B', 'C'])
+ [['A', 'B', 'C']]
+ """
+ return Q(elements) >= 1
+
+
+def one_of(elements: Iterable[Element]) -> CNF:
+ """Exactly one of the elements is true
+
+ >>> one_of(['A', 'B', 'C'])
+ [('A', 'B', 'C'),
+ ('~A', '~B'),
+ ('~A', '~C'),
+ ('~B', '~C')]
+ """
+ return Q(elements) == 1
+
+
+def basic_fact(element: Element) -> CNF:
+ """Assert that this one element always matches"""
+ return Q([element]) == 1
+
+
+def none_of(elements: Iterable[Element]) -> CNF:
+ """Forces exclusion of matching rows on a truth table"""
+ return Q(elements) == 0
diff --git a/reasoning_gym/logic/zebra_puzzles.py b/reasoning_gym/logic/zebra_puzzles.py
new file mode 100644
index 00000000..3ba177a7
--- /dev/null
+++ b/reasoning_gym/logic/zebra_puzzles.py
@@ -0,0 +1,78 @@
+from dataclasses import dataclass
+from random import Random, seed
+from typing import Dict, List, Optional, Tuple
+
+from ..factory import ProceduralDataset, register_dataset
+from .contrib.logic_puzzle.generate import generate_puzzle
+
+
+@dataclass
+class ZebraConfig:
+ """Configuration for zebra puzzle generation"""
+
+ num_people: int = 4
+ num_characteristics: int = 4
+ seed: Optional[int] = None
+ size: int = 500
+
+ def validate(self):
+ """Validate configuration parameters"""
+ assert 2 <= self.num_people <= 7, "num_people must be between 2 and 7"
+ assert 2 <= self.num_characteristics <= 7, "num_characteristics must be between 2 and 7"
+
+
+class ZebraDataset(ProceduralDataset):
+ """Generates [Zebra Puzzles](https://en.wikipedia.org/wiki/Zebra_Puzzle) with configurable parameters"""
+
+ def __init__(self, config: ZebraConfig):
+ super().__init__(config=config, seed=config.seed, size=config.size)
+
+ def __getitem__(self, idx: int) -> dict:
+ """Generate a single Zebra task
+
+ Returns:
+ dict with keys:
+ - question: str, the task description
+ - answer: str, a solution string
+ - metadata: dict with generation parameters
+ """
+ seed(self.seed + idx)
+
+ K = self.config.num_people
+ M = self.config.num_characteristics
+ instance, puzzle = generate_puzzle(K, M, "train")
+ q = instance["questions"][0]["question"]
+ answer = instance["questions"][0]["answer"]
+ question = str(puzzle) + "\n" + q
+
+ return {
+ "question": question,
+ "answer": answer,
+ "metadata": {
+ "num_people": K,
+ "num_characteristics": M,
+ },
+ }
+
+ def score_answer(self, answer: Optional[str], entry: Dict[str, any]) -> float:
+ """Determine if the solution provided solves the Zebra task.
+
+ The function awards 1.0 for a correct answer.
+
+ Args:
+ answer (Optional[str]): The user's answer.
+ entry (Dict[str, any]): The original dataset entry containing the correct answer.
+
+ Returns:
+ float: The computed score between 0.0 and 1.0.
+ """
+
+ if answer == None:
+ return 0.0
+ if answer.lower().replace("\n", "") != entry["answer"].lower().replace("\n", ""):
+ return 0.01
+ else:
+ return 1.0 # Yay
+
+
+register_dataset("zebra_puzzles", ZebraDataset, ZebraConfig)
diff --git a/tests/test_zebra.py b/tests/test_zebra.py
new file mode 100644
index 00000000..d233c438
--- /dev/null
+++ b/tests/test_zebra.py
@@ -0,0 +1,20 @@
+import pytest
+
+from reasoning_gym.logic.zebra_puzzles import ZebraConfig, ZebraDataset
+
+
+def test_zebra_puzzles():
+ """Test basic properties and solution of generated items"""
+
+ config = ZebraConfig(seed=42, size=10, num_people=4, num_characteristics=4)
+ dataset = ZebraDataset(config)
+
+ for item in dataset:
+ assert isinstance(item, dict)
+ assert "question" in item
+ assert "answer" in item
+ assert "metadata" in item
+
+ # Test the scoring
+ assert dataset.score_answer(answer=item["answer"], entry=item) == 1.0
+ assert dataset.score_answer(answer=None, entry=item) == 0.0