diff --git a/reasoning_gym/logic/knights_knaves.py b/reasoning_gym/logic/knights_knaves.py index fa3d4279..7cc16ecf 100644 --- a/reasoning_gym/logic/knights_knaves.py +++ b/reasoning_gym/logic/knights_knaves.py @@ -171,6 +171,7 @@ def _sample_statement(self, person_id: int, depth_constraint: int): while True: knight_or_knave = self.rng.choice(["telling-truth", "lying"]) person = self.rng.integers(0, self.n_people) + # prevent the contradiction "I am lying" if not (knight_or_knave == "lying" and person == person_id): return (knight_or_knave, person) if dice == 1: @@ -215,21 +216,11 @@ def __init__(self, rand_seed, problem): self.rng = np.random.default_rng(rand_seed) self.problem = problem - def format_problem( - self, - random_names=True, - random_saying_template=True, - random_knight_knave_pairs=True, - flip_knight_knave_pair=False, - uncommon_name=False, - reorder_statement=False, - ): + def format_problem(self): statements = copy.deepcopy(self.problem["statements"]) n_people = len(statements) names = list(self.rng.choice(COMMON_NAMES, size=n_people, replace=False)) - knight_knave = ["a knight", "a knave"] - if random_knight_knave_pairs: - knight_knave = self.rng.choice(KNIGHT_KNAVE_PAIRS) + knight_knave = self.rng.choice(KNIGHT_KNAVE_PAIRS) knight_knave = { "knight": knight_knave[0].split()[1], "knave": knight_knave[1].split()[1], @@ -277,27 +268,78 @@ def format_problem( "solution_text": solution_text, } - def _format_statement(self, names, knight_knave, statement): + def _format_statement(self, names, knight_knave, statement, depth=0): + """ + Recursively format a logical statement with appropriate parentheses based on depth. + + Args: + names: List of people's names + knight_knave: Dictionary with knight/knave terminology + statement: Logical statement tuple to format + depth: Current nesting depth (0 = top level) + """ + # Base case: this is a primitive statement + if statement[0] in ("telling-truth", "lying"): + return self._format_knight_knave(names, knight_knave, statement) + + # Handle negation if statement[0] == "not": - return self._format_knight_knave(names, knight_knave, statement[1], negation=True) + + # Special case: If negating a primitive statement, use the complementary term directly + if statement[1][0] in ("telling-truth", "lying"): + # Map "telling-truth" to "lying" and vice versa + complementary_statement = ( + "lying" if statement[1][0] == "telling-truth" else "telling-truth", + statement[1][1], + ) + return self._format_knight_knave(names, knight_knave, complementary_statement) + else: + # For complex statements, use the verbose "it is not the case that" format + inner_content = self._format_statement(names, knight_knave, statement[1], depth + 1) + if statement[1][0] not in ("telling-truth", "lying"): + inner_content = f"({inner_content})" + return f"it is not the case that {inner_content}" + + # Handle AND/OR if statement[0] in ["and", "or"]: - return (" " + statement[0] + " ").join( - self._format_knight_knave(names, knight_knave, sub_stmt) for sub_stmt in statement[1:] - ) + formatted_substmts = [] + for sub_stmt in statement[1:]: + sub_content = self._format_statement(names, knight_knave, sub_stmt, depth + 1) + # Always add parentheses for complex subexpressions in AND/OR + if sub_stmt[0] not in ("telling-truth", "lying"): + sub_content = f"({sub_content})" + formatted_substmts.append(sub_content) + connector = f" {statement[0]} " + return connector.join(formatted_substmts) + + # Handle implication if statement[0] == "->": - return ( - "If " - + self._format_knight_knave(names, knight_knave, statement[1]) - + " then " - + self._format_knight_knave(names, knight_knave, statement[2]) - ) + antecedent = self._format_statement(names, knight_knave, statement[1], depth + 1) + consequent = self._format_statement(names, knight_knave, statement[2], depth + 1) + + # Always add parentheses for complex expressions in implications + if statement[1][0] not in ("telling-truth", "lying"): + antecedent = f"({antecedent})" + if statement[2][0] not in ("telling-truth", "lying"): + consequent = f"({consequent})" + + return f"if {antecedent} then {consequent}" + + # Handle biconditional if statement[0] == "<=>": - return ( - self._format_knight_knave(names, knight_knave, statement[1]) - + " if and only if " - + self._format_knight_knave(names, knight_knave, statement[2]) - ) - return self._format_knight_knave(names, knight_knave, statement) + left = self._format_statement(names, knight_knave, statement[1], depth + 1) + right = self._format_statement(names, knight_knave, statement[2], depth + 1) + + # Always add parentheses for complex expressions in biconditionals + if statement[1][0] not in ("telling-truth", "lying"): + left = f"({left})" + if statement[2][0] not in ("telling-truth", "lying"): + right = f"({right})" + + return f"{left} if and only if {right}" + + # This should not happen with well-formed statements + raise ValueError(f"Unknown statement type: {statement[0]}") def _format_knight_knave(self, names, knight_knave, statement, negation=False): assert statement[0] in ("telling-truth", "lying") @@ -404,8 +446,6 @@ def __generate_problem(self, rng: Random) -> dict[str, Any]: problems = sampler.sample_valid_problems(1, skip_no_solution=True, skip_multiple_solutions=True) problem = problems[0] - # Format the problem using the original KKProblemFormatter logic - # Format the problem formatter = KKProblemFormatter(rand_seed=rng.randint(0, 2**32), problem=problem) formatted = formatter.format_problem() diff --git a/tests/test_knights_knaves.py b/tests/test_knights_knaves.py index fe539c0a..0c595c0b 100644 --- a/tests/test_knights_knaves.py +++ b/tests/test_knights_knaves.py @@ -45,6 +45,7 @@ def test_items(): assert "question" in item assert "answer" in item assert "metadata" in item + assert "solution" in item["metadata"] def test_solution(): @@ -185,3 +186,52 @@ def test_satisfiability(): assert not KnightsKnavesDataset.test_satisfiability( ("<=>", ("telling-truth", 0), ("telling-truth", 1)), (True, False) ) + + +def test_depth_constraint(): + config = KnightsKnavesConfig( + n_people=2, + depth_constraint=4, + width_constraint=2, + size=5, + seed=42, + ) + dataset = KnightsKnavesDataset(config) + assert len(dataset) == 5 + for i in range(len(dataset)): + # make sure there's a unique solution + assert len(dataset[i]["metadata"]["solution"]) == len(dataset[i]["metadata"]["names"]) + + +def test_depth_constraint_specific_problem(): + test_statements = ( + ( + "or", + ("not", ("and", ("telling-truth", 0), ("telling-truth", 1), ("lying", 1))), + ( + "and", + ("not", ("telling-truth", 0)), + ("->", ("telling-truth", 0), ("lying", 1)), + ("<=>", ("telling-truth", 1), ("lying", 2)), + ), + ( + "and", + ("or", ("lying", 2), ("lying", 1), ("telling-truth", 0)), + ("or", ("telling-truth", 0), ("lying", 2)), + ("or", ("telling-truth", 2), ("telling-truth", 0)), + ), + ), + ( + "not", + ( + "or", + ("and", ("telling-truth", 1), ("telling-truth", 2), ("telling-truth", 0)), + ("or", ("telling-truth", 2), ("telling-truth", 1), ("telling-truth", 0)), + ), + ), + ("not", ("telling-truth", 0)), + ) + + solutions = KnightsKnavesDataset.find_solution(test_statements) + assert len(solutions) == 1, "Should have exactly one solution" + assert solutions[0] == (True, False, False)