Skip to content

Commit

Permalink
fixed problems in knights_knaves (#251)
Browse files Browse the repository at this point in the history
* remove unnecessary variables

* added depth logic

* add depth tests
  • Loading branch information
vncntt authored Mar 2, 2025
1 parent 24828e1 commit 3149edf
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 31 deletions.
102 changes: 71 additions & 31 deletions reasoning_gym/logic/knights_knaves.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def _sample_statement(self, person_id: int, depth_constraint: int):
while True:
knight_or_knave = self.rng.choice(["telling-truth", "lying"])
person = self.rng.integers(0, self.n_people)
# prevent the contradiction "I am lying"
if not (knight_or_knave == "lying" and person == person_id):
return (knight_or_knave, person)
if dice == 1:
Expand Down Expand Up @@ -215,21 +216,11 @@ def __init__(self, rand_seed, problem):
self.rng = np.random.default_rng(rand_seed)
self.problem = problem

def format_problem(
self,
random_names=True,
random_saying_template=True,
random_knight_knave_pairs=True,
flip_knight_knave_pair=False,
uncommon_name=False,
reorder_statement=False,
):
def format_problem(self):
statements = copy.deepcopy(self.problem["statements"])
n_people = len(statements)
names = list(self.rng.choice(COMMON_NAMES, size=n_people, replace=False))
knight_knave = ["a knight", "a knave"]
if random_knight_knave_pairs:
knight_knave = self.rng.choice(KNIGHT_KNAVE_PAIRS)
knight_knave = self.rng.choice(KNIGHT_KNAVE_PAIRS)
knight_knave = {
"knight": knight_knave[0].split()[1],
"knave": knight_knave[1].split()[1],
Expand Down Expand Up @@ -277,27 +268,78 @@ def format_problem(
"solution_text": solution_text,
}

def _format_statement(self, names, knight_knave, statement):
def _format_statement(self, names, knight_knave, statement, depth=0):
"""
Recursively format a logical statement with appropriate parentheses based on depth.
Args:
names: List of people's names
knight_knave: Dictionary with knight/knave terminology
statement: Logical statement tuple to format
depth: Current nesting depth (0 = top level)
"""
# Base case: this is a primitive statement
if statement[0] in ("telling-truth", "lying"):
return self._format_knight_knave(names, knight_knave, statement)

# Handle negation
if statement[0] == "not":
return self._format_knight_knave(names, knight_knave, statement[1], negation=True)

# Special case: If negating a primitive statement, use the complementary term directly
if statement[1][0] in ("telling-truth", "lying"):
# Map "telling-truth" to "lying" and vice versa
complementary_statement = (
"lying" if statement[1][0] == "telling-truth" else "telling-truth",
statement[1][1],
)
return self._format_knight_knave(names, knight_knave, complementary_statement)
else:
# For complex statements, use the verbose "it is not the case that" format
inner_content = self._format_statement(names, knight_knave, statement[1], depth + 1)
if statement[1][0] not in ("telling-truth", "lying"):
inner_content = f"({inner_content})"
return f"it is not the case that {inner_content}"

# Handle AND/OR
if statement[0] in ["and", "or"]:
return (" " + statement[0] + " ").join(
self._format_knight_knave(names, knight_knave, sub_stmt) for sub_stmt in statement[1:]
)
formatted_substmts = []
for sub_stmt in statement[1:]:
sub_content = self._format_statement(names, knight_knave, sub_stmt, depth + 1)
# Always add parentheses for complex subexpressions in AND/OR
if sub_stmt[0] not in ("telling-truth", "lying"):
sub_content = f"({sub_content})"
formatted_substmts.append(sub_content)
connector = f" {statement[0]} "
return connector.join(formatted_substmts)

# Handle implication
if statement[0] == "->":
return (
"If "
+ self._format_knight_knave(names, knight_knave, statement[1])
+ " then "
+ self._format_knight_knave(names, knight_knave, statement[2])
)
antecedent = self._format_statement(names, knight_knave, statement[1], depth + 1)
consequent = self._format_statement(names, knight_knave, statement[2], depth + 1)

# Always add parentheses for complex expressions in implications
if statement[1][0] not in ("telling-truth", "lying"):
antecedent = f"({antecedent})"
if statement[2][0] not in ("telling-truth", "lying"):
consequent = f"({consequent})"

return f"if {antecedent} then {consequent}"

# Handle biconditional
if statement[0] == "<=>":
return (
self._format_knight_knave(names, knight_knave, statement[1])
+ " if and only if "
+ self._format_knight_knave(names, knight_knave, statement[2])
)
return self._format_knight_knave(names, knight_knave, statement)
left = self._format_statement(names, knight_knave, statement[1], depth + 1)
right = self._format_statement(names, knight_knave, statement[2], depth + 1)

# Always add parentheses for complex expressions in biconditionals
if statement[1][0] not in ("telling-truth", "lying"):
left = f"({left})"
if statement[2][0] not in ("telling-truth", "lying"):
right = f"({right})"

return f"{left} if and only if {right}"

# This should not happen with well-formed statements
raise ValueError(f"Unknown statement type: {statement[0]}")

def _format_knight_knave(self, names, knight_knave, statement, negation=False):
assert statement[0] in ("telling-truth", "lying")
Expand Down Expand Up @@ -404,8 +446,6 @@ def __generate_problem(self, rng: Random) -> dict[str, Any]:
problems = sampler.sample_valid_problems(1, skip_no_solution=True, skip_multiple_solutions=True)
problem = problems[0]

# Format the problem using the original KKProblemFormatter logic

# Format the problem
formatter = KKProblemFormatter(rand_seed=rng.randint(0, 2**32), problem=problem)
formatted = formatter.format_problem()
Expand Down
50 changes: 50 additions & 0 deletions tests/test_knights_knaves.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def test_items():
assert "question" in item
assert "answer" in item
assert "metadata" in item
assert "solution" in item["metadata"]


def test_solution():
Expand Down Expand Up @@ -185,3 +186,52 @@ def test_satisfiability():
assert not KnightsKnavesDataset.test_satisfiability(
("<=>", ("telling-truth", 0), ("telling-truth", 1)), (True, False)
)


def test_depth_constraint():
config = KnightsKnavesConfig(
n_people=2,
depth_constraint=4,
width_constraint=2,
size=5,
seed=42,
)
dataset = KnightsKnavesDataset(config)
assert len(dataset) == 5
for i in range(len(dataset)):
# make sure there's a unique solution
assert len(dataset[i]["metadata"]["solution"]) == len(dataset[i]["metadata"]["names"])


def test_depth_constraint_specific_problem():
test_statements = (
(
"or",
("not", ("and", ("telling-truth", 0), ("telling-truth", 1), ("lying", 1))),
(
"and",
("not", ("telling-truth", 0)),
("->", ("telling-truth", 0), ("lying", 1)),
("<=>", ("telling-truth", 1), ("lying", 2)),
),
(
"and",
("or", ("lying", 2), ("lying", 1), ("telling-truth", 0)),
("or", ("telling-truth", 0), ("lying", 2)),
("or", ("telling-truth", 2), ("telling-truth", 0)),
),
),
(
"not",
(
"or",
("and", ("telling-truth", 1), ("telling-truth", 2), ("telling-truth", 0)),
("or", ("telling-truth", 2), ("telling-truth", 1), ("telling-truth", 0)),
),
),
("not", ("telling-truth", 0)),
)

solutions = KnightsKnavesDataset.find_solution(test_statements)
assert len(solutions) == 1, "Should have exactly one solution"
assert solutions[0] == (True, False, False)

0 comments on commit 3149edf

Please sign in to comment.